From 1d24a7879d3f9df62ea2a8f8f6c41dd7819d94e2 Mon Sep 17 00:00:00 2001 From: xzeng Date: Mon, 23 Jan 2023 00:14:49 -0500 Subject: [PATCH] init --- .gitignore | 14 + README.md | 76 +- build_pkg.py | 3 + datasets/data_path.py | 38 + datasets/pointflow_datasets.py | 404 +++++ default_config.py | 450 +++++ demo.py | 45 + env.yaml | 311 ++++ models/adagn.py | 67 + models/dense.py | 80 + models/distributions.py | 37 + models/latent_points_ada.py | 273 +++ models/latent_points_ada_localprior.py | 84 + models/lion.py | 91 + models/pvcnn2.py | 557 ++++++ models/pvcnn2_ada.py | 568 ++++++ models/score_sde/resnet.py | 230 +++ models/shapelatent_modules.py | 54 + models/utils.py | 52 + models/vae_adain.py | 339 ++++ script/compute_score.py | 43 + third_party/ChamferDistancePytorch/.gitignore | 3 + third_party/ChamferDistancePytorch/LICENSE | 21 + third_party/ChamferDistancePytorch/README.md | 104 ++ .../chamfer2D/chamfer2D.cu | 182 ++ .../chamfer2D/chamfer_cuda.cpp | 33 + .../chamfer2D/dist_chamfer_2D.py | 80 + .../ChamferDistancePytorch/chamfer2D/setup.py | 14 + .../chamfer3D/chamfer3D.cu | 196 +++ .../chamfer3D/chamfer_cuda.cpp | 33 + .../chamfer3D/dist_chamfer_3D.py | 133 ++ .../ChamferDistancePytorch/chamfer3D/setup.py | 14 + .../chamfer5D/chamfer5D.cu | 223 +++ .../chamfer5D/chamfer_cuda.cpp | 33 + .../chamfer5D/dist_chamfer_5D.py | 82 + .../ChamferDistancePytorch/chamfer5D/setup.py | 14 + .../chamfer6D/chamfer6D.cu | 237 +++ .../chamfer6D/chamfer_cuda.cpp | 33 + .../chamfer6D/dist_chamfer_6D.py | 82 + .../ChamferDistancePytorch/chamfer6D/setup.py | 14 + .../ChamferDistancePytorch/chamfer_python.py | 44 + third_party/ChamferDistancePytorch/fscore.py | 17 + .../ChamferDistancePytorch/unit_test.py | 69 + third_party/PyTorchEMD/.gitignore | 5 + third_party/PyTorchEMD/README.md | 34 + third_party/PyTorchEMD/__init__.py | 0 third_party/PyTorchEMD/backend.py | 21 + third_party/PyTorchEMD/cuda/emd.cpp | 29 + third_party/PyTorchEMD/cuda/emd_kernel.cu | 398 +++++ third_party/PyTorchEMD/emd.py | 52 + third_party/PyTorchEMD/emd_cuda.py | 9 + third_party/PyTorchEMD/emd_nograd.py | 45 + third_party/PyTorchEMD/emd_static.py | 49 + third_party/PyTorchEMD/setup.py | 27 + third_party/PyTorchEMD/test_emd_loss.py | 44 + third_party/pvcnn/LICENSE | 21 + third_party/pvcnn/README.md | 2 + third_party/pvcnn/functional/__init__.py | 7 + third_party/pvcnn/functional/backend.py | 29 + third_party/pvcnn/functional/ball_query.py | 20 + .../pvcnn/functional/devoxelization.py | 45 + third_party/pvcnn/functional/grouping.py | 33 + .../pvcnn/functional/interpolatation.py | 54 + third_party/pvcnn/functional/loss.py | 18 + third_party/pvcnn/functional/sampling.py | 100 ++ .../functional/src/ball_query/ball_query.cpp | 30 + .../functional/src/ball_query/ball_query.cu | 59 + .../functional/src/ball_query/ball_query.cuh | 8 + .../functional/src/ball_query/ball_query.hpp | 10 + third_party/pvcnn/functional/src/bindings.cpp | 37 + .../pvcnn/functional/src/cuda_utils.cuh | 39 + .../functional/src/grouping/grouping.cpp | 44 + .../pvcnn/functional/src/grouping/grouping.cu | 85 + .../functional/src/grouping/grouping.cuh | 9 + .../functional/src/grouping/grouping.hpp | 10 + .../src/interpolate/neighbor_interpolate.cpp | 65 + .../src/interpolate/neighbor_interpolate.cu | 181 ++ .../src/interpolate/neighbor_interpolate.cuh | 16 + .../src/interpolate/neighbor_interpolate.hpp | 16 + .../src/interpolate/trilinear_devox.cpp | 91 + .../src/interpolate/trilinear_devox.cu | 178 ++ .../src/interpolate/trilinear_devox.cuh | 13 + .../src/interpolate/trilinear_devox.hpp | 16 + .../functional/src/sampling/sampling.cpp | 58 + .../pvcnn/functional/src/sampling/sampling.cu | 174 ++ .../functional/src/sampling/sampling.cuh | 11 + .../functional/src/sampling/sampling.hpp | 12 + third_party/pvcnn/functional/src/utils.hpp | 20 + .../pvcnn/functional/src/voxelization/vox.cpp | 76 + .../pvcnn/functional/src/voxelization/vox.cu | 126 ++ .../pvcnn/functional/src/voxelization/vox.cuh | 10 + .../pvcnn/functional/src/voxelization/vox.hpp | 15 + third_party/pvcnn/functional/voxelization.py | 47 + third_party/torchdiffeq/LICENSE | 21 + third_party/torchdiffeq/README.md | 1 + .../torchdiffeq/torchdiffeq/__init__.py | 4 + .../torchdiffeq/torchdiffeq/_impl/__init__.py | 2 + .../torchdiffeq/_impl/adaptive_heun.py | 25 + .../torchdiffeq/torchdiffeq/_impl/adjoint.py | 280 +++ .../torchdiffeq/torchdiffeq/_impl/bosh3.py | 22 + .../torchdiffeq/torchdiffeq/_impl/dopri5.py | 36 + .../torchdiffeq/torchdiffeq/_impl/dopri8.py | 76 + .../torchdiffeq/_impl/event_handling.py | 35 + .../torchdiffeq/_impl/fehlberg2.py | 22 + .../torchdiffeq/_impl/fixed_adams.py | 228 +++ .../torchdiffeq/_impl/fixed_grid.py | 29 + .../torchdiffeq/torchdiffeq/_impl/interp.py | 48 + .../torchdiffeq/torchdiffeq/_impl/misc.py | 353 ++++ .../torchdiffeq/torchdiffeq/_impl/odeint.py | 164 ++ .../torchdiffeq/_impl/rk_common.py | 307 ++++ .../torchdiffeq/_impl/scipy_wrapper.py | 53 + .../torchdiffeq/torchdiffeq/_impl/solvers.py | 172 ++ third_party/yacs_config.py | 586 +++++++ train_dist.py | 251 +++ trainers/base_trainer.py | 852 +++++++++ trainers/common_fun.py | 105 ++ trainers/common_fun_prior_train.py | 363 ++++ trainers/hvae_trainer.py | 204 +++ trainers/train_2prior.py | 449 +++++ trainers/train_prior.py | 741 ++++++++ utils/checker.py | 80 + utils/data_helper.py | 35 + utils/diffusion.py | 170 ++ utils/diffusion_continuous.py | 845 +++++++++ utils/diffusion_pvd.py | 563 ++++++ utils/ema.py | 120 ++ utils/eval_helper.py | 341 ++++ utils/evaluation_metrics_fast.py | 687 ++++++++ utils/exp_helper.py | 122 ++ utils/io_helper.py | 17 + utils/model_helper.py | 138 ++ utils/sr_utils.py | 115 ++ utils/utils.py | 1532 +++++++++++++++++ utils/vis_helper.py | 149 ++ 134 files changed, 18308 insertions(+), 10 deletions(-) create mode 100644 .gitignore create mode 100644 build_pkg.py create mode 100644 datasets/data_path.py create mode 100644 datasets/pointflow_datasets.py create mode 100644 default_config.py create mode 100644 demo.py create mode 100644 env.yaml create mode 100644 models/adagn.py create mode 100644 models/dense.py create mode 100644 models/distributions.py create mode 100644 models/latent_points_ada.py create mode 100644 models/latent_points_ada_localprior.py create mode 100644 models/lion.py create mode 100644 models/pvcnn2.py create mode 100644 models/pvcnn2_ada.py create mode 100644 models/score_sde/resnet.py create mode 100644 models/shapelatent_modules.py create mode 100644 models/utils.py create mode 100644 models/vae_adain.py create mode 100644 script/compute_score.py create mode 100644 third_party/ChamferDistancePytorch/.gitignore create mode 100644 third_party/ChamferDistancePytorch/LICENSE create mode 100755 third_party/ChamferDistancePytorch/README.md create mode 100755 third_party/ChamferDistancePytorch/chamfer2D/chamfer2D.cu create mode 100755 third_party/ChamferDistancePytorch/chamfer2D/chamfer_cuda.cpp create mode 100644 third_party/ChamferDistancePytorch/chamfer2D/dist_chamfer_2D.py create mode 100755 third_party/ChamferDistancePytorch/chamfer2D/setup.py create mode 100755 third_party/ChamferDistancePytorch/chamfer3D/chamfer3D.cu create mode 100755 third_party/ChamferDistancePytorch/chamfer3D/chamfer_cuda.cpp create mode 100644 third_party/ChamferDistancePytorch/chamfer3D/dist_chamfer_3D.py create mode 100755 third_party/ChamferDistancePytorch/chamfer3D/setup.py create mode 100755 third_party/ChamferDistancePytorch/chamfer5D/chamfer5D.cu create mode 100755 third_party/ChamferDistancePytorch/chamfer5D/chamfer_cuda.cpp create mode 100644 third_party/ChamferDistancePytorch/chamfer5D/dist_chamfer_5D.py create mode 100755 third_party/ChamferDistancePytorch/chamfer5D/setup.py create mode 100755 third_party/ChamferDistancePytorch/chamfer6D/chamfer6D.cu create mode 100755 third_party/ChamferDistancePytorch/chamfer6D/chamfer_cuda.cpp create mode 100755 third_party/ChamferDistancePytorch/chamfer6D/dist_chamfer_6D.py create mode 100755 third_party/ChamferDistancePytorch/chamfer6D/setup.py create mode 100644 third_party/ChamferDistancePytorch/chamfer_python.py create mode 100644 third_party/ChamferDistancePytorch/fscore.py create mode 100644 third_party/ChamferDistancePytorch/unit_test.py create mode 100644 third_party/PyTorchEMD/.gitignore create mode 100644 third_party/PyTorchEMD/README.md create mode 100755 third_party/PyTorchEMD/__init__.py create mode 100755 third_party/PyTorchEMD/backend.py create mode 100755 third_party/PyTorchEMD/cuda/emd.cpp create mode 100644 third_party/PyTorchEMD/cuda/emd_kernel.cu create mode 100755 third_party/PyTorchEMD/emd.py create mode 100644 third_party/PyTorchEMD/emd_cuda.py create mode 100644 third_party/PyTorchEMD/emd_nograd.py create mode 100755 third_party/PyTorchEMD/emd_static.py create mode 100755 third_party/PyTorchEMD/setup.py create mode 100644 third_party/PyTorchEMD/test_emd_loss.py create mode 100644 third_party/pvcnn/LICENSE create mode 100644 third_party/pvcnn/README.md create mode 100644 third_party/pvcnn/functional/__init__.py create mode 100644 third_party/pvcnn/functional/backend.py create mode 100644 third_party/pvcnn/functional/ball_query.py create mode 100644 third_party/pvcnn/functional/devoxelization.py create mode 100644 third_party/pvcnn/functional/grouping.py create mode 100644 third_party/pvcnn/functional/interpolatation.py create mode 100644 third_party/pvcnn/functional/loss.py create mode 100644 third_party/pvcnn/functional/sampling.py create mode 100644 third_party/pvcnn/functional/src/ball_query/ball_query.cpp create mode 100644 third_party/pvcnn/functional/src/ball_query/ball_query.cu create mode 100644 third_party/pvcnn/functional/src/ball_query/ball_query.cuh create mode 100644 third_party/pvcnn/functional/src/ball_query/ball_query.hpp create mode 100644 third_party/pvcnn/functional/src/bindings.cpp create mode 100644 third_party/pvcnn/functional/src/cuda_utils.cuh create mode 100644 third_party/pvcnn/functional/src/grouping/grouping.cpp create mode 100644 third_party/pvcnn/functional/src/grouping/grouping.cu create mode 100644 third_party/pvcnn/functional/src/grouping/grouping.cuh create mode 100644 third_party/pvcnn/functional/src/grouping/grouping.hpp create mode 100644 third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.cpp create mode 100644 third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.cu create mode 100644 third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.cuh create mode 100644 third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.hpp create mode 100644 third_party/pvcnn/functional/src/interpolate/trilinear_devox.cpp create mode 100644 third_party/pvcnn/functional/src/interpolate/trilinear_devox.cu create mode 100644 third_party/pvcnn/functional/src/interpolate/trilinear_devox.cuh create mode 100644 third_party/pvcnn/functional/src/interpolate/trilinear_devox.hpp create mode 100644 third_party/pvcnn/functional/src/sampling/sampling.cpp create mode 100644 third_party/pvcnn/functional/src/sampling/sampling.cu create mode 100644 third_party/pvcnn/functional/src/sampling/sampling.cuh create mode 100644 third_party/pvcnn/functional/src/sampling/sampling.hpp create mode 100644 third_party/pvcnn/functional/src/utils.hpp create mode 100644 third_party/pvcnn/functional/src/voxelization/vox.cpp create mode 100644 third_party/pvcnn/functional/src/voxelization/vox.cu create mode 100644 third_party/pvcnn/functional/src/voxelization/vox.cuh create mode 100644 third_party/pvcnn/functional/src/voxelization/vox.hpp create mode 100644 third_party/pvcnn/functional/voxelization.py create mode 100644 third_party/torchdiffeq/LICENSE create mode 100644 third_party/torchdiffeq/README.md create mode 100644 third_party/torchdiffeq/torchdiffeq/__init__.py create mode 100644 third_party/torchdiffeq/torchdiffeq/_impl/__init__.py create mode 100644 third_party/torchdiffeq/torchdiffeq/_impl/adaptive_heun.py create mode 100644 third_party/torchdiffeq/torchdiffeq/_impl/adjoint.py create mode 100644 third_party/torchdiffeq/torchdiffeq/_impl/bosh3.py create mode 100644 third_party/torchdiffeq/torchdiffeq/_impl/dopri5.py create mode 100644 third_party/torchdiffeq/torchdiffeq/_impl/dopri8.py create mode 100644 third_party/torchdiffeq/torchdiffeq/_impl/event_handling.py create mode 100644 third_party/torchdiffeq/torchdiffeq/_impl/fehlberg2.py create mode 100644 third_party/torchdiffeq/torchdiffeq/_impl/fixed_adams.py create mode 100644 third_party/torchdiffeq/torchdiffeq/_impl/fixed_grid.py create mode 100644 third_party/torchdiffeq/torchdiffeq/_impl/interp.py create mode 100644 third_party/torchdiffeq/torchdiffeq/_impl/misc.py create mode 100644 third_party/torchdiffeq/torchdiffeq/_impl/odeint.py create mode 100644 third_party/torchdiffeq/torchdiffeq/_impl/rk_common.py create mode 100644 third_party/torchdiffeq/torchdiffeq/_impl/scipy_wrapper.py create mode 100644 third_party/torchdiffeq/torchdiffeq/_impl/solvers.py create mode 100644 third_party/yacs_config.py create mode 100644 train_dist.py create mode 100644 trainers/base_trainer.py create mode 100644 trainers/common_fun.py create mode 100644 trainers/common_fun_prior_train.py create mode 100644 trainers/hvae_trainer.py create mode 100644 trainers/train_2prior.py create mode 100644 trainers/train_prior.py create mode 100644 utils/checker.py create mode 100644 utils/data_helper.py create mode 100644 utils/diffusion.py create mode 100644 utils/diffusion_continuous.py create mode 100644 utils/diffusion_pvd.py create mode 100644 utils/ema.py create mode 100644 utils/eval_helper.py create mode 100644 utils/evaluation_metrics_fast.py create mode 100644 utils/exp_helper.py create mode 100644 utils/io_helper.py create mode 100644 utils/model_helper.py create mode 100644 utils/sr_utils.py create mode 100644 utils/utils.py create mode 100644 utils/vis_helper.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..34fedbd --- /dev/null +++ b/.gitignore @@ -0,0 +1,14 @@ +__pycache__/ +*__pycache__ +.idea/ +*.pyc +*.m +.ipynb_checkpoints +*swp +*swo +*__pycache__* +models/pvcnn/functional/build/ +*.sh +lion_ckpt +data/ +datasets/test_data diff --git a/README.md b/README.md index c21fa1c..b08e0ba 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,76 @@ ##

LION: Latent Point Diffusion Models for 3D Shape Generation

NeurIPS 2022

- Xiaohui Zeng·   - Arash Vahdat·   - Francis Williams·   - Zan Gojcic·   - Or Litany·   - Sanja Fidler·   + Xiaohui Zeng   + Arash Vahdat   + Francis Williams   + Zan Gojcic   + Or Litany   + Sanja FidlerKarsten Kreis

PaperProject Page
-

-

:construction: :pick: :hammer_and_wrench: :construction_worker:

-

Here, we will release code and checkpoints in the near future! Stay tuned!

-

+

Animation

+## Install +* Dependencies: + * CUDA 11.6 + +* Setup the environment + Install from conda file + ``` + conda env create --name lion_env --file=env.yaml + conda activate lion_env + + # Install some other packages + pip install git+https://github.com/openai/CLIP.git + + # build some packages first (optional) + python build_pkg.py + ``` + Tested with conda version 22.9.0 + +## Demo +run `python demo.py`, will load the released text2shape model on hugging face and generate a chair point cloud. + +## Released checkpoint and samples +* will be release soon +* put the downloaded file under `./lion_ckpt/` + +## Training + +### data +* ShapeNet can be downloaded [here](https://github.com/stevenygd/PointFlow#dataset). +* Put the downloaded data as `./data/ShapeNetCore.v2.PC15k` *or* edit the `pointflow` entry in `./datasets/data_path.py` for the ShapeNet dataset path. + +### train VAE +* run `bash ./script/train_vae.sh $NGPU` (the released checkpoint is trained with `NGPU=4`) + +### train diffusion prior +* require the vae checkpoint +* run `bash ./script/train_prior.sh $NGPU` (the released checkpoint is trained with `NGPU=8` with 2 node) + +### evaluate a trained prior +* download the test data from [here](https://drive.google.com/file/d/1uEp0o6UpRqfYwvRXQGZ5ZgT1IYBQvUSV/view?usp=share_link), unzip and put it as `./datasets/test_data/` +* download the released checkpoint from above +``` +checkpoint="./lion_ckpt/unconditional/airplane/checkpoints/model.pt" +bash ./script/eval.sh $checkpoint # will take 1-2 hour +``` + +## Evaluate the samples with the 1-NNA metrics +* download the test data from [here](https://drive.google.com/file/d/1uEp0o6UpRqfYwvRXQGZ5ZgT1IYBQvUSV/view?usp=share_link), unzip and put it as `./datasets/test_data/` +* run `python ./script/compute_score.py` + +## Citation +``` +@inproceedings{zeng2022lion, + title={LION: Latent Point Diffusion Models for 3D Shape Generation}, + author={Xiaohui Zeng and Arash Vahdat and Francis Williams and Zan Gojcic and Or Litany and Sanja Fidler and Karsten Kreis}, + booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, + year={2022} +} +``` diff --git a/build_pkg.py b/build_pkg.py new file mode 100644 index 0000000..09abefc --- /dev/null +++ b/build_pkg.py @@ -0,0 +1,3 @@ +import clip +from models import pvcnn2 +from utils import eval_helper diff --git a/datasets/data_path.py b/datasets/data_path.py new file mode 100644 index 0000000..7c4f94e --- /dev/null +++ b/datasets/data_path.py @@ -0,0 +1,38 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import os + + +def get_path(dataname=None): + dataset_path = {} + dataset_path['pointflow'] = [ + './data/ShapeNetCore.v2.PC15k/' + + ] + + if dataname is None: + return dataset_path + else: + assert( + dataname in dataset_path), f'not found {dataname}, only: {list(dataset_path.keys())}' + for p in dataset_path[dataname]: + print(f'searching: {dataname}, get: {p}') + if os.path.exists(p): + return p + ValueError( + f'all path not found for {dataname}, please double check: {dataset_path[dataname]}; or edit the datasets/data_path.py ') + + +def get_cache_path(): + cache_list = ['/workspace/data_cache_local/data_stat/', + '/workspace/data_cache/data_stat/'] + for p in cache_list: + if os.path.exists(p): + return p + ValueError( + f'all path not found for {cache_list}, please double check: or edit the datasets/data_path.py ') diff --git a/datasets/pointflow_datasets.py b/datasets/pointflow_datasets.py new file mode 100644 index 0000000..9ab7c5e --- /dev/null +++ b/datasets/pointflow_datasets.py @@ -0,0 +1,404 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +""" copied and modified from https://github.com/stevenygd/PointFlow/blob/master/datasets.py """ +import os +import open3d as o3d +import time +import torch +import numpy as np +from loguru import logger +from torch.utils.data import Dataset +from torch.utils import data +import random +import tqdm +from datasets.data_path import get_path +OVERFIT = 0 + +# taken from https://github.com/optas/latent_3d_points/blob/ +# 8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py +synsetid_to_cate = { + '02691156': 'airplane', + '02773838': 'bag', + '02801938': 'basket', + '02808440': 'bathtub', + '02818832': 'bed', + '02828884': 'bench', + '02876657': 'bottle', + '02880940': 'bowl', + '02924116': 'bus', + '02933112': 'cabinet', + '02747177': 'can', + '02942699': 'camera', + '02954340': 'cap', + '02958343': 'car', + '03001627': 'chair', + '03046257': 'clock', + '03207941': 'dishwasher', + '03211117': 'monitor', + '04379243': 'table', + '04401088': 'telephone', + '02946921': 'tin_can', + '04460130': 'tower', + '04468005': 'train', + '03085013': 'keyboard', + '03261776': 'earphone', + '03325088': 'faucet', + '03337140': 'file', + '03467517': 'guitar', + '03513137': 'helmet', + '03593526': 'jar', + '03624134': 'knife', + '03636649': 'lamp', + '03642806': 'laptop', + '03691459': 'speaker', + '03710193': 'mailbox', + '03759954': 'microphone', + '03761084': 'microwave', + '03790512': 'motorcycle', + '03797390': 'mug', + '03928116': 'piano', + '03938244': 'pillow', + '03948459': 'pistol', + '03991062': 'pot', + '04004475': 'printer', + '04074963': 'remote_control', + '04090263': 'rifle', + '04099429': 'rocket', + '04225987': 'skateboard', + '04256520': 'sofa', + '04330267': 'stove', + '04530566': 'vessel', + '04554684': 'washer', + '02992529': 'cellphone', + '02843684': 'birdhouse', + '02871439': 'bookshelf', + # '02858304': 'boat', no boat in our dataset, merged into vessels + # '02834778': 'bicycle', not in our taxonomy +} +cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()} + + +class ShapeNet15kPointClouds(Dataset): + def __init__(self, + categories=['airplane'], + tr_sample_size=10000, + te_sample_size=10000, + split='train', + scale=1., + normalize_per_shape=False, + normalize_shape_box=False, + random_subsample=False, + sample_with_replacement=1, + normalize_std_per_axis=False, + normalize_global=False, + recenter_per_shape=False, + all_points_mean=None, + all_points_std=None, + input_dim=3, + ): + self.normalize_shape_box = normalize_shape_box + root_dir = get_path('pointflow') + self.root_dir = root_dir + logger.info('[DATA] cat: {}, split: {}, full path: {}; norm global={}, norm-box={}', + categories, split, self.root_dir, normalize_global, normalize_shape_box) + + self.split = split + assert self.split in ['train', 'test', 'val'] + self.tr_sample_size = tr_sample_size + self.te_sample_size = te_sample_size + if type(categories) is str: + categories = [categories] + self.cates = categories + + if 'all' in categories: + self.synset_ids = list(cate_to_synsetid.values()) + else: + self.synset_ids = [cate_to_synsetid[c] for c in self.cates] + subdirs = self.synset_ids + # assert 'v2' in root_dir, "Only supporting v2 right now." + self.gravity_axis = 1 + self.display_axis_order = [0, 2, 1] + + self.root_dir = root_dir + self.split = split + self.in_tr_sample_size = tr_sample_size + self.in_te_sample_size = te_sample_size + self.subdirs = subdirs + self.scale = scale + self.random_subsample = random_subsample + self.sample_with_replacement = sample_with_replacement + self.input_dim = input_dim + + self.all_cate_mids = [] + self.cate_idx_lst = [] + self.all_points = [] + tic = time.time() + for cate_idx, subd in enumerate(self.subdirs): + # NOTE: [subd] here is synset id + sub_path = os.path.join(root_dir, subd, self.split) + if not os.path.isdir(sub_path): + print("Directory missing : %s " % (sub_path)) + raise ValueError('check the data path') + continue + if True: + all_mids = [] + assert(os.path.exists(sub_path)), f'path missing: {sub_path}' + for x in os.listdir(sub_path): + if not x.endswith('.npy'): + continue + all_mids.append(os.path.join(self.split, x[:-len('.npy')])) + + logger.info('[DATA] number of file [{}] under: {} ', + len(os.listdir(sub_path)), sub_path) + # NOTE: [mid] contains the split: i.e. "train/" + # or "val/" or "test/" + all_mids = sorted(all_mids) + for mid in all_mids: + # obj_fname = os.path.join(sub_path, x) + obj_fname = os.path.join(root_dir, subd, mid + ".npy") + point_cloud = np.load(obj_fname) # (15k, 3) + self.all_points.append(point_cloud[np.newaxis, ...]) + self.cate_idx_lst.append(cate_idx) + self.all_cate_mids.append((subd, mid)) + + logger.info('[DATA] Load data time: {:.1f}s | dir: {} | ' + 'sample_with_replacement: {}; num points: {}', time.time() - tic, self.subdirs, + self.sample_with_replacement, len(self.all_points)) + + # Shuffle the index deterministically (based on the number of examples) + self.shuffle_idx = list(range(len(self.all_points))) + random.Random(38383).shuffle(self.shuffle_idx) + self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx] + self.all_points = [self.all_points[i] for i in self.shuffle_idx] + self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx] + + # Normalization + self.all_points = np.concatenate(self.all_points) # (N, 15000, 3) + self.normalize_per_shape = normalize_per_shape + self.normalize_std_per_axis = normalize_std_per_axis + self.recenter_per_shape = recenter_per_shape + if self.normalize_shape_box: # per shape normalization + B, N = self.all_points.shape[:2] + self.all_points_mean = ( # B,1,3 + (np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) + + (np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)) / 2 + self.all_points_std = np.amax( # B,1,1 + ((np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) - + (np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)), + axis=-1).reshape(B, 1, 1) / 2 + elif self.normalize_per_shape: # per shape normalization + B, N = self.all_points.shape[:2] + self.all_points_mean = self.all_points.mean(axis=1).reshape( + B, 1, input_dim) + logger.info('all_points shape: {}. mean over axis=1', + self.all_points.shape) + if normalize_std_per_axis: + self.all_points_std = self.all_points.reshape( + B, N, -1).std(axis=1).reshape(B, 1, input_dim) + else: + self.all_points_std = self.all_points.reshape( + B, -1).std(axis=1).reshape(B, 1, 1) + elif all_points_mean is not None and all_points_std is not None and not self.recenter_per_shape: + # using loaded dataset stats + self.all_points_mean = all_points_mean + self.all_points_std = all_points_std + elif self.recenter_per_shape: # per shape center + # TODO: bounding box scale at the large dim and center + B, N = self.all_points.shape[:2] + self.all_points_mean = ( + (np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) + + (np.amin(self.all_points, axis=1)).reshape(B, 1, + input_dim)) / 2 + self.all_points_std = np.amax( + ((np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) - + (np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)), + axis=-1).reshape(B, 1, 1) / 2 + # else: # normalize across the dataset + elif normalize_global: # normalize across the dataset + self.all_points_mean = self.all_points.reshape( + -1, input_dim).mean(axis=0).reshape(1, 1, input_dim) + + if normalize_std_per_axis: + self.all_points_std = self.all_points.reshape( + -1, input_dim).std(axis=0).reshape(1, 1, input_dim) + else: + self.all_points_std = self.all_points.reshape(-1).std( + axis=0).reshape(1, 1, 1) + + logger.info('[DATA] normalize_global: mean={}, std={}', + self.all_points_mean.reshape(-1), + self.all_points_std.reshape(-1)) + else: + raise NotImplementedError('No Normalization') + self.all_points = (self.all_points - self.all_points_mean) / \ + self.all_points_std + logger.info('[DATA] shape={}, all_points_mean:={}, std={}, max={:.3f}, min={:.3f}; num-pts={}', + self.all_points.shape, + self.all_points_mean.shape, self.all_points_std.shape, + self.all_points.max(), self.all_points.min(), tr_sample_size) + + if OVERFIT: + self.all_points = self.all_points[:40] + + # TODO: why do we need this?? + self.train_points = self.all_points[:, :min( + 10000, self.all_points.shape[1])] + self.tr_sample_size = min(10000, tr_sample_size) + self.te_sample_size = min(5000, te_sample_size) + assert self.scale == 1, "Scale (!= 1) is deprecated" + + # Default display axis order + self.display_axis_order = [0, 1, 2] + + def get_pc_stats(self, idx): + if self.recenter_per_shape: + m = self.all_points_mean[idx].reshape(1, self.input_dim) + s = self.all_points_std[idx].reshape(1, -1) + return m, s + + if self.normalize_per_shape or self.normalize_shape_box: + m = self.all_points_mean[idx].reshape(1, self.input_dim) + s = self.all_points_std[idx].reshape(1, -1) + return m, s + + return self.all_points_mean.reshape(1, -1), \ + self.all_points_std.reshape(1, -1) + + def renormalize(self, mean, std): + self.all_points = self.all_points * self.all_points_std + \ + self.all_points_mean + self.all_points_mean = mean + self.all_points_std = std + self.all_points = (self.all_points - self.all_points_mean) / \ + self.all_points_std + self.train_points = self.all_points[:, :min( + 10000, self.all_points.shape[1])] + ## self.test_points = self.all_points[:, 10000:] + + def __len__(self): + return len(self.train_points) + + def __getitem__(self, idx): + output = {} + tr_out = self.train_points[idx] + if self.random_subsample and self.sample_with_replacement: + tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size) + elif self.random_subsample and not self.sample_with_replacement: + tr_idxs = np.random.permutation( + np.arange(tr_out.shape[0]))[:self.tr_sample_size] + else: + tr_idxs = np.arange(self.tr_sample_size) + tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float() + m, s = self.get_pc_stats(idx) + + cate_idx = self.cate_idx_lst[idx] + sid, mid = self.all_cate_mids[idx] + input_pts = tr_out + + output.update( + { + 'idx': idx, + 'select_idx': tr_idxs, + 'tr_points': tr_out, + 'input_pts': input_pts, + 'mean': m, + 'std': s, + 'cate_idx': cate_idx, + 'sid': sid, + 'mid': mid, + 'display_axis_order': self.display_axis_order + }) + return output + + +def init_np_seed(worker_id): + seed = torch.initial_seed() + np.random.seed(seed % 4294967296) + + +def get_datasets(cfg, args): + """ + cfg: config.data sub part + """ + if OVERFIT: + random_subsample = 0 + else: + random_subsample = cfg.random_subsample + logger.info(f'get_datasets: tr_sample_size={cfg.tr_max_sample_points}, ' + f' te_sample_size={cfg.te_max_sample_points}; ' + f' random_subsample={random_subsample}' + f' normalize_global={cfg.normalize_global}' + f' normalize_std_per_axix={cfg.normalize_std_per_axis}' + f' normalize_per_shape={cfg.normalize_per_shape}' + f' recenter_per_shape={cfg.recenter_per_shape}' + ) + kwargs = {} + tr_dataset = ShapeNet15kPointClouds( + categories=cfg.cates, + split='train', + tr_sample_size=cfg.tr_max_sample_points, + te_sample_size=cfg.te_max_sample_points, + sample_with_replacement=cfg.sample_with_replacement, + scale=cfg.dataset_scale, # root_dir=cfg.data_dir, + normalize_shape_box=cfg.normalize_shape_box, + normalize_per_shape=cfg.normalize_per_shape, + normalize_std_per_axis=cfg.normalize_std_per_axis, + normalize_global=cfg.normalize_global, + recenter_per_shape=cfg.recenter_per_shape, + random_subsample=random_subsample, + **kwargs) + + eval_split = getattr(args, "eval_split", "val") + # te_dataset has random_subsample as False, therefore not using sample_with_replacement + te_dataset = ShapeNet15kPointClouds( + categories=cfg.cates, + split=eval_split, + tr_sample_size=cfg.tr_max_sample_points, + te_sample_size=cfg.te_max_sample_points, + scale=cfg.dataset_scale, # root_dir=cfg.data_dir, + normalize_shape_box=cfg.normalize_shape_box, + normalize_per_shape=cfg.normalize_per_shape, + normalize_std_per_axis=cfg.normalize_std_per_axis, + normalize_global=cfg.normalize_global, + recenter_per_shape=cfg.recenter_per_shape, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return tr_dataset, te_dataset + + +def get_data_loaders(cfg, args): + tr_dataset, te_dataset = get_datasets(cfg, args) + kwargs = {} + if args.distributed: + kwargs['sampler'] = data.distributed.DistributedSampler( + tr_dataset, shuffle=True) + else: + kwargs['shuffle'] = True + if args.eval_trainnll: + kwargs['shuffle'] = False + train_loader = data.DataLoader(dataset=tr_dataset, + batch_size=cfg.batch_size, + num_workers=cfg.num_workers, + drop_last=cfg.train_drop_last == 1, + pin_memory=False, **kwargs) + test_loader = data.DataLoader(dataset=te_dataset, + batch_size=cfg.batch_size_test, + shuffle=False, + num_workers=cfg.num_workers, + pin_memory=False, + drop_last=False, + ) + logger.info( + f'[Batch Size] train={cfg.batch_size}, test={cfg.batch_size_test}; drop-last={cfg.train_drop_last}') + loaders = { + "test_loader": test_loader, + 'train_loader': train_loader, + } + return loaders diff --git a/default_config.py b/default_config.py new file mode 100644 index 0000000..731dc86 --- /dev/null +++ b/default_config.py @@ -0,0 +1,450 @@ +# --------------------------------------------------------------- +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +# --------------------------------------------------------------- + + +from third_party.yacs_config import CfgNode as CN + +cfg = CN() +cfg.dpm_ckpt = '' +cfg.clipforge = CN() +cfg.clipforge.clip_model = "ViT-B/32" +cfg.clipforge.enable = 0 +cfg.clipforge.feat_dim = 512 +cfg.eval_trainnll = 0 +cfg.exp_name = '' +cfg.cmt = '' +cfg.hash = '' +cfg.ngpu = 1 +cfg.snapshot_min = 30 # snapshot every 30 min +cfg.bash_name = '' +cfg.set_detect_anomaly = 0 +cfg.weight_recont = 1.0 +# vae ckpt +# lns +cfg.use_checkpoint = 0 +cfg.num_val_samples = 16 # 24 #12 + +# config for pointtransformer +cfg.eval = CN() +cfg.eval.need_denoise = 0 +cfg.eval.load_other_vae_ckpt = 0 +cfg.register_deprecated_key('eval.other_vae_ckpt_path') +cfg.vis_latent_point = 0 +cfg.latent_pts = CN() +#cfg.latent_pts.class_embed_layer = '' +cfg.register_deprecated_key('latent_pts.class_embed_layer') +cfg.latent_pts.style_dim = 128 # dim of global style latent variable +cfg.register_deprecated_key('latent_pts.perturb_input') +cfg.register_deprecated_key('latent_pts.perturb_input_scale') +cfg.register_deprecated_key('latent_pts.outlier_input') + +# scale of init weights for the mlp in adaGN layer +cfg.latent_pts.ada_mlp_init_scale = 1.0 +# models.latent_points_ada.StyleMLP' # style mlp layers +cfg.latent_pts.style_mlp = '' +cfg.latent_pts.pts_sigma_offset = 0.0 +cfg.latent_pts.skip_weight = 0.1 +cfg.latent_pts.encoder_layer_out_dim = 32 +cfg.latent_pts.decoder_layer_out_dim = 32 +cfg.register_deprecated_key('latent_pts.encoder_nneighbor') +cfg.register_deprecated_key('latent_pts.decoder_nneighbor') +cfg.latent_pts.style_prior = 'models.score_sde.resnet.PriorSEDrop' +cfg.latent_pts.mask_out_extra_latent = 0 # use only latent coordinates +# latent coordinates directly same as input (not using the decoder and encoder) +cfg.register_deprecated_key('latent_pts.latent_as_pts') + +cfg.latent_pts.normalization = 'bn' # BatchNorm or LayerNorm +cfg.latent_pts.pvd_mse_loss = 0 +cfg.latent_pts.hid = 64 + +cfg.register_deprecated_key('latent_pts.knn') +cfg.register_deprecated_key('latent_pts.n5layer') +cfg.register_deprecated_key('latent_pts.dgcnn_last_hid') + +cfg.latent_pts.latent_dim_ext = [64] # the global latent dim +cfg.latent_pts.weight_kl_pt = 1.0 # kl ratio of the pts +cfg.latent_pts.weight_kl_feat = 1.0 # kl ratio of the latent feat +cfg.latent_pts.weight_kl_glb = 1.0 # kl ratio of the latent feat +# kl ratio of the latent feat +cfg.latent_pts.style_encoder = 'models.shapelatent_modules.PointNetPlusEncoder' +cfg.latent_pts.use_linear_for_adagn = 0 +# cfg.latent_pts.weight_kl_glb = 1.0 # kl ratio of the global latent + +# shapelatent: +cfg.has_shapelatent = 1 +cfg.shapelatent = CN() +cfg.shapelatent.local_emb_agg = 'mean' +cfg.shapelatent.freeze_vae = 0 # learn vae +cfg.shapelatent.eps_z_global_only = 1 +cfg.shapelatent.model = 'flow' +cfg.shapelatent.residual = 1 +cfg.shapelatent.encoder_type = 'pointnet' +cfg.shapelatent.prior_type = 'flow' +cfg.shapelatent.decoder_type = 'PointwiseNet' +cfg.shapelatent.loss0_weight = 1.0 +cfg.shapelatent.latent_dim = 256 +cfg.shapelatent.kl_weight = 1e-3 +cfg.shapelatent.decoder_num_points = -1 +# offset the sigma towards zero for better init, will use the log_sigma - offset value, better to be positive s.t. - offset < 0 since we'd like to push it towards 0; exp(-0.1)=0.9, exp(-0.8)=0.44, exp(-1)=0.3, exp(-10)=4e-5 +cfg.shapelatent.log_sigma_offset = 0.0 + +cfg.sde = CN() +cfg.sde.ode_sample = 0 #1 +# train the prior or not, default is 1, only when we do voxel2pts, will freeze prior +cfg.sde.train_dae = 1 +cfg.sde.init_t = 1.0 # start from time = 1.0 +cfg.sde.nhead = 4 # number of head in transformder: multi-head attention layer +cfg.sde.local_prior = 'same_as_global' # architecture for local prior +cfg.sde.drop_inactive_var = 0 +cfg.sde.learn_mixing_logit = 1 # freeze it +cfg.sde.regularize_mlogit_margin = 0.0 +cfg.sde.share_mlogit = 0 # use same mlogit for all latent variables +cfg.sde.hypara_mixing_logit = 0 # set as hyper-parameter and freeze it? +cfg.sde.bound_mlogit = 0 # clamp or not +cfg.sde.bound_mlogit_value = -5.42 # clamp the max value +cfg.sde.regularize_mlogit = 0 # set the sum of sigmoid(mlogit) as one loss +cfg.sde.attn_mhead = 0 # use multi-head attention in prior model +cfg.sde.attn_mhead_local = -1 # use multi-head attention in prior model +cfg.sde.pos_embed = 'none' +cfg.sde.hier_prior = 0 +cfg.sde.is_continues = 0 +cfg.sde.time_emb_scales = 1.0 # -> 1k? +cfg.sde.time_eps = 1e-2 +cfg.sde.ode_eps = 1e-5 # cut off for ode sampling +cfg.sde.sde_type = 'vpsde' # vada +cfg.sde.sigma2_0 = 0.0 +cfg.sde.sigma2_max = 0.99 +cfg.sde.sigma2_min = 1e-4 +cfg.sde.beta_start = 0.1 # 1e-4 * 1e3 +cfg.sde.beta_end = 20.0 # 1e-2 * 1e3 +# sampling, always iw # ll: small times; 'll_uniform' # -> ll_iw +cfg.sde.iw_sample_p = 'll_iw' +# drop_all_iw / drop_sigma2t_iw +cfg.sde.iw_subvp_like_vp_sde = False +cfg.sde.prior_model = 'models.latent_points_ada_localprior.PVCNN2Prior' + +# -- to train diffusion in latent space -- # +cfg.sde.update_q_ema = False +cfg.sde.iw_sample_q = 'reweight_p_samples' +# ll_iw / reweight_p_samples +cfg.sde.kl_anneal_portion_vada = 0.1 +cfg.sde.kl_const_portion_vada = 0.0 +cfg.sde.kl_const_coeff_vada = 0.7 +cfg.sde.kl_balance_vada = False +cfg.sde.grad_clip_max_norm = 0.0 +cfg.sde.cont_kl_anneal = True +# False +cfg.sde.mixing_logit_init = -6 +cfg.sde.weight_decay_norm_vae = 0.0 #1e-2 +cfg.sde.weight_decay_norm_dae = 0.0 #1e-2 +# -> 0, for sn calculator +cfg.sde.train_vae = True +cfg.sde.jac_reg_coeff = 0 +cfg.sde.jac_reg_freq = 1 +cfg.sde.kin_reg_coeff = 0 +cfg.sde.learning_rate_mlogit = -1.0 +cfg.sde.learning_rate_dae_local = 3e-4 +cfg.sde.learning_rate_min_dae_local = 3e-4 +cfg.sde.learning_rate_dae = 3e-4 +cfg.sde.learning_rate_min_dae = 3e-4 +cfg.sde.learning_rate_min_vae = 1e-5 +cfg.sde.learning_rate_vae = 1e-4 +cfg.sde.epochs = 800 +cfg.sde.warmup_epochs = 20 +cfg.sde.weight_decay = 3e-4 +cfg.sde.use_adamax = False +cfg.sde.use_adam = True # False +cfg.sde.mixed_prediction = False # True +cfg.sde.vae_checkpoint = '' +cfg.sde.dae_checkpoint = '' +# will be used to multiply with the t value, if ode solver, use 1k, if discrete solver, use 1.0 +cfg.sde.embedding_scale = 1.0 # 1000.0 +cfg.sde.embedding_type = 'positional' +cfg.sde.train_ode_solver_tol = 1e-5 +cfg.sde.num_scales_dae = 2 +cfg.sde.autocast_train = False +cfg.sde.diffusion_steps = 1000 +cfg.sde.embedding_dim = 128 +cfg.sde.num_channels_dae = 256 +cfg.sde.num_cell_per_scale_dae = 8 +cfg.sde.num_cell_per_scale_dae_local = 0 +cfg.sde.dropout = 0.2 +cfg.sde.num_preprocess_blocks = 2 +cfg.sde.num_latent_scales = 1 +cfg.sde.fir = False +cfg.sde.progressive = 'none' +cfg.sde.progressive_input = 'none' +cfg.sde.progressive_combine = 'sum' +cfg.sde.dataset = 'shape' +cfg.sde.denoising_stddevs = 'beta' +cfg.sde.ema_decay = 0.9999 +# cfg.sde.is_train_vae=True +cfg.register_deprecated_key("sde.is_train_vae") +cfg.sde.kl_max_coeff_vada = 1.0 +# conditional prior input +cfg.sde.condition_add = 1 +cfg.sde.condition_cat = 0 +cfg.sde.global_prior_ckpt = '' # checkpoint for global prior component +cfg.sde.pool_feat_cat = 0 # the local prior aggregate the feat as extra input channels + +# hyperparameter of ddim sampling +cfg.sde.ddim_skip_type = 'uniform' +cfg.sde.ddim_kappa = 1.0 # 1.0: fully ddpm sampling; 0: ode style sampling + +cfg.ddpm = CN() +cfg.ddpm.use_p2_weight = 0 +cfg.ddpm.p2_k = 1.0 +cfg.ddpm.p2_gamma = 1.0 +cfg.ddpm.use_new_timeemb = 0 +cfg.ddpm.input_dim = 3 +cfg.ddpm.dropout = 0.1 +cfg.ddpm.num_layers_classifier = 3 +cfg.ddpm.use_bn = True +cfg.ddpm.add_point_feat = True +cfg.ddpm.use_gn = False +cfg.ddpm.time_dim = 64 +cfg.ddpm.ema = 1 +cfg.ddpm.with_se = 0 +cfg.ddpm.use_global_attn = 0 +cfg.ddpm.num_steps = 1000 +cfg.ddpm.beta_1 = 1e-4 +cfg.ddpm.beta_T = 2e-2 +# ['linear', 'customer'] 'customer' for airplane in PVD +cfg.ddpm.sched_mode = 'linear' +cfg.ddpm.model_var_type = 'fixedlarge' +# define architecture: +cfg.register_deprecated_key("ddpm.pointnet_plus") +cfg.register_deprecated_key("ddpm.pointnet_pp") +cfg.register_deprecated_key("ddpm.pointnet_luo") +# end define architecture +#cfg.ddpm.use_pvc = 1 +cfg.register_deprecated_key("ddpm.use_pvc") +cfg.ddpm.clip_denoised = 0 +cfg.ddpm.model_mean_type = 'eps' +cfg.ddpm.loss_type = 'mse' +cfg.ddpm.loss_type_0 = '' +cfg.ddpm.loss_weight_emd = 0.02 +cfg.ddpm.loss_weight_cdnorm = 1.0 +cfg.ddpm.attn = [0, 1, 0, 0] +cfg.ddpm.ncenter = [1024, 256, 64, 16] + +#cfg.ddpm.pvc = CN() +#cfg.ddpm.pvc.use_small_model = 0 +#cfg.ddpm.pvc.mlp_after_pvc = 0 +cfg.register_deprecated_key("ddpm.pvc") +cfg.register_deprecated_key("ddpm.pvc.use_small_model") +cfg.register_deprecated_key("ddpm.pvc.mlp_after_pvc") + +cfg.ddpm.ddim_step = 200 + +cfg.data = CN() +cfg.data.nclass = 55 +cfg.data.cond_on_cat = 0 +cfg.data.cond_on_voxel = 0 +cfg.data.eval_test_split = 0 # eval loader will be using test split +cfg.data.voxel_size = 0.1 # size of voxel for voxel_datasets.py +cfg.data.noise_std = 0.1 # std for the noise added to the input data +cfg.data.noise_type = 'normal' # std for the noise added to the input data +cfg.data.noise_std_min = -1.0 # for range of noise std +cfg.data.clip_forge_enable = 0 +cfg.data.clip_model = 'ViT-B/32' +cfg.data.type = "datasets.pointflow_datasets" +# datasets/neuralspline_datasets datasets/shape_curvature +cfg.data.dataset_type = "shapenet15k" +cfg.data.num_workers = 12 # 8 +cfg.data.train_drop_last = 1 # drop_last for train data loader +cfg.data.cates = 'chair' # data category +cfg.data.tr_max_sample_points = 2048 +cfg.data.te_max_sample_points = 2048 +cfg.data.data_dir = "data/ShapeNetCore.v2.PC15k" # depreciated +cfg.data.batch_size = 12 +cfg.data.batch_size_test = 10 +cfg.data.dataset_scale = 1 +# -- the following option in terms of normalization should turn into string -- # +cfg.data.normalize_per_shape = False +cfg.data.normalize_shape_box = False +cfg.data.normalize_global = False +cfg.data.normalize_std_per_axis = False +cfg.data.normalize_range = False # not used +cfg.data.recenter_per_shape = True +# -- for the normal prediction model, used in folder_datasets +cfg.register_deprecated_key('data.load_point_stat') +cfg.register_deprecated_key('data.is_load_pointflow2NS') +cfg.register_deprecated_key('data.data_path') + +# +cfg.data.sample_with_replacement = 1 +# fixed the data.tr_max_sample_points $np data.te_max_sample_points $np2048 points of the first 15k points +cfg.data.random_subsample = 1 +# the data dim, used in dataset worker, if -1, it will be the same as ddpm.input_dim +cfg.data.input_dim = -1 +cfg.data.is_encode_whole_dataset_trainer = 0 +cfg.register_deprecated_key('data.augment') +cfg.register_deprecated_key('data.aug_translate') +cfg.register_deprecated_key('data.aug_scale') +cfg.register_deprecated_key('data.sub_train_set') + +cfg.test_size = 660 + +cfg.viz = CN() +cfg.viz.log_freq = 10 +cfg.viz.viz_freq = 400 +cfg.viz.save_freq = 200 +cfg.viz.val_freq = -1 +cfg.viz.viz_order = [2, 0, 1] +cfg.viz.vis_sample_ddim_step = 0 + +cfg.trainer = CN() +# when loss 1 is weighted, also weight the kl terms +cfg.trainer.apply_loss_weight_1_kl = 0 +cfg.trainer.kl_free = [0, 0] # the value for the threshold +# not back ward kl loss if KL value is smaller than the threshold +cfg.trainer.use_kl_free = 0 +cfg.trainer.type = "trainers.ddpm_trainer" # it means dist trainer +cfg.trainer.epochs = 10000 +cfg.trainer.warmup_epochs = 0 +cfg.trainer.seed = 1 +cfg.trainer.use_grad_scalar = 0 +cfg.trainer.opt = CN() +cfg.trainer.opt.type = 'adam' +cfg.trainer.opt.lr = 1e-4 # use bs*1e-5/8 +cfg.trainer.opt.lr_min = 1e-4 # use bs*1e-5/8 +# lr start to anneal after ratio of epochs; used in cosine and lambda lr scheduler +cfg.trainer.opt.start_ratio = 0.6 +cfg.trainer.opt.beta1 = 0.9 +cfg.trainer.opt.beta2 = 0.999 +cfg.trainer.opt.momentum = 0.9 # for SGD +cfg.trainer.opt.weight_decay = 0. +cfg.trainer.opt.ema_decay = 0.9999 +cfg.trainer.opt.grad_clip = -1. +cfg.trainer.opt.scheduler = '' +cfg.trainer.opt.step_decay = 0.998 +cfg.trainer.opt.vae_lr_warmup_epochs = 0 +cfg.trainer.anneal_kl = 0 +cfg.trainer.kl_balance = 0 +cfg.trainer.rec_balance = 0 +cfg.trainer.loss1_weight_anneal_v = 'quad' +cfg.trainer.kl_ratio = [1.0, 1.0] +cfg.trainer.kl_ratio_apply = 0 # apply the fixed kl ratio in the kl_ratio list +# using spectral norm regularization on vae training or not (used in hvae_trainer) +cfg.trainer.sn_reg_vae = 0 +cfg.trainer.sn_reg_vae_weight = 0.0 # loss weight for the sn regulatrization + +# [start] set in runtime +cfg.log_name = '' +cfg.save_dir = '' +cfg.log_dir = '' +cfg.comet_key = '' +# [end] + +cfg.voxel2pts = CN() +cfg.voxel2pts.init_weight = '' +cfg.voxel2pts.diffusion_steps = [0] + +cfg.dpm = CN() +cfg.dpm.train_encoder_only = 0 +cfg.num_ref = 0 # manully set the number of reference +cfg.eval_ddim_step = 0 # ddim sampling for the model evaluation +cfg.model_config = '' # used for model control, without ading new flag + +## --- depreciated --- # +cfg.register_deprecated_key('cls') # CN() +cfg.register_deprecated_key('cls.classifier_type') # 'models.classifier.OneLayer' +cfg.register_deprecated_key('cls.train_on_eps') # 1 +cfg.register_deprecated_key('cond_prior') # CN() +cfg.register_deprecated_key('cond_prior.grid_emb_resolution') # 32 +cfg.register_deprecated_key('cond_prior.emb_dim') # 64 +cfg.register_deprecated_key('cond_prior.use_voxel_feat') # 1 +cfg.register_deprecated_key('cond_encoder_prior') # 'models.shapelatent_modules.VoxelGridEncoder' +cfg.register_deprecated_key('cond_prior.pvcconv_concat_3d_feat_input') # 0 +cfg.register_deprecated_key('generate_mode_global') # 'interpolate' +cfg.register_deprecated_key('generate_mode_local') # 'freeze' +cfg.register_deprecated_key('normals') # CN() +cfg.register_deprecated_key('normals.model_type') # '' +cfg.register_deprecated_key('save_sample_seq_and_quit') # 0 +cfg.register_deprecated_key('lns_loss_weight') # 1.0 +cfg.register_deprecated_key('normal_pred_checkpoint') # '' +cfg.register_deprecated_key('lns') # CN() +cfg.register_deprecated_key('lns.override_config') # '' +cfg.register_deprecated_key('lns.wandb_checkpoint') # 'nvidia-toronto/generative_chairs/3m3gc6sz/checkpoint-171.pth' +cfg.register_deprecated_key('lns.num_input_points') # 1000 +cfg.register_deprecated_key('lns.num_simulate') # 20 +cfg.register_deprecated_key('lns.split_simulate') # 'train' +# use mesh-trainer or not +cfg.register_deprecated_key('with_lns') # 0 + +cfg.register_deprecated_key('normal_predictor_yaml') # '' + +cfg.register_deprecated_key('pointtransformer') # CN() +# number of attention layer in each block +cfg.register_deprecated_key('pointtransformer.blocks') # [2, 3, 4, 6, 3] +cfg.register_deprecated_key('shapelatent.refiner_bp') # 1 # bp gradient to the local-decoder or not +cfg.register_deprecated_key('shapelatent.loss_weight_refiner') # 1.0 # weighted loss for the refiner +cfg.register_deprecated_key('shapelatent.refiner_type') # 'models.pvcnn2.PVCNN2BaseAPI' # mode for the refiner + +cfg.register_deprecated_key('shapelatent.encoder_weight_std') # 0.1 +cfg.register_deprecated_key('shapelatent.encoder_weight_norm') # 0 +cfg.register_deprecated_key('shapelatent.encoder_weight_uniform') # 1 +cfg.register_deprecated_key('shapelatent.key_point_gen') # 'mlps' +cfg.register_deprecated_key('shapelatent.add_sub_loss') # 1 # not used +cfg.register_deprecated_key('shapelatent.local_decoder_type') # '' +cfg.register_deprecated_key('shapelatent.local_decoder_type_1') # '' +cfg.register_deprecated_key('shapelatent.local_encoder_ball_radius') # 0.8 +cfg.register_deprecated_key('shapelatent.local_encoder_ap_ball_radius') # 1.0 +cfg.register_deprecated_key('shapelatent.local_encoder_type') # '' +cfg.register_deprecated_key('shapelatent.local_encoder_type_1') # '' +cfg.register_deprecated_key('shapelatent.local_loss_weight_max') # 50 +cfg.register_deprecated_key('shapelatent.num_neighbors') # 0 +cfg.register_deprecated_key('shapelatent.extra_centers') # [] +# for latent model is flow +cfg.register_deprecated_key('shapelatent.latent_flow_depth') # 14 +cfg.register_deprecated_key('shapelatent.latent_flow_hidden_dim') # 256 +cfg.register_deprecated_key('shapelatent.bp_to_l0') # True +cfg.register_deprecated_key('shapelatent.global_only_epochs') # 0 +cfg.register_deprecated_key('shapelatent.center_local_points') # 1 +cfg.register_deprecated_key('shapelatent.hvae') # CN() +# alternatively way to compute the local loss +cfg.register_deprecated_key('shapelatent.hvae.loss_wrt_ori') # 0 +# add voxel feature to the latent space; the decoder require pvc conv or query +cfg.register_deprecated_key('shapelatent.add_voxel2z_global') # 0 +# reuse the encoder to get local latent +cfg.register_deprecated_key('shapelatent.query_output_local_from_enc') # 0 +# check models/shapelatent_modules where the feature will be saved as a dict +cfg.register_deprecated_key('shapelatent.query_local_feat_layer') # 'inter_voxelfeat_0' +# need to check the sa_blocks of the global encoder +cfg.register_deprecated_key('shapelatent.query_local_feat_dim') # 32 +# reuse the encoder to get local latent +cfg.register_deprecated_key('shapelatent.query_center_emd_from_enc') # 0 # reuse the encoder for center emd +cfg.register_deprecated_key('shapelatent.prog_dec_gf') # 8 # grow_factor in VaniDecoderProg +cfg.register_deprecated_key('shapelatent.prog_dec_gf_list') # [0, 0] # grow_factor in VaniDecoderProg +cfg.register_deprecated_key('shapelatent.prog_dec_ne') # 2 # num_expand in VaniDecoderProg +# increase number hirach, used by hvaemul model +cfg.register_deprecated_key('shapelatent.num_neighbors_per_level') # [64] # number of neighbors for each level +cfg.register_deprecated_key('shapelatent.num_level') # 1 # number of hierarchi latent space (local) +cfg.register_deprecated_key('shapelatent.x0_target_fps') # 0 # let the target of global output as the +cfg.register_deprecated_key('shapelatent.downsample_input_ratio') # 1.0 +# whether taking other tensor as input to local-encoder of not +cfg.register_deprecated_key('shapelatent.local_enc_input') # 'sim' +# local encoder take z0 as input at which location +cfg.register_deprecated_key('shapelatent.local_encoder_condition_z0') # '' +# output the absolution coordinates or the offset w.r.t centers +cfg.register_deprecated_key('shapelatent.local_decoder_output_offset') # 0 +# feed coords of keypoints to the local prior model +cfg.register_deprecated_key('shapelatent.local_prior_need_coords') # 0 + +# add the time embedding tensor to each encoder layer instead of add to first layer only +cfg.register_deprecated_key('sde.transformer_temb2interlayer') # 0 +# normalization used in transformer encoder; +cfg.register_deprecated_key('sde.transformer_norm_type') # 'layer_norm' +cfg.register_deprecated_key('data.has_normal') # 0 # for datasets/pointflow_rgb.py only +cfg.register_deprecated_key('data.has_color') # 0 # for datasets/pointflow_rgb.py only +cfg.register_deprecated_key('data.cls_data_ratio') # 1.0 # ratio of the training data +cfg.register_deprecated_key('data.sample_curvature') # 0 # only for datasets/shape_curvature +cfg.register_deprecated_key('data.ratio_c') # 1.0 # only for datasets/shape_curvature diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..aeea0a5 --- /dev/null +++ b/demo.py @@ -0,0 +1,45 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +""" + require diffusers-0.11.1 +""" +import os +import clip +import torch +from PIL import Image +from default_config import cfg as config +from models.lion import LION +from utils.vis_helper import plot_points +from huggingface_hub import hf_hub_download + +model_path = './lion_ckpt/text2shape/chair/checkpoints/model.pt' +model_config = './lion_ckpt/text2shape/chair/cfg.yml' + +config.merge_from_file(model_config) +lion = LION(config) +lion.load_model(model_path) + +if config.clipforge.enable: + input_t = ["a swivel chair, five wheels"] + device_str = 'cuda' + clip_model, clip_preprocess = clip.load( + config.clipforge.clip_model, device=device_str) + text = clip.tokenize(input_t).to(device_str) + clip_feat = [] + clip_feat.append(clip_model.encode_text(text).float()) + clip_feat = torch.cat(clip_feat, dim=0) + print('clip_feat', clip_feat.shape) +else: + clip_feat = None +output = lion.sample(1 if clip_feat is None else clip_feat.shape[0], clip_feat=clip_feat) +pts = output['points'] +img_name = "/tmp/tmp.png" +plot_points(pts, output_name=img_name) +img = Image.open(img_name) +img.show() diff --git a/env.yaml b/env.yaml new file mode 100644 index 0000000..fffe316 --- /dev/null +++ b/env.yaml @@ -0,0 +1,311 @@ +name: lion_env +channels: + - pytorch + - nvidia + - anaconda + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=4.5=1_gnu + - argon2-cffi=20.1.0=py38h27cfd23_1 + - async_generator=1.10=pyhd3eb1b0_0 + - attrs=21.4.0=pyhd3eb1b0_0 + - backcall=0.2.0=pyhd3eb1b0_0 + - blas=1.0=mkl + - bleach=4.1.0=pyhd3eb1b0_0 + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2020.10.14=0 + - certifi=2020.6.20=py38_0 + - cffi=1.15.0=py38hd667e15_1 + - cmake=3.18.2=ha30ef3c_0 + - cudatoolkit=11.1.74=h6bb024c_0 + - debugpy=1.5.1=py38h295c915_0 + - decorator=5.1.1=pyhd3eb1b0_0 + - defusedxml=0.7.1=pyhd3eb1b0_0 + - entrypoints=0.3=py38_0 + - expat=2.2.10=he6710b0_2 + - ffmpeg=4.3=hf484d3e_0 + - freetype=2.11.0=h70c0345_0 + - giflib=5.2.1=h7b6447c_0 + - gmp=6.2.1=h2531618_2 + - gnutls=3.6.15=he1e5248_0 + - importlib_metadata=4.8.2=hd3eb1b0_0 + - intel-openmp=2021.4.0=h06a4308_3561 + - ipykernel=6.4.1=py38h06a4308_1 + - ipython=7.31.1=py38h06a4308_0 + - ipython_genutils=0.2.0=pyhd3eb1b0_1 + - ipywidgets=7.6.5=pyhd3eb1b0_1 + - jedi=0.18.1=py38h06a4308_0 + - jpeg=9d=h7f8727e_0 + - jupyter_client=7.1.2=pyhd3eb1b0_0 + - jupyter_core=4.9.1=py38h06a4308_0 + - jupyterlab_pygments=0.1.2=py_0 + - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 + - krb5=1.18.2=h173b8e3_0 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.35.1=h7274673_9 + - libcurl=7.71.1=h20c2e04_1 + - libedit=3.1.20191231=h14c3975_1 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.3.0=h5101ec6_17 + - libgomp=9.3.0=h5101ec6_17 + - libiconv=1.15=h63c8f33_5 + - libidn2=2.3.2=h7f8727e_0 + - libpng=1.6.37=hbc83047_0 + - libsodium=1.0.18=h7b6447c_0 + - libssh2=1.9.0=h1ba5d50_1 + - libstdcxx-ng=9.3.0=hd4cf53a_17 + - libtasn1=4.16.0=h27cfd23_0 + - libtiff=4.2.0=h85742a9_0 + - libunistring=0.9.10=h27cfd23_0 + - libuv=1.40.0=h7b6447c_0 + - libwebp=1.2.0=h89dd481_0 + - libwebp-base=1.2.0=h27cfd23_0 + - lz4-c=1.9.3=h295c915_1 + - markupsafe=2.0.1=py38h27cfd23_0 + - matplotlib-inline=0.1.2=pyhd3eb1b0_2 + - mistune=0.8.4=py38h7b6447c_1000 + - mkl=2021.4.0=h06a4308_640 + - mkl-service=2.4.0=py38h7f8727e_0 + - mkl_fft=1.3.1=py38hd3c417c_0 + - mkl_random=1.2.2=py38h51133e4_0 + - nbclient=0.5.3=pyhd3eb1b0_0 + - nbconvert=6.3.0=py38h06a4308_0 + - ncurses=6.3=h7f8727e_2 + - nest-asyncio=1.5.1=pyhd3eb1b0_0 + - nettle=3.7.3=hbbd107a_1 + - notebook=6.4.6=py38h06a4308_0 + - numpy=1.21.2=py38h20f2e39_0 + - numpy-base=1.21.2=py38h79a1101_0 + - olefile=0.46=pyhd3eb1b0_0 + - openh264=2.1.1=h4ff587b_0 + - openssl=1.1.1m=h7f8727e_0 + - packaging=21.3=pyhd3eb1b0_0 + - pandocfilters=1.5.0=pyhd3eb1b0_0 + - parso=0.8.3=pyhd3eb1b0_0 + - pexpect=4.8.0=pyhd3eb1b0_3 + - pickleshare=0.7.5=pyhd3eb1b0_1003 + - pillow=8.4.0=py38h5aabda8_0 + - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0 + - prometheus_client=0.13.1=pyhd3eb1b0_0 + - prompt-toolkit=3.0.20=pyhd3eb1b0_0 + - ptyprocess=0.7.0=pyhd3eb1b0_2 + - pycparser=2.21=pyhd3eb1b0_0 + - pygments=2.11.2=pyhd3eb1b0_0 + - python=3.8.12=h12debd9_0 + - python-dateutil=2.8.2=pyhd3eb1b0_0 + - python-fastjsonschema=2.16.1=pyhd8ed1ab_0 + - python_abi=3.8=2_cp38 + - pytorch=1.10.2=py3.8_cuda11.1_cudnn8.0.5_0 + - pytorch-mutex=1.0=cuda + - pyzmq=22.3.0=py38h295c915_2 + - readline=8.1.2=h7f8727e_1 + - rhash=1.4.0=h1ba5d50_0 + - send2trash=1.8.0=pyhd3eb1b0_1 + - six=1.16.0=pyhd3eb1b0_0 + - sqlite=3.37.2=hc218d9a_0 + - terminado=0.9.4=py38h06a4308_0 + - testpath=0.5.0=pyhd3eb1b0_0 + - tk=8.6.11=h1ccaba5_0 + - torchaudio=0.10.2=py38_cu111 + - torchvision=0.11.3=py38_cu111 + - tornado=6.1=py38h27cfd23_0 + - traitlets=5.1.1=pyhd3eb1b0_0 + - wcwidth=0.2.5=pyhd3eb1b0_0 + - webencodings=0.5.1=py38_1 + - wheel=0.37.1=pyhd3eb1b0_0 + - widgetsnbextension=3.5.1=py38_0 + - xz=5.2.5=h7b6447c_0 + - zeromq=4.3.4=h2531618_0 + - zipp=3.7.0=pyhd3eb1b0_0 + - zlib=1.2.11=h7f8727e_4 + - zstd=1.4.5=h9ceee32_0 + - pip: + - about-time==3.1.1 + - absl-py==1.0.0 + - addict==2.4.0 + - aiohttp==3.8.1 + - aiosignal==1.2.0 + - alive-progress==2.2.0 + - antlr4-python3-runtime==4.9.3 + - anyio==3.5.0 + - astunparse==1.6.3 + - async-timeout==4.0.2 + - babel==2.9.1 + - cachetools==5.0.0 + - calmsize==0.1.3 + - ccimport==0.3.7 + - cftime==1.6.0 + - charset-normalizer==2.0.11 + - click==8.0.3 + - colorama==0.4.4 + - comet-ml==3.31.21 + - commonmark==0.9.1 + - configobj==5.0.6 + - crc32c==2.2.post0 + - cumm-cu111==0.2.8 + - cupy-cuda111==10.2.0 + - cycler==0.11.0 + - cython==0.29.20 + - dataclasses==0.6 + - deepspeed==0.6.5 + - deprecation==2.1.0 + - diffusers==0.11.1 + - docker-pycreds==0.4.0 + - drjit==0.2.1 + - dulwich==0.20.32 + - easydict==1.9 + - einops==0.4.0 + - everett==3.0.0 + - fastrlock==0.8 + - filelock==3.9.0 + - fire==0.4.0 + - flatbuffers==2.0 + - flatten-dict==0.4.2 + - fonttools==4.29.1 + - freetype-py==2.3.0 + - frozenlist==1.3.0 + - fsspec==2022.2.0 + - ftfy==6.1.1 + - future==0.18.2 + - fvcore==0.1.5.post20220512 + - gast==0.5.3 + - gitdb==4.0.9 + - gitpython==3.1.26 + - google-auth==2.6.0 + - google-auth-oauthlib==0.4.6 + - google-pasta==0.2.0 + - grapheme==0.6.0 + - grpcio==1.43.0 + - h5py==3.6.0 + - hjson==3.0.2 + - huggingface-hub==0.11.1 + - idna==3.3 + - imageio==2.15.0 + - imageio-ffmpeg==0.4.5 + - importlib-metadata==4.10.1 + - importlib-resources==5.4.0 + - iopath==0.1.10 + - jinja2==3.1.1 + - joblib==1.1.0 + - json5==0.9.6 + - jsonschema==4.4.0 + - jupyter-packaging==0.12.0 + - jupyter-server==1.15.6 + - jupyterlab==3.3.2 + - jupyterlab-server==2.11.2 + - keras==2.8.0 + - keras-preprocessing==1.1.2 + - kiwisolver==1.3.2 + - kornia==0.6.6 + - lark==1.1.2 + - libclang==14.0.1 + - llvmlite==0.39.0 + - loguru==0.6.0 + - markdown==3.3.6 + - matplotlib==3.5.1 + - matplotlib2tikz==0.7.6 + - meshio==5.3.4 + - mitsuba==3.0.1 + - mrcfile==1.3.0 + - multidict==6.0.2 + - multipledispatch==0.6.0 + - mypy-extensions==0.4.3 + - nbclassic==0.3.7 + - nbformat==5.2.0 + - nestargs==0.5.0 + - netcdf4==1.5.8 + - networkx==2.6.3 + - ninja==1.10.2.3 + - notebook-shim==0.1.0 + - numba==0.56.0 + - nvidia-ml-py3==7.352.0 + - oauthlib==3.2.0 + - omegaconf==2.2.2 + - open3d==0.15.2 + - opencv-python==4.5.5.64 + - openexr==1.3.7 + - opt-einsum==3.3.0 + - pandas==1.4.0 + - pathtools==0.1.2 + - pccm==0.3.4 + - pip==22.3.1 + - plyfile==0.7.4 + - portalocker==2.5.1 + - progressbar2==4.0.0 + - promise==2.3 + - protobuf==3.19.4 + - psutil==5.9.0 + - py-cpuinfo==8.0.0 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pybind11==2.10.0 + - pydeprecate==0.3.1 + - pyglet==1.5.23 + - pykeops==1.5 + - pymcubes==0.1.2 + - pyopengl==3.1.0 + - pyparsing==3.0.7 + - pyquaternion==0.9.9 + - pyrr==0.10.3 + - pyrsistent==0.18.1 + - python-swiftclient==4.0.0 + - python-utils==3.3.3 + - pytorch-lightning==1.5.1 + - pytorch3d==0.3.0 + - pytz==2021.3 + - pywavelets==1.2.0 + - pyyaml==6.0 + - regex==2022.3.15 + - requests==2.27.1 + - requests-oauthlib==1.3.1 + - requests-toolbelt==0.9.1 + - rich==12.3.0 + - rsa==4.8 + - ruamel-yaml==0.17.20 + - ruamel-yaml-clib==0.2.6 + - scikit-image==0.19.1 + - scikit-learn==1.0.2 + - scipy==1.8.0 + - seaborn==0.11.2 + - semantic-version==2.9.0 + - sentry-sdk==1.5.4 + - sharedarray==3.2.1 + - shortuuid==1.0.8 + - simple-parsing==0.0.18 + - simplejson==3.18.0 + - sklearn==0.0 + - smmap==5.0.0 + - sniffio==1.2.0 + - tabulate==0.8.9 + - tensorboard==2.8.0 + - tensorboard-data-server==0.6.1 + - tensorboard-plugin-wit==1.8.1 + - tensorboardx==2.4.1 + - tensorflow-gpu==2.8.0 + - tensorflow-io-gcs-filesystem==0.25.0 + - termcolor==1.1.0 + - tf-estimator-nightly==2.8.0.dev2021122109 + - tflearn==0.5.0 + - tfrecord==1.14.1 + - threadpoolctl==3.1.0 + - tifffile==2022.2.2 + - tikzplotlib==0.10.1 + - tomlkit==0.10.0 + - torchmetrics==0.7.2 + - tqdm==4.62.3 + - trimesh==3.10.1 + - typing-extensions==4.2.0 + - typing-inspect==0.7.1 + - urllib3==1.26.8 + - wandb==0.12.10 + - webcolors==1.11.1 + - websocket-client==1.2.3 + - werkzeug==2.0.3 + - wrapt==1.13.3 + - wurlitzer==3.0.2 + - yacs==0.1.8 + - yarl==1.7.2 + - yaspin==2.1.0 diff --git a/models/adagn.py b/models/adagn.py new file mode 100644 index 0000000..6dcf10d --- /dev/null +++ b/models/adagn.py @@ -0,0 +1,67 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +""" +adaptive group norm +""" +from loguru import logger +import torch.nn as nn +import torch +import numpy as np +from utils.checker import * +from .dense import dense +import os + +class AdaGN(nn.Module): + ''' + adaptive group normalization + ''' + def __init__(self, ndim, cfg, n_channel): + """ + ndim: dim of the input features + n_channel: number of channels of the inputs + ndim_style: channel of the style features + """ + super().__init__() + style_dim = cfg.latent_pts.style_dim + init_scale = cfg.latent_pts.ada_mlp_init_scale + self.ndim = ndim + self.n_channel = n_channel + self.style_dim = style_dim + self.out_dim = n_channel * 2 + self.norm = nn.GroupNorm(8, n_channel) + in_channel = n_channel + self.emd = dense(style_dim, n_channel*2, init_scale=init_scale) + self.emd.bias.data[:in_channel] = 1 + self.emd.bias.data[in_channel:] = 0 + + def __repr__(self): + return f"AdaGN(GN(8, {self.n_channel}), Linear({self.style_dim}, {self.out_dim}))" + + def forward(self, image, style): + # style: B,D + # image: B,D,N,1 + CHECK2D(style) + style = self.emd(style) + if self.ndim == 3: #B,D,V,V,V + CHECK5D(image) + style = style.view(style.shape[0], -1, 1, 1, 1) # 5D + elif self.ndim == 2: # B,D,N,1 + CHECK4D(image) + style = style.view(style.shape[0], -1, 1, 1) # 4D + elif self.ndim == 1: # B,D,N + CHECK3D(image) + style = style.view(style.shape[0], -1, 1) # 4D + else: + raise NotImplementedError + + factor, bias = style.chunk(2, 1) + result = self.norm(image) + result = result * factor + bias + return result + + diff --git a/models/dense.py b/models/dense.py new file mode 100644 index 0000000..cbbde1e --- /dev/null +++ b/models/dense.py @@ -0,0 +1,80 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +""" copied and modified from https://github.com/CW-Huang/sdeflow-light/blob/524650bc5ad69522b3e0905672deef0650374512/lib/models/unet.py """ + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import _calculate_fan_in_and_fan_out +import numpy as np + + +def _calculate_correct_fan(tensor, mode): + """ + copied and modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py#L337 + """ + mode = mode.lower() + valid_modes = ['fan_in', 'fan_out', 'fan_avg'] + if mode not in valid_modes: + raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) + + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + return fan_in if mode == 'fan_in' else fan_out + + +def kaiming_uniform_(tensor, gain=1., mode='fan_in'): + r"""Fills the input `Tensor` with values according to the method + described in `Delving deep into rectifiers: Surpassing human-level + performance on ImageNet classification` - He, K. et al. (2015), using a + uniform distribution. The resulting tensor will have values sampled from + :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} + Also known as He initialization. + Args: + tensor: an n-dimensional `torch.Tensor` + gain: multiplier to the dispersion + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.kaiming_uniform_(w, mode='fan_in') + """ + fan = _calculate_correct_fan(tensor, mode) + # gain = calculate_gain(nonlinearity, a) + var = gain / max(1., fan) + bound = math.sqrt(3.0 * var) # Calculate uniform bounds from standard deviation + with torch.no_grad(): + return tensor.uniform_(-bound, bound) + + +def variance_scaling_init_(tensor, scale): + return kaiming_uniform_(tensor, gain=1e-10 if scale == 0 else scale, mode='fan_avg') + + +def dense(in_channels, out_channels, init_scale=1.): + lin = nn.Linear(in_channels, out_channels) + variance_scaling_init_(lin.weight, scale=init_scale) + nn.init.zeros_(lin.bias) + return lin + +def conv2d(in_planes, out_planes, kernel_size=(3, 3), stride=1, dilation=1, padding=1, bias=True, padding_mode='zeros', + init_scale=1.): + conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, + bias=bias, padding_mode=padding_mode) + variance_scaling_init_(conv.weight, scale=init_scale) + if bias: + nn.init.zeros_(conv.bias) + return conv + + + diff --git a/models/distributions.py b/models/distributions.py new file mode 100644 index 0000000..e375bb8 --- /dev/null +++ b/models/distributions.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch +import numpy as np + +@torch.jit.script +def sample_normal_jit(mu, sigma): + rho = mu.mul(0).normal_() + z = rho.mul_(sigma).add_(mu) + return z, rho + +class Normal: + def __init__(self, mu, log_sigma, sigma=None): + self.mu = mu + self.log_sigma = log_sigma + self.sigma = torch.exp(log_sigma) if sigma is None else sigma + + def sample(self, t=1.): + return sample_normal_jit(self.mu, self.sigma * t) + + def sample_given_rho(self, rho): + return rho * self.sigma + self.mu + + def mean(self): + return self.mu + + def log_p(self, samples): + normalized_samples = (samples - self.mu) / self.sigma + log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - self.log_sigma + return log_p + + diff --git a/models/latent_points_ada.py b/models/latent_points_ada.py new file mode 100644 index 0000000..a37854f --- /dev/null +++ b/models/latent_points_ada.py @@ -0,0 +1,273 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +from loguru import logger +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from .pvcnn2_ada import \ + create_pointnet2_sa_components, create_pointnet2_fp_modules, LinearAttention, create_mlp_components, SharedMLP + +# the building block of encode and decoder for VAE + +class PVCNN2Unet(nn.Module): + """ + copied and modified from https://github.com/alexzhou907/PVD/blob/9747265a5f141e5546fd4f862bfa66aa59f1bd33/model/pvcnn_generation.py#L172 + """ + def __init__(self, + num_classes, embed_dim, use_att, dropout=0.1, + extra_feature_channels=3, + input_dim=3, + width_multiplier=1, + voxel_resolution_multiplier=1, + time_emb_scales=1.0, + verbose=True, + condition_input=False, + point_as_feat=1, cfg={}, + sa_blocks={}, fp_blocks={}, + clip_forge_enable=0, + clip_forge_dim=512 + ): + super().__init__() + logger.info('[Build Unet] extra_feature_channels={}, input_dim={}', + extra_feature_channels, input_dim) + self.input_dim = input_dim + + self.clip_forge_enable = clip_forge_enable + self.sa_blocks = sa_blocks + self.fp_blocks = fp_blocks + self.point_as_feat = point_as_feat + self.condition_input = condition_input + assert extra_feature_channels >= 0 + self.time_emb_scales = time_emb_scales + self.embed_dim = embed_dim + ## assert(self.embed_dim == 0) + if self.embed_dim > 0: # has time embedding + # for prior model, we have time embedding, for VAE model, no time embedding + self.embedf = nn.Sequential( + nn.Linear(embed_dim, embed_dim), + nn.LeakyReLU(0.1, inplace=True), + nn.Linear(embed_dim, embed_dim), + ) + + if self.clip_forge_enable: + self.clip_forge_mapping = nn.Linear(clip_forge_dim, embed_dim) + style_dim = cfg.latent_pts.style_dim + self.style_clip = nn.Linear(style_dim + embed_dim, style_dim) + + self.in_channels = extra_feature_channels + 3 + + sa_layers, sa_in_channels, channels_sa_features, _ = \ + create_pointnet2_sa_components( + input_dim=input_dim, + sa_blocks=self.sa_blocks, + extra_feature_channels=extra_feature_channels, + with_se=True, + embed_dim=embed_dim, # time embedding dim + use_att=use_att, dropout=dropout, + width_multiplier=width_multiplier, + voxel_resolution_multiplier=voxel_resolution_multiplier, + verbose=verbose, cfg=cfg + ) + self.sa_layers = nn.ModuleList(sa_layers) + + self.global_att = None if not use_att else LinearAttention(channels_sa_features, 8, verbose=verbose) + + # only use extra features in the last fp module + sa_in_channels[0] = extra_feature_channels + input_dim - 3 + fp_layers, channels_fp_features = create_pointnet2_fp_modules( + fp_blocks=self.fp_blocks, in_channels=channels_sa_features, + sa_in_channels=sa_in_channels, + with_se=True, embed_dim=embed_dim, + use_att=use_att, dropout=dropout, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier, + verbose=verbose, cfg=cfg + ) + self.fp_layers = nn.ModuleList(fp_layers) + + layers, _ = create_mlp_components( + in_channels=channels_fp_features, + out_channels=[128, dropout, num_classes], # was 0.5 + classifier=True, dim=2, width_multiplier=width_multiplier, + cfg=cfg) + self.classifier = nn.ModuleList(layers) + + def get_timestep_embedding(self, timesteps, device): + if len(timesteps.shape) == 2 and timesteps.shape[1] == 1: + timesteps = timesteps[:,0] + assert(len(timesteps.shape) == 1), f'get shape: {timesteps.shape}' + timesteps = timesteps * self.time_emb_scales + + half_dim = self.embed_dim // 2 + emb = np.log(10000) / (half_dim - 1) + emb = torch.from_numpy(np.exp(np.arange(0, half_dim) * -emb)).float().to(device) + emb = timesteps[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if self.embed_dim % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), "constant", 0) + assert emb.shape == torch.Size([timesteps.shape[0], self.embed_dim]) + return emb + + def forward(self, inputs, **kwargs): + # Input: coords: B3N + B = inputs.shape[0] + coords = inputs[:, :self.input_dim, :].contiguous() + features = inputs + temb = kwargs.get('t', None) + if temb is not None: + t = temb + if t.ndim == 0 and not len(t.shape) == 1: + t = t.view(1).expand(B) + temb = self.embedf(self.get_timestep_embedding(t, inputs.device + ))[:,:,None].expand(-1,-1,inputs.shape[-1]) + temb_ori = temb # B,embed_dim,Npoint + + style = kwargs['style'] + if self.clip_forge_enable: + clip_feat = kwargs['clip_feat'] + assert(clip_feat is not None), f'require clip_feat as input' + clip_feat = self.clip_forge_mapping(clip_feat) + style = torch.cat([style, clip_feat], dim=1).contiguous() + style = self.style_clip(style) + + coords_list, in_features_list = [], [] + for i, sa_blocks in enumerate(self.sa_layers): + in_features_list.append(features) + coords_list.append(coords) + if i > 0 and temb is not None: + #TODO: implement a sa_blocks forward function; check if is PVConv layer and kwargs get grid_emb, take as additional input + features = torch.cat([features,temb],dim=1) + features, coords, temb, _ = \ + sa_blocks ((features, + coords, temb, style)) + else: # i == 0 or temb is None + features, coords, temb, _ = \ + sa_blocks ((features, coords, temb, style)) + + in_features_list[0] = inputs[:, 3:, :].contiguous() + if self.global_att is not None: + features = self.global_att(features) + for fp_idx, fp_blocks in enumerate(self.fp_layers): + if temb is not None: + features, coords, temb, _ = fp_blocks(( + coords_list[-1-fp_idx], coords, + torch.cat([features,temb],dim=1), + in_features_list[-1-fp_idx], temb, style)) + else: + features, coords, temb, _ = fp_blocks(( + coords_list[-1-fp_idx], coords, + features, + in_features_list[-1-fp_idx], temb, style)) + + for l in self.classifier: + if isinstance(l, SharedMLP): + features = l(features, style) + else: + features = l(features) + return features + +class PointTransPVC(nn.Module): + # encoder : B,N,3 -> B,N,2*D + sa_blocks = [ # conv_configs, sa_configs + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (128, 128, 128))), + ] + fp_blocks = [ + ((128, 128), (128, 3, 8)), # fp_configs, conv_configs + ((128, 128), (128, 3, 8)), + ((128, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, zdim, input_dim, args={}): + super().__init__() + self.zdim = zdim + self.layers = PVCNN2Unet(2*zdim+input_dim*2, + embed_dim=0, use_att=1, extra_feature_channels=0, + input_dim=args.ddpm.input_dim, cfg=args, + sa_blocks=self.sa_blocks, fp_blocks=self.fp_blocks, + dropout=args.ddpm.dropout) + self.skip_weight = args.latent_pts.skip_weight + self.pts_sigma_offset = args.latent_pts.pts_sigma_offset + self.input_dim = input_dim + + def forward(self, inputs): + x, style = inputs + B,N,D = x.shape + output = self.layers(x.permute(0,2,1).contiguous(), style=style).permute(0,2,1).contiguous() # BND + + pt_mu_1d = output[:,:,:self.input_dim].contiguous() + pt_sigma_1d = output[:,:,self.input_dim:2*self.input_dim].contiguous() - self.pts_sigma_offset + + pt_mu_1d = self.skip_weight * pt_mu_1d + x + if self.zdim > 0: + ft_mu_1d = output[:,:,2*self.input_dim:-self.zdim].contiguous() + ft_sigma_1d = output[:,:,-self.zdim:].contiguous() + + mu_1d = torch.cat([pt_mu_1d, ft_mu_1d], dim=2).view(B,-1).contiguous() + sigma_1d = torch.cat([pt_sigma_1d, ft_sigma_1d], dim=2).view(B,-1).contiguous() + else: + mu_1d = pt_mu_1d.view(B,-1).contiguous() + sigma_1d = pt_sigma_1d.view(B,-1).contiguous() + return {'mu_1d': mu_1d, 'sigma_1d': sigma_1d} + +class LatentPointDecPVC(nn.Module): + """ input x: [B,Npoint,D] with [B,Npoint,3] + """ + sa_blocks = [ # conv_configs, sa_configs + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (128, 128, 128))), + ] + fp_blocks = [ + ((128, 128), (128, 3, 8)), # fp_configs, conv_configs + ((128, 128), (128, 3, 8)), + ((128, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, point_dim, context_dim, num_points=None, args={}, **kwargs): + super().__init__() + self.point_dim = point_dim + logger.info('[Build Dec] point_dim={}, context_dim={}', point_dim, context_dim) + self.context_dim = context_dim + self.point_dim + # self.num_points = num_points + if num_points is None: + self.num_points = args.data.tr_max_sample_points + else: + self.num_points = num_points + self.layers = PVCNN2Unet(point_dim, embed_dim=0, use_att=1, + extra_feature_channels=context_dim, + input_dim=args.ddpm.input_dim, cfg=args, + sa_blocks=self.sa_blocks, fp_blocks=self.fp_blocks, + dropout=args.ddpm.dropout) + self.skip_weight = args.latent_pts.skip_weight + + def forward(self, x, beta, context, style): + """ + Args: + x: Point clouds at some timestep t, (B, N, d). [not used] + beta: Time. (B, ). [not used] + context: Latent points, (B,N_pts*D_latent_pts), D_latent_pts = D_input + D_extra + style: Shape latents. (B,d). + Returns: + points: (B,N,3) + """ + + # CHECKDIM(context, 1, self.num_points*self.context_dim) + assert(context.shape[1] == self.num_points*self.context_dim) + context = context.view(-1,self.num_points,self.context_dim) # BND + x = context[:,:,:self.point_dim] + output = self.layers(context.permute(0,2,1).contiguous(), style=style).permute(0,2,1).contiguous() # BN3 + output = output * self.skip_weight + x + return output + diff --git a/models/latent_points_ada_localprior.py b/models/latent_points_ada_localprior.py new file mode 100644 index 0000000..4f71740 --- /dev/null +++ b/models/latent_points_ada_localprior.py @@ -0,0 +1,84 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch +from loguru import logger +import torch.nn as nn +import torch.nn.functional as F +from .latent_points_ada import PVCNN2Unet +from .utils import mask_inactive_variables + +# diffusion model for latent points +class PVCNN2Prior(PVCNN2Unet): + sa_blocks = [ # conv_configs, sa_configs + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 128))), + (None, (16, 0.8, 32, (128, 128, 128))), + ] + fp_blocks = [ + ((128, 128), (128, 3, 8)), # fp_configs, conv_configs + ((128, 128), (128, 3, 8)), + ((128, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, args, num_input_channels, cfg): + + # only cfg is used + self.clip_forge_enable = cfg.clipforge.enable + clip_forge_dim = cfg.clipforge.feat_dim + num_input_channels = num_classes = cfg.shapelatent.latent_dim + cfg.ddpm.input_dim + self.num_classes = num_classes + embed_dim = cfg.ddpm.time_dim + use_att = True + extra_feature_channels = cfg.shapelatent.latent_dim + self.num_points = cfg.data.tr_max_sample_points + dropout = cfg.ddpm.dropout + time_emb_scales = cfg.sde.embedding_scale # 1k default + logger.info('[Build Prior Model] nclass={}, embed_dim={}, use_att={},' + 'extra_feature_channels={}, dropout={}, time_emb_scales={} num_point={}', + num_classes, embed_dim, use_att, extra_feature_channels, dropout, time_emb_scales, + self.num_points) + # Attention: we are not using time_emb_scales here, but the embedding_scale + super().__init__( + num_classes, embed_dim, use_att, dropout=dropout, + input_dim=cfg.ddpm.input_dim, + extra_feature_channels=extra_feature_channels, + time_emb_scales=time_emb_scales, + verbose=True, + condition_input=False, + cfg=cfg, + sa_blocks=self.sa_blocks, + fp_blocks=self.fp_blocks, + clip_forge_enable=self.clip_forge_enable, clip_forge_dim=clip_forge_dim) + # init mixing logit + self.mixed_prediction = cfg.sde.mixed_prediction # This enables mixed prediction + if self.mixed_prediction: + logger.info('init-mixing_logit = {}, after sigmoid = {}', + cfg.sde.mixing_logit_init, torch.sigmoid(torch.tensor(cfg.sde.mixing_logit_init)) + ) + init = cfg.sde.mixing_logit_init * torch.ones(size=[1, num_input_channels*self.num_points, 1, 1]) + self.mixing_logit = torch.nn.Parameter(init, requires_grad=True) + self.is_active = None + else: # no mixing_logit + self.mixing_logit = None + self.is_active = None + + def forward(self, x, t, *args, **kwargs): #x0=None): + # Input: x: B,ND or B,ND,1,1 + # require shape for x: B,C,N + ## CHECKEQ(x.shape[-1], self.num_classes) + assert('condition_input' in kwargs), 'require condition_input' + if self.mixed_prediction and self.is_active is not None: + x = mask_inactive_variables(x, self.is_active) + input_shape = x.shape + x = x.view(-1,self.num_points,self.num_classes).permute(0,2,1).contiguous() + B = x.shape[0] + out = super().forward(x, t=t, style=kwargs['condition_input'].squeeze(-1).squeeze(-1), clip_feat=kwargs.get('clip_feat', None)) + return out.permute(0,2,1).contiguous().view(input_shape) + # -1,self.num_classes) # BDN -> BND -> BN,D diff --git a/models/lion.py b/models/lion.py new file mode 100644 index 0000000..339aaf5 --- /dev/null +++ b/models/lion.py @@ -0,0 +1,91 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +from models.vae_adain import Model as VAE +from models.latent_points_ada_localprior import PVCNN2Prior as LocalPrior +from utils.diffusion_pvd import DiffusionDiscretized +from utils.vis_helper import plot_points +from utils.model_helper import import_model +from diffusers import DDPMScheduler +import torch +from matplotlib import pyplot as plt + +class LION(object): + def __init__(self, cfg): + self.vae = VAE(cfg).cuda() + GlobalPrior = import_model(cfg.latent_pts.style_prior) + global_prior = GlobalPrior(cfg.sde, cfg.latent_pts.style_dim, cfg).cuda() + local_prior = LocalPrior(cfg.sde, cfg.shapelatent.latent_dim, cfg).cuda() + self.priors = torch.nn.ModuleList([global_prior, local_prior]) + self.scheduler = DDPMScheduler(clip_sample=False, + beta_start=cfg.ddpm.beta_1, beta_end=cfg.ddpm.beta_T, beta_schedule=cfg.ddpm.sched_mode, + num_train_timesteps=cfg.ddpm.num_steps, variance_type=cfg.ddpm.model_var_type) + self.diffusion = DiffusionDiscretized(None, None, cfg) + # self.load_model(cfg) + + def load_model(self, model_path): + # model_path = cfg.ckpt.path + ckpt = torch.load(model_path) + self.priors.load_state_dict(ckpt['dae_state_dict']) + self.vae.load_state_dict(ckpt['vae_state_dict']) + print(f'INFO finish loading from {model_path}') + + @torch.no_grad() + def sample(self, num_samples=10, clip_feat=None, save_img=False): + self.scheduler.set_timesteps(1000, device='cuda') + timesteps = self.scheduler.timesteps + latent_shape = self.vae.latent_shape() + global_prior, local_prior = self.priors[0], self.priors[1] + assert(not local_prior.mixed_prediction and not global_prior.mixed_prediction) + sampled_list = [] + output_dict = {} + + # start sample global prior + x_T_shape = [num_samples] + latent_shape[0] + x_noisy = torch.randn(size=x_T_shape, device='cuda') + condition_input = None + for i, t in enumerate(timesteps): + t_tensor = torch.ones(num_samples, dtype=torch.int64, device='cuda') * (t+1) + noise_pred = global_prior(x=x_noisy, t=t_tensor.float(), + condition_input=condition_input, clip_feat=clip_feat) + x_noisy = self.scheduler.step(noise_pred, t, x_noisy).prev_sample + sampled_list.append(x_noisy) + output_dict['z_global'] = x_noisy + + condition_input = x_noisy + condition_input = self.vae.global2style(condition_input) + + # start sample local prior + x_T_shape = [num_samples] + latent_shape[1] + x_noisy = torch.randn(size=x_T_shape, device='cuda') + + for i, t in enumerate(timesteps): + t_tensor = torch.ones(num_samples, dtype=torch.int64, device='cuda') * (t+1) + noise_pred = local_prior(x=x_noisy, t=t_tensor.float(), + condition_input=condition_input, clip_feat=clip_feat) + x_noisy = self.scheduler.step(noise_pred, t, x_noisy).prev_sample + sampled_list.append(x_noisy) + output_dict['z_local'] = x_noisy + + # decode the latent + output = self.vae.sample(num_samples=num_samples, decomposed_eps=sampled_list) + if save_img: + out_name = plot_points(output, "/tmp/tmp.png") + print(f'INFO save plot image at {out_name}') + output_dict['points'] = output + return output_dict + + def get_mixing_component(self, noise_pred, t): + # usage: + # if global_prior.mixed_prediction: + # mixing_component = self.get_mixing_component(noise_pred, t) + # coeff = torch.sigmoid(global_prior.mixing_logit) + # noise_pred = (1 - coeff) * mixing_component + coeff * noise_pred + + alpha_bar = self.scheduler.alphas_cumprod[t] + one_minus_alpha_bars_sqrt = np.sqrt(1.0 - alpha_bar) + return noise_pred * one_minus_alpha_bars_sqrt diff --git a/models/pvcnn2.py b/models/pvcnn2.py new file mode 100644 index 0000000..f7d5fa6 --- /dev/null +++ b/models/pvcnn2.py @@ -0,0 +1,557 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +""" +copied and modified from source: + https://github.com/alexzhou907/PVD/blob/9747265a5f141e5546fd4f862bfa66aa59f1bd33/model/pvcnn_generation.py + and functions under + https://github.com/alexzhou907/PVD/tree/9747265a5f141e5546fd4f862bfa66aa59f1bd33/modules +""" +import copy +import functools +from loguru import logger +from einops import rearrange +import torch.nn as nn +import torch +import numpy as np +import third_party.pvcnn.functional as F +from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd + +class SE3d(nn.Module): + def __init__(self, channel, reduction=8): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid() + ) + self.channel = channel + def __repr__(self): + return f"SE({self.channel}, {self.channel})" + def forward(self, inputs): + return inputs * self.fc(inputs.mean(-1).mean(-1).mean(-1)).view(inputs.shape[0], inputs.shape[1], 1, 1, 1) + +class LinearAttention(nn.Module): + """ + copied and modified from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L159 + """ + def __init__(self, dim, heads = 4, dim_head = 32, verbose=True): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + ''' + Args: + x: torch.tensor (B,C,N), C=num-channels, N=num-points + Returns: + out: torch.tensor (B,C,N) + ''' + x = x.unsqueeze(-1) # add w dimension + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + out = self.to_out(out) + out = out.squeeze(-1) # B,C,N,1 -> B,C,N + return out + + +def swish(input): + return input * torch.sigmoid(input) + + +class Swish(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return swish(input) + + +class BallQuery(nn.Module): + def __init__(self, radius, num_neighbors, include_coordinates=True): + super().__init__() + self.radius = radius + self.num_neighbors = num_neighbors + self.include_coordinates = include_coordinates + + @custom_bwd + def backward(self, *args, **kwargs): + return super().backward(*args, **kwargs) + + @custom_fwd(cast_inputs=torch.float32) + def forward(self, points_coords, centers_coords, points_features=None): + # input: BCN, BCN + # returns: + # neighbor_features: B,D(+3),Ncenter + points_coords = points_coords.contiguous() + centers_coords = centers_coords.contiguous() + neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors) + neighbor_coordinates = F.grouping(points_coords, neighbor_indices) + neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1) + + if points_features is None: + assert self.include_coordinates, 'No Features For Grouping' + neighbor_features = neighbor_coordinates + else: + neighbor_features = F.grouping(points_features, neighbor_indices) + if self.include_coordinates: + neighbor_features = torch.cat([neighbor_coordinates, neighbor_features], dim=1) + return neighbor_features + + def extra_repr(self): + return 'radius={}, num_neighbors={}{}'.format( + self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '') + +class SharedMLP(nn.Module): + def __init__(self, in_channels, out_channels, dim=1): + super().__init__() + if dim==1: + conv = nn.Conv1d + else: + conv = nn.Conv2d + bn = nn.GroupNorm + if not isinstance(out_channels, (list, tuple)): + out_channels = [out_channels] + layers = [] + for oc in out_channels: + layers.append( conv(in_channels, oc, 1)) + layers.append(bn(8, oc)) + layers.append(Swish()) + in_channels = oc + self.layers = nn.Sequential(*layers) + + def forward(self, inputs): + if isinstance(inputs, (list, tuple)): + return (self.layers(inputs[0]), *inputs[1:]) + else: + return self.layers(inputs) + +class Voxelization(nn.Module): + def __init__(self, resolution, normalize=True, eps=0): + super().__init__() + self.r = int(resolution) + self.normalize = normalize + self.eps = eps + + def forward(self, features, coords): + # features: B,D,N + # coords: B,3,N + coords = coords.detach() + norm_coords = coords - coords.mean(2, keepdim=True) + if self.normalize: + norm_coords = norm_coords / (norm_coords.norm( + dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + + self.eps) + 0.5 + else: + norm_coords = (norm_coords + 1) / 2.0 + norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1) + vox_coords = torch.round(norm_coords).to(torch.int32) + if features is None: + return features, norm_coords + return F.avg_voxelize(features, vox_coords, self.r), norm_coords + + def extra_repr(self): + return 'resolution={}{}'.format( + self.r, + ', normalized eps = {}'.format(self.eps) if self.normalize else '') + +class PVConv(nn.Module): + def __init__(self, in_channels, out_channels, + kernel_size, resolution, + normalize=1, eps=0, with_se=False, + add_point_feat=True, attention=False, + dropout=0.1, verbose=True + ): + super().__init__() + self.resolution = resolution + self.voxelization = Voxelization(resolution, + normalize=normalize, + eps=eps) + # For each PVConv we use (Conv3d, GroupNorm(8), Swish, dropout, Conv3d, GroupNorm(8), Attention) + voxel_layers = [ + nn.Conv3d(in_channels, + out_channels, + kernel_size, stride=1, + padding=kernel_size // 2), + nn.GroupNorm(8, out_channels), + Swish(), + nn.Dropout(dropout), + nn.Conv3d(out_channels, out_channels, + kernel_size, stride=1, + padding=kernel_size // 2), + nn.GroupNorm(8, out_channels) + ] + if with_se: + voxel_layers.append(SE3d(out_channels)) + self.voxel_layers = nn.Sequential(*voxel_layers) + if attention: + self.attn = LinearAttention(out_channels, verbose=verbose) + else: + self.attn = None + if add_point_feat: + self.point_features = SharedMLP(in_channels, out_channels) #, **mlp_kwargs) + self.add_point_feat = add_point_feat + + def forward(self, inputs): + ''' + Args: + inputs: tuple of features and coords + features: B,feat-dim,num-points + coords: B,3, num-points + Returns: + fused_features: in (B,out-feat-dim,num-points) + coords : in (B, 3, num_points); same as the input coords + ''' + features = inputs[0] + coords_input = inputs[1] + time_emb = inputs[2] + ## features, coords_input, time_emb = inputs + if coords_input.shape[1] > 3: + coords = coords_input[:,:3] # the last 3 dim are other point attributes if any + else: + coords = coords_input + assert (features.shape[0] == coords.shape[0] + ), f'get feat: {features.shape} and {coords.shape}' + assert (features.shape[2] == coords.shape[2] + ), f'get feat: {features.shape} and {coords.shape}' + assert (coords.shape[1] == 3 + ), f'expect coords: B,3,Npoint, get: {coords.shape}' + # features: B,D,N; point_features + # coords: B,3,N + voxel_features_4d, voxel_coords = self.voxelization(features, coords) + r = self.resolution + B = coords.shape[0] + voxel_features_4d = self.voxel_layers(voxel_features_4d) + voxel_features = F.trilinear_devoxelize(voxel_features_4d, voxel_coords, + r, self.training) + + fused_features = voxel_features + if self.add_point_feat: + fused_features = fused_features + self.point_features(features) + if self.attn is not None: + fused_features = self.attn(fused_features) + if time_emb is None: + time_emb = {'voxel_features_4d': voxel_features_4d, 'resolution': self.resolution, 'training': self.training} + return fused_features, coords_input, time_emb #inputs[2] + + +class PointNetAModule(nn.Module): + def __init__(self, in_channels, out_channels, include_coordinates=True): + super().__init__() + if not isinstance(out_channels, (list, tuple)): + out_channels = [[out_channels]] + elif not isinstance(out_channels[0], (list, tuple)): + out_channels = [out_channels] + + mlps = [] + total_out_channels = 0 + for _out_channels in out_channels: + mlps.append( + SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0), + out_channels=_out_channels, dim=1) + ) + total_out_channels += _out_channels[-1] + + self.include_coordinates = include_coordinates + self.out_channels = total_out_channels + self.mlps = nn.ModuleList(mlps) + + def forward(self, inputs): + features, coords, time_emb = inputs + if self.include_coordinates: + features = torch.cat([features, coords], dim=1) + coords = torch.zeros((coords.size(0), 3, 1), device=coords.device) + if len(self.mlps) > 1: + features_list = [] + for mlp in self.mlps: + features_list.append(mlp(features).max(dim=-1, keepdim=True).values) + return torch.cat(features_list, dim=1), coords, time_emb + else: + return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords, time_emb + + def extra_repr(self): + return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}' + + +class PointNetSAModule(nn.Module): + def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True): + super().__init__() + + if not isinstance(radius, (list, tuple)): + radius = [radius] + if not isinstance(num_neighbors, (list, tuple)): + num_neighbors = [num_neighbors] * len(radius) + assert len(radius) == len(num_neighbors) + if not isinstance(out_channels, (list, tuple)): + out_channels = [[out_channels]] * len(radius) + elif not isinstance(out_channels[0], (list, tuple)): + out_channels = [out_channels] * len(radius) + assert len(radius) == len(out_channels) + + groupers, mlps = [], [] + total_out_channels = 0 + for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors): + groupers.append( + BallQuery(radius=_radius, num_neighbors=_num_neighbors, + include_coordinates=include_coordinates) + ) + # logger.info('create MLP: in_channel={}, out_channels={}', + # in_channels + (3 if include_coordinates else 0),_out_channels) + mlps.append( + SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0) , + out_channels=_out_channels, dim=2) + ) + total_out_channels += _out_channels[-1] + + self.num_centers = num_centers + self.out_channels = total_out_channels + self.groupers = nn.ModuleList(groupers) + self.mlps = nn.ModuleList(mlps) + + def forward(self, inputs): + # features, coords, _ = inputs + features = inputs[0] + coords = inputs[1] # B3N + if coords.shape[1] > 3: + coords = coords[:,:3] + + centers_coords = F.furthest_point_sample(coords, self.num_centers) + # centers_coords: B,D,N + S = centers_coords.shape[-1] + time_emb = inputs[2] + time_emb = time_emb[:,:,:S] if \ + time_emb is not None and type(time_emb) is not dict \ + else time_emb + + features_list = [] + c = 0 + for grouper, mlp in zip(self.groupers, self.mlps): + c += 1 + grouper_output = grouper(coords, centers_coords, features) + features_list.append( + mlp(grouper_output + ).max(dim=-1).values + ) + if len(features_list) > 1: + return torch.cat(features_list, dim=1), centers_coords, time_emb + else: + return features_list[0], centers_coords, time_emb + + def extra_repr(self): + return f'num_centers={self.num_centers}, out_channels={self.out_channels}' + + +class PointNetFPModule(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1) + + def forward(self, inputs): + if len(inputs) == 4: + points_coords, centers_coords, centers_features, time_emb = inputs + points_features = None + else: + points_coords, centers_coords, centers_features, points_features, time_emb = inputs + interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features) + if points_features is not None: + interpolated_features = torch.cat( + [interpolated_features, points_features], dim=1 + ) + if time_emb is not None: + B,D,S = time_emb.shape + N = points_coords.shape[-1] + time_emb = time_emb[:,:,0:1].expand(-1,-1,N) + return self.mlp(interpolated_features), points_coords, time_emb + +def _linear_gn_relu(in_channels, out_channels): + return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish()) + + +def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1): + r = width_multiplier + + if dim == 1: + block = _linear_gn_relu + else: + block = SharedMLP + if not isinstance(out_channels, (list, tuple)): + out_channels = [out_channels] + if len(out_channels) == 0 or (len(out_channels) == 1 and out_channels[0] is None): + return nn.Sequential(), in_channels, in_channels + + layers = [] + for oc in out_channels[:-1]: + if oc < 1: + layers.append(nn.Dropout(oc)) + else: + oc = int(r * oc) + layers.append(block(in_channels, oc)) + in_channels = oc + if dim == 1: + if classifier: + layers.append(nn.Linear(in_channels, out_channels[-1])) + else: + layers.append(_linear_gn_relu(in_channels, int(r * out_channels[-1]))) + else: + if classifier: + layers.append(nn.Conv1d(in_channels, out_channels[-1], 1)) + else: + layers.append(SharedMLP(in_channels, int(r * out_channels[-1]))) + return layers, out_channels[-1] if classifier else int(r * out_channels[-1]) + + +def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0, + width_multiplier=1, voxel_resolution_multiplier=1, verbose=True): + r, vr = width_multiplier, voxel_resolution_multiplier + + layers, concat_channels = [], 0 + c = 0 + for k, (out_channels, num_blocks, voxel_resolution) in enumerate(blocks): + out_channels = int(r * out_channels) + for p in range(num_blocks): + attention = k % 2 == 0 and k > 0 and p == 0 + if voxel_resolution is None: + block = SharedMLP + else: + block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, + with_se=with_se, normalize=normalize, eps=eps, verbose=verbose) + + if c == 0: + layers.append(block(in_channels, out_channels)) + else: + layers.append(block(in_channels+embed_dim, out_channels)) + in_channels = out_channels + concat_channels += out_channels + c += 1 + return layers, in_channels, concat_channels + + +def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, + input_dim=3, + embed_dim=64, use_att=False, force_att=0, + dropout=0.1, with_se=False, normalize=True, eps=0, has_temb=1, + width_multiplier=1, voxel_resolution_multiplier=1, verbose=True): + """ + Returns: + in_channels: the last output channels of the sa blocks + """ + r, vr = width_multiplier, voxel_resolution_multiplier + in_channels = extra_feature_channels + input_dim + + sa_layers, sa_in_channels = [], [] + c = 0 + num_centers = None + for conv_configs, sa_configs in sa_blocks: + k = 0 + sa_in_channels.append(in_channels) + sa_blocks = [] + + if conv_configs is not None: + out_channels, num_blocks, voxel_resolution = conv_configs + out_channels = int(r * out_channels) + for p in range(num_blocks): + attention = ( (c+1) % 2 == 0 and use_att and p == 0 ) or (force_att and c > 0) + if voxel_resolution is None: + block = SharedMLP + else: + block = functools.partial( + PVConv, kernel_size=3, + resolution=int(vr * voxel_resolution), attention=attention, + dropout=dropout, + with_se=with_se, + normalize=normalize, eps=eps, verbose=verbose) + + if c == 0: + sa_blocks.append(block(in_channels, out_channels)) + elif k ==0: + sa_blocks.append(block(in_channels+embed_dim*has_temb, out_channels)) + in_channels = out_channels + k += 1 + extra_feature_channels = in_channels + + if sa_configs is not None: + num_centers, radius, num_neighbors, out_channels = sa_configs + _out_channels = [] + for oc in out_channels: + if isinstance(oc, (list, tuple)): + _out_channels.append([int(r * _oc) for _oc in oc]) + else: + _out_channels.append(int(r * oc)) + out_channels = _out_channels + if num_centers is None: + block = PointNetAModule + else: + block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius, + num_neighbors=num_neighbors) + sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim*has_temb if k==0 else 0 ), + out_channels=out_channels, + include_coordinates=True)) + in_channels = extra_feature_channels = sa_blocks[-1].out_channels + c += 1 + + if len(sa_blocks) == 1: + sa_layers.append(sa_blocks[0]) + else: + sa_layers.append(nn.Sequential(*sa_blocks)) + + return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers + + +def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False, + dropout=0.1, has_temb=1, + with_se=False, normalize=True, eps=0, + width_multiplier=1, voxel_resolution_multiplier=1, + verbose=True): + r, vr = width_multiplier, voxel_resolution_multiplier + + fp_layers = [] + c = 0 + + for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks): + fp_blocks = [] + out_channels = tuple(int(r * oc) for oc in fp_configs) + fp_blocks.append( + PointNetFPModule(in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim*has_temb, + out_channels=out_channels) + ) + in_channels = out_channels[-1] + + if conv_configs is not None: + out_channels, num_blocks, voxel_resolution = conv_configs + out_channels = int(r * out_channels) + for p in range(num_blocks): + attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0 + if voxel_resolution is None: + block = SharedMLP + else: + block = functools.partial(PVConv, kernel_size=3, + resolution=int(vr * voxel_resolution), attention=attention, + dropout=dropout, + with_se=with_se, # with_se_relu=True, + normalize=normalize, eps=eps, + verbose=verbose) + + fp_blocks.append(block(in_channels, out_channels)) + in_channels = out_channels + if len(fp_blocks) == 1: + fp_layers.append(fp_blocks[0]) + else: + fp_layers.append(nn.Sequential(*fp_blocks)) + + c += 1 + + return fp_layers, in_channels + + diff --git a/models/pvcnn2_ada.py b/models/pvcnn2_ada.py new file mode 100644 index 0000000..69b4f7f --- /dev/null +++ b/models/pvcnn2_ada.py @@ -0,0 +1,568 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +""" +copied and modified from source: + https://github.com/alexzhou907/PVD/blob/9747265a5f141e5546fd4f862bfa66aa59f1bd33/model/pvcnn_generation.py + and functions under + https://github.com/alexzhou907/PVD/tree/9747265a5f141e5546fd4f862bfa66aa59f1bd33/modules +""" +import copy +import functools +from loguru import logger +from einops import rearrange +import torch.nn as nn +import torch +import numpy as np +import third_party.pvcnn.functional as F +# from utils.checker import * +from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd +from .adagn import AdaGN +import os +quiet = int(os.environ.get('quiet', 0)) +class SE3d(nn.Module): + def __init__(self, channel, reduction=8): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid() + ) + + self.channel = channel + def __repr__(self): + return f"SE({self.channel}, {self.channel})" + def forward(self, inputs): + return inputs * self.fc(inputs.mean(-1).mean(-1).mean(-1)).view(inputs.shape[0], inputs.shape[1], 1, 1, 1) + +class LinearAttention(nn.Module): + """ + copied and modified from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L159 + """ + def __init__(self, dim, heads = 4, dim_head = 32, verbose=True): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + ''' + Args: + x: torch.tensor (B,C,N), C=num-channels, N=num-points + Returns: + out: torch.tensor (B,C,N) + ''' + x = x.unsqueeze(-1) # add w dimension + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + out = self.to_out(out) + out = out.squeeze(-1) # B,C,N,1 -> B,C,N + return out + + +def swish(input): + return input * torch.sigmoid(input) + + +class Swish(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return swish(input) + + +class BallQuery(nn.Module): + def __init__(self, radius, num_neighbors, include_coordinates=True): + super().__init__() + self.radius = radius + self.num_neighbors = num_neighbors + self.include_coordinates = include_coordinates + + @custom_bwd + def backward(self, *args, **kwargs): + return super().backward(*args, **kwargs) + + @custom_fwd(cast_inputs=torch.float32) + def forward(self, points_coords, centers_coords, points_features=None): + # input: BCN, BCN + # neighbor_features: B,D(+3),Ncenter + points_coords = points_coords.contiguous() + centers_coords = centers_coords.contiguous() + neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors) + neighbor_coordinates = F.grouping(points_coords, neighbor_indices) + neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1) + + if points_features is None: + assert self.include_coordinates, 'No Features For Grouping' + neighbor_features = neighbor_coordinates + else: + neighbor_features = F.grouping(points_features, neighbor_indices) + if self.include_coordinates: + neighbor_features = torch.cat([neighbor_coordinates, neighbor_features], dim=1) + return neighbor_features + + def extra_repr(self): + return 'radius={}, num_neighbors={}{}'.format( + self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '') + +class SharedMLP(nn.Module): + def __init__(self, in_channels, out_channels, dim=1, cfg={}): + + assert(len(cfg) > 0), cfg + super().__init__() + if dim==1: + conv = nn.Conv1d + else: + conv = nn.Conv2d + bn = functools.partial(AdaGN, dim, cfg) + if not isinstance(out_channels, (list, tuple)): + out_channels = [out_channels] + layers = [] + for oc in out_channels: + layers.append(conv(in_channels, oc, 1)) + layers.append(bn(oc)) + layers.append(Swish()) + in_channels = oc + self.layers = nn.ModuleList(layers) + + def forward(self, *inputs): + if len(inputs) == 1 and len(inputs[0]) == 4: + # try to fix thwn SharedMLP is the first layer + inputs = inputs[0] + if len(inputs) == 1: + raise NotImplementedError + elif len(inputs) == 4: + assert(len(inputs) == 4), 'input, style' + x, _, _, style = inputs + for l in self.layers: + if isinstance(l, AdaGN): + x = l(x, style) + else: + x = l(x) + return (x, *inputs[1:]) + elif len(inputs) == 2: + x, style = inputs + for l in self.layers: + if isinstance(l, AdaGN): + x = l(x, style) + else: + x = l(x) + return x + else: + raise NotImplementedError + +class Voxelization(nn.Module): + def __init__(self, resolution, normalize=True, eps=0): + super().__init__() + self.r = int(resolution) + self.normalize = normalize + self.eps = eps + + def forward(self, features, coords): + # features: B,D,N + # coords: B,3,N + coords = coords.detach() + norm_coords = coords - coords.mean(2, keepdim=True) + if self.normalize: + norm_coords = norm_coords / (norm_coords.norm( + dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + + self.eps) + 0.5 + else: + norm_coords = (norm_coords + 1) / 2.0 + norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1) + vox_coords = torch.round(norm_coords).to(torch.int32) + if features is None: + return features, norm_coords + return F.avg_voxelize(features, vox_coords, self.r), norm_coords + + def extra_repr(self): + return 'resolution={}{}'.format( + self.r, + ', normalized eps = {}'.format(self.eps) if self.normalize else '') + +class PVConv(nn.Module): + def __init__(self, in_channels, out_channels, + kernel_size, resolution, + normalize=1, eps=0, with_se=False, + add_point_feat=True, attention=False, + dropout=0.1, verbose=True, + cfg={} + ): + super().__init__() + assert(len(cfg) > 0), cfg + self.resolution = resolution + self.voxelization = Voxelization(resolution, + normalize=normalize, + eps=eps) + # For each PVConv we use (Conv3d, GroupNorm(8), Swish, dropout, Conv3d, GroupNorm(8), Attention) + NormLayer = functools.partial(AdaGN, 3, cfg) + voxel_layers = [ + nn.Conv3d(in_channels , + out_channels, + kernel_size, stride=1, + padding=kernel_size // 2), + NormLayer(out_channels), + Swish(), + nn.Dropout(dropout), + nn.Conv3d(out_channels, out_channels, + kernel_size, stride=1, + padding=kernel_size // 2), + NormLayer(out_channels) + ] + if with_se: + voxel_layers.append(SE3d(out_channels)) + self.voxel_layers = nn.ModuleList(voxel_layers) + if attention: + self.attn = LinearAttention(out_channels, verbose=verbose) + else: + self.attn = None + if add_point_feat: + self.point_features = SharedMLP(in_channels, out_channels, cfg=cfg) + self.add_point_feat = add_point_feat + + def forward(self, inputs): + ''' + Args: + inputs: tuple of features and coords + features: B,feat-dim,num-points + coords: B,3, num-points + time_emd: B,D; time embedding + style: B,D; global latent + Returns: + fused_features: in (B,out-feat-dim,num-points) + coords : in (B, 3 or 6, num_points); same as the input coords + ''' + features = inputs[0] + coords_input= inputs[1] + time_emb = inputs[2] + style = inputs[3] + if coords_input.shape[1] > 3: + coords = coords_input[:,:3] + else: + coords = coords_input + assert (features.shape[0] == coords.shape[0] + ), f'get feat: {features.shape} and {coords.shape}' + assert (features.shape[2] == coords.shape[2] + ), f'get feat: {features.shape} and {coords.shape}' + assert (coords.shape[1] == 3 + ), f'expect coords: B,3,Npoint, get: {coords.shape}' + # features: B,D,N; point_features + # coords: B,3,N + voxel_features_4d, voxel_coords = self.voxelization(features, coords) + r = self.resolution + B = coords.shape[0] + + for voxel_layers in self.voxel_layers: + if isinstance(voxel_layers, AdaGN): + voxel_features_4d = voxel_layers(voxel_features_4d, style) + else: + voxel_features_4d = voxel_layers(voxel_features_4d) + voxel_features = F.trilinear_devoxelize(voxel_features_4d, voxel_coords, + r, self.training) + + fused_features = voxel_features + if self.add_point_feat: + fused_features = fused_features + self.point_features(features, style) + if self.attn is not None: + fused_features = self.attn(fused_features) + return fused_features, coords_input, time_emb, style + + +class PointNetAModule(nn.Module): + def __init__(self, in_channels, out_channels, include_coordinates=True, cfg={}): + super().__init__() + if not isinstance(out_channels, (list, tuple)): + out_channels = [[out_channels]] + elif not isinstance(out_channels[0], (list, tuple)): + out_channels = [out_channels] + + mlps = [] + total_out_channels = 0 + for _out_channels in out_channels: + mlps.append( + SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0), + out_channels=_out_channels, dim=1, cfg=cfg) + ) + total_out_channels += _out_channels[-1] + + self.include_coordinates = include_coordinates + self.out_channels = total_out_channels + self.mlps = nn.ModuleList(mlps) + + def forward(self, inputs): + features, coords, time_emb, style = inputs + if self.include_coordinates: + features = torch.cat([features, coords], dim=1) + coords = torch.zeros((coords.size(0), 3, 1), device=coords.device) + if len(self.mlps) > 1: + features_list = [] + for mlp in self.mlps: + features_list.append(mlp(features, style).max(dim=-1, keepdim=True).values) + return torch.cat(features_list, dim=1), coords, time_emb + else: + return self.mlps[0](features, style).max(dim=-1, keepdim=True).values, coords, time_emb + + def extra_repr(self): + return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}' + + +class PointNetSAModule(nn.Module): + def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True, + cfg={}): + super().__init__() + if not isinstance(radius, (list, tuple)): + radius = [radius] + if not isinstance(num_neighbors, (list, tuple)): + num_neighbors = [num_neighbors] * len(radius) + assert len(radius) == len(num_neighbors) + if not isinstance(out_channels, (list, tuple)): + out_channels = [[out_channels]] * len(radius) + elif not isinstance(out_channels[0], (list, tuple)): + out_channels = [out_channels] * len(radius) + assert len(radius) == len(out_channels) + + groupers, mlps = [], [] + total_out_channels = 0 + for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors): + groupers.append( + BallQuery(radius=_radius, num_neighbors=_num_neighbors, + include_coordinates=include_coordinates) + ) + mlps.append( + SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0), + out_channels=_out_channels, dim=2, cfg=cfg) + ) + total_out_channels += _out_channels[-1] + + self.num_centers = num_centers + self.out_channels = total_out_channels + self.groupers = nn.ModuleList(groupers) + self.mlps = nn.ModuleList(mlps) + + def forward(self, inputs): + features = inputs[0] + coords = inputs[1] # B3N + style = inputs[3] + if coords.shape[1] > 3: + coords = coords[:,:3] + + centers_coords = F.furthest_point_sample(coords, self.num_centers) + # centers_coords: B,D,N + S = centers_coords.shape[-1] + time_emb = inputs[2] + time_emb = time_emb[:,:,:S] if \ + time_emb is not None and type(time_emb) is not dict \ + else time_emb + + features_list = [] + c = 0 + for grouper, mlp in zip(self.groupers, self.mlps): + c += 1 + grouper_output = grouper(coords, centers_coords, features ) + features_list.append( + mlp(grouper_output, style + ).max(dim=-1).values + ) + + if len(features_list) > 1: + return torch.cat(features_list, dim=1), centers_coords, time_emb, style + else: + return features_list[0], centers_coords, time_emb, style + + def extra_repr(self): + return f'num_centers={self.num_centers}, out_channels={self.out_channels}' + + +class PointNetFPModule(nn.Module): + def __init__(self, in_channels, out_channels, cfg={}): + super().__init__() + self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1, cfg=cfg) + + def forward(self, inputs): + if len(inputs) == 5: + points_coords, centers_coords, centers_features, time_emb, style = inputs + points_features = None + elif len(inputs) == 6: + points_coords, centers_coords, centers_features, points_features, time_emb, style = inputs + else: + raise NotImplementedError + + interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features) + if points_features is not None: + interpolated_features = torch.cat( + [interpolated_features, points_features], dim=1 + ) + if time_emb is not None: + B,D,S = time_emb.shape + N = points_coords.shape[-1] + time_emb = time_emb[:,:,0:1].expand(-1,-1,N) + return self.mlp(interpolated_features, style), points_coords, time_emb, style + +def _linear_gn_relu(in_channels, out_channels): + return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish()) + +def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1, cfg={}): + r = width_multiplier + + if dim == 1: + block = _linear_gn_relu + else: + block = SharedMLP + if not isinstance(out_channels, (list, tuple)): + out_channels = [out_channels] + if len(out_channels) == 0 or (len(out_channels) == 1 and out_channels[0] is None): + return nn.Sequential(), in_channels, in_channels + + layers = [] + for oc in out_channels[:-1]: + if oc < 1: + layers.append(nn.Dropout(oc)) + else: + oc = int(r * oc) + layers.append(block(in_channels, oc, cfg=cfg)) + in_channels = oc + if dim == 1: + if classifier: + layers.append(nn.Linear(in_channels, out_channels[-1])) + else: + layers.append(_linear_gn_relu(in_channels, int(r * out_channels[-1]))) + else: + if classifier: + layers.append(nn.Conv1d(in_channels, out_channels[-1], 1)) + else: + layers.append(SharedMLP(in_channels, int(r * out_channels[-1]))) + return layers, out_channels[-1] if classifier else int(r * out_channels[-1]) + +def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, + input_dim=3, + embed_dim=64, use_att=False, force_att=0, + dropout=0.1, with_se=False, normalize=True, eps=0, has_temb=1, + width_multiplier=1, voxel_resolution_multiplier=1, verbose=True, + cfg={}): + """ + Returns: + in_channels: the last output channels of the sa blocks + """ + assert(len(cfg) > 0), cfg + r, vr = width_multiplier, voxel_resolution_multiplier + in_channels = extra_feature_channels + input_dim + + sa_layers, sa_in_channels = [], [] + c = 0 + num_centers = None + for conv_configs, sa_configs in sa_blocks: + k = 0 + sa_in_channels.append(in_channels) + sa_blocks = [] + if conv_configs is not None: + out_channels, num_blocks, voxel_resolution = conv_configs + out_channels = int(r * out_channels) + for p in range(num_blocks): + attention = ( (c+1) % 2 == 0 and use_att and p == 0 ) or (force_att and c > 0) + if voxel_resolution is None: + block = SharedMLP + else: + block = functools.partial( + PVConv, kernel_size=3, + resolution=int(vr * voxel_resolution), attention=attention, + dropout=dropout, + with_se=with_se, # with_se_relu=True, + normalize=normalize, eps=eps, verbose=verbose, cfg=cfg) + + if c == 0: + sa_blocks.append(block(in_channels, out_channels, cfg=cfg)) + elif k ==0: + sa_blocks.append(block(in_channels+embed_dim*has_temb, out_channels, cfg=cfg)) + in_channels = out_channels + k += 1 + extra_feature_channels = in_channels + if sa_configs is not None: + num_centers, radius, num_neighbors, out_channels = sa_configs + _out_channels = [] + for oc in out_channels: + if isinstance(oc, (list, tuple)): + _out_channels.append([int(r * _oc) for _oc in oc]) + else: + _out_channels.append(int(r * oc)) + out_channels = _out_channels + if num_centers is None: + block = PointNetAModule + else: + block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius, + num_neighbors=num_neighbors) + sa_blocks.append(block(cfg=cfg, + in_channels=extra_feature_channels+(embed_dim*has_temb if k==0 else 0 ), + out_channels=out_channels, + include_coordinates=True)) + in_channels = extra_feature_channels = sa_blocks[-1].out_channels + c += 1 + + if len(sa_blocks) == 1: + sa_layers.append(sa_blocks[0]) + else: + sa_layers.append(nn.Sequential(*sa_blocks)) + + return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers + + +def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False, + dropout=0.1, has_temb=1, + with_se=False, normalize=True, eps=0, + width_multiplier=1, voxel_resolution_multiplier=1, + verbose=True, cfg={}): + assert(len(cfg) > 0), cfg + r, vr = width_multiplier, voxel_resolution_multiplier + + fp_layers = [] + c = 0 + + for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks): + fp_blocks = [] + out_channels = tuple(int(r * oc) for oc in fp_configs) + fp_blocks.append( + PointNetFPModule( + in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim*has_temb, + out_channels=out_channels, + cfg=cfg) + ) + in_channels = out_channels[-1] + + if conv_configs is not None: + out_channels, num_blocks, voxel_resolution = conv_configs + out_channels = int(r * out_channels) + for p in range(num_blocks): + attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0 + if voxel_resolution is None: + block = functools.partial(SharedMLP, cfg=cfg) + else: + block = functools.partial(PVConv, kernel_size=3, + resolution=int(vr * voxel_resolution), attention=attention, + dropout=dropout, + with_se=with_se, # with_se_relu=True, + normalize=normalize, eps=eps, + verbose=verbose, + cfg=cfg) + + fp_blocks.append(block(in_channels, out_channels)) + in_channels = out_channels + if len(fp_blocks) == 1: + fp_layers.append(fp_blocks[0]) + else: + fp_layers.append(nn.Sequential(*fp_blocks)) + + c += 1 + + return fp_layers, in_channels + diff --git a/models/score_sde/resnet.py b/models/score_sde/resnet.py new file mode 100644 index 0000000..1f02356 --- /dev/null +++ b/models/score_sde/resnet.py @@ -0,0 +1,230 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +""" implement the gloabl prior for LION +""" +import torch.nn as nn +from loguru import logger +import functools +import torch +from ..utils import init_temb_fun, mask_inactive_variables + +class SE(nn.Module): + def __init__(self, channel, reduction=8): + super().__init__() + self.fc = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, 1, bias=False), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, 1, bias=False), + nn.Sigmoid() + ) + + def forward(self, inputs): + return inputs * self.fc(inputs) + +class ResBlockSEClip(nn.Module): + """ + fixed the conv0 not used error in ResBlockSE + """ + def __init__(self, input_dim, output_dim): + super().__init__() + self.non_linearity = nn.ReLU(inplace=True) + self.input_dim = input_dim + self.output_dim = output_dim + self.conv1 = nn.Conv2d(input_dim*2, output_dim, 1, 1) + self.conv2 = nn.Conv2d(output_dim, output_dim, 1, 1) + in_ch = self.output_dim + self.SE = SE(in_ch) + def forward(self, x, t): + ## logger.info('x: {}, t: {}, input_dim={}', x.shape, t.shape, self.input_dim) + clip_feat = t[:, self.input_dim:].contiguous() + t = t[:,:self.input_dim].contiguous() + output = x + t + output = torch.cat([output, clip_feat], dim=1).contiguous() + output = self.conv1(output) + output = self.non_linearity(output) + output = self.conv2(output) + output = self.non_linearity(output) + output = self.SE(output) + shortcut = x + return shortcut + output + def __repr__(self): + return "ResBlockSEClip(%d, %d)"%(self.input_dim, self.output_dim) + + + +class ResBlockSEDrop(nn.Module): + """ + fixed the conv0 not used error in ResBlockSE + """ + + def __init__(self, input_dim, output_dim, dropout): + super().__init__() + self.non_linearity = nn.ReLU(inplace=True) + self.input_dim = input_dim + self.output_dim = output_dim + self.conv1 = nn.Conv2d(input_dim, output_dim, 1, 1) + self.conv2 = nn.Conv2d(output_dim, output_dim, 1, 1) + in_ch = self.output_dim + self.SE = SE(in_ch) + self.dropout = nn.Dropout(dropout) + self.dropout_ratio = dropout + + def forward(self, x, t): + output = x + t + output = self.conv1(output) + output = self.non_linearity(output) + output = self.dropout(output) + output = self.conv2(output) + output = self.non_linearity(output) + output = self.SE(output) + shortcut = x + return shortcut + output + + def __repr__(self): + return "ResBlockSE_withdropout(%d, %d, drop=%f)" % ( + self.input_dim, self.output_dim, self.dropout_ratio) + + +class ResBlock(nn.Module): + def __init__(self, input_dim, output_dim): + # resample=None, act=nn.ELU(), + # normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1): + super().__init__() + self.non_linearity = nn.ELU() + self.input_dim = input_dim + self.output_dim = output_dim + self.conv1 = nn.Conv2d(input_dim, output_dim, 1, 1) + self.conv2 = nn.Conv2d(output_dim, output_dim, 1, 1) + in_ch = self.output_dim + self.normalize1 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), + num_channels=in_ch, eps=1e-6) + self.normalize2 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), + num_channels=in_ch, eps=1e-6) + + def forward(self, x, t): + x = x + t + output = self.conv1(x) + output = self.normalize1(output) + output = self.non_linearity(output) + output = self.conv2(output) + output = self.normalize2(output) + output = self.non_linearity(output) + shortcut = x + return shortcut + output + + def __repr__(self): + return "ResBlock(%d, %d)" % (self.input_dim, self.output_dim) + + +class Prior(nn.Module): + building_block = ResBlock + + def __init__(self, args, num_input_channels, *oargs, **kwargs): + super().__init__() + # args: cfg.sde + # oargs: other argument: the global argument + self.condition_input = kwargs.get('condition_input', False) + self.cfg = oargs[0] + self.clip_forge_enable = self.cfg.clipforge.enable # kwargs.get('clipforge.enable', 0) + + logger.info('[Build Resnet Prior] Has condition input: {}; clipforge {}; ' + 'learn_mixing_logit={}, ', self.condition_input, + self.clip_forge_enable, args.learn_mixing_logit) + + self.act = act = nn.SiLU() + self.num_scales = args.num_scales_dae + self.num_input_channels = num_input_channels + + self.nf = nf = args.num_channels_dae + num_cell_per_scale_dae = args.num_cell_per_scale_dae if 'num_cell_per_scale_dae' not in kwargs else kwargs[ + 'num_cell_per_scale_dae'] + + # take clip feature as input + if self.clip_forge_enable: + self.clip_feat_mapping = nn.Conv1d(self.cfg.clipforge.feat_dim, self.nf, 1) + + # mixed_prediction # + self.mixed_prediction = args.mixed_prediction # This enables mixed prediction + if self.mixed_prediction: + logger.info('init-mixing_logit = {}, after sigmoid = {}', + args.mixing_logit_init, torch.sigmoid(torch.tensor(args.mixing_logit_init))) + assert(args.mixing_logit_init), f'require learning' + # if not args.learn_mixing_logit and args.hypara_mixing_logit: + # # not learn, treat it as hyparameters + # init = args.mixing_logit_init * torch.ones(size=[1, num_input_channels, 1, 1]) + # self.mixing_logit = torch.nn.Parameter(init, requires_grad=False) # not update + # self.is_active = None + # elif not args.learn_mixing_logit: # not learn, loaded from c04cd1h exp + # init = torch.load('../exp/1110/chair/c04cd1h_hvae3s_390f8dhInitSepesTrainvae0_hvaeB72l1E4W1/mlogit.pt') + # self.mixing_logit = torch.nn.Parameter(init, requires_grad=False) + # self.is_active = None + # else: + if True: + init = args.mixing_logit_init * torch.ones(size=[1, num_input_channels, 1, 1]) + self.mixing_logit = torch.nn.Parameter(init, requires_grad=True) + self.is_active = None + else: # no mixing_logit + self.mixing_logit = None + self.is_active = None + + self.embedding_dim = args.embedding_dim + self.embedding_dim_mult = 4 + self.temb_fun = init_temb_fun(args.embedding_type, args.embedding_scale, args.embedding_dim) + logger.info('[temb_fun] embedding_type={}, embedding_scale={}, embedding_dim={}', + args.embedding_type, args.embedding_scale, args.embedding_dim) + # exit() + modules = [] + modules.append(nn.Conv2d(self.embedding_dim, self.embedding_dim * 4, 1, 1)) + modules.append(nn.Conv2d(self.embedding_dim * 4, nf, 1, 1)) + self.temb_layer = nn.Sequential(*modules) + + modules = [] + input_channels = num_input_channels + self.input_layer = nn.Conv2d(input_channels, nf, 1, 1) + in_ch = nf + for i_block in range(args.num_cell_per_scale_dae): + modules.append(self.building_block(nf, nf)) + self.output_layer = nn.Conv2d(nf, input_channels, 1, 1) + self.all_modules = nn.ModuleList(modules) + + def forward(self, x, t, **kwargs): + # timestep/noise_level embedding; only for continuous training + # time embedding + if t.dim() == 0: + t = t.expand(1) + temb = self.temb_fun(t)[:, :, None, None] # make it 4d + temb = self.temb_layer(temb) + + if self.clip_forge_enable: + clip_feat = kwargs['clip_feat'] + clip_feat = self.clip_feat_mapping(clip_feat[:, :, None])[:, :, :, None] # B,D -> BD1->B,D,1,1 + if temb.shape[0] == 1 and temb.shape[0] < clip_feat.shape[0]: + temb = temb.expand(clip_feat.shape[0], -1, -1, -1) + temb = torch.cat([temb, clip_feat], dim=1) # add to temb feature + # mask out inactive variables + if self.mixed_prediction and self.is_active is not None: + x = mask_inactive_variables(x, self.is_active) + x = self.input_layer(x) + for layer in self.all_modules: + enc_input = x + x = layer(enc_input, temb) + + h = self.output_layer(x) + return h + + +class PriorSEDrop(Prior): + def __init__(self, *args, **kwargs): + self.building_block = functools.partial(ResBlockSEDrop, dropout=args[0].dropout) + super().__init__(*args, **kwargs) + +class PriorSEClip(Prior): + building_block = ResBlockSEClip + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + diff --git a/models/shapelatent_modules.py b/models/shapelatent_modules.py new file mode 100644 index 0000000..46a83e9 --- /dev/null +++ b/models/shapelatent_modules.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch.nn as nn +from loguru import logger +from .pvcnn2 import create_pointnet2_sa_components +# implement the global encoder for VAE model + +class PointNetPlusEncoder(nn.Module): + sa_blocks = [ + [[32, 2, 32], [1024, 0.1, 32, [32, 32]]], + [[32, 1, 16], [256, 0.2, 32, [32, 64]]] + ] + force_att = 0 # add attention to all layers + def __init__(self, zdim, input_dim, extra_feature_channels=0, args={}): + super().__init__() + sa_blocks = self.sa_blocks + layers, sa_in_channels, channels_sa_features, _ = \ + create_pointnet2_sa_components(sa_blocks, + extra_feature_channels, input_dim=input_dim, + embed_dim=0, force_att=self.force_att, + use_att=True, with_se=True) + self.mlp = nn.Linear(channels_sa_features, zdim*2) + self.zdim = zdim + logger.info('[Encoder] zdim={}, out_sigma={}; force_att: {}', zdim, True, self.force_att) + self.layers = nn.ModuleList(layers) + self.voxel_dim = [n[1][-1][-1] for n in self.sa_blocks] + + def forward(self, x): + """ + Args: + x: B,N,3 + Returns: + mu, sigma: B,D + """ + output = {} + x = x.transpose(1, 2) # B,3,N + xyz = x ## x[:,:3,:] + features = x + for layer_id, layer in enumerate(self.layers): + features, xyz, _ = layer( (features, xyz, None) ) + # features: B,D,N; xyz: B,3,N + + features = features.max(-1)[0] + features = self.mlp(features) + mu_1d, sigma_1d = features[:, :self.zdim], features[:, self.zdim:] + output.update({'mu_1d': mu_1d, 'sigma_1d': sigma_1d}) + return output + + diff --git a/models/utils.py b/models/utils.py new file mode 100644 index 0000000..bd500db --- /dev/null +++ b/models/utils.py @@ -0,0 +1,52 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch +import math +import torch.nn as nn + +def mask_inactive_variables(x, is_active): + x = x * is_active + return x + +class PositionalEmbedding(nn.Module): + def __init__(self, embedding_dim, scale): + super(PositionalEmbedding, self).__init__() + self.embedding_dim = embedding_dim + self.scale = scale + + def forward(self, timesteps): + assert len(timesteps.shape) == 1 + timesteps = timesteps * self.scale + half_dim = self.embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + return emb + + +class RandomFourierEmbedding(nn.Module): + def __init__(self, embedding_dim, scale): + super(RandomFourierEmbedding, self).__init__() + self.w = nn.Parameter(torch.randn(size=(1, embedding_dim // 2)) * scale, requires_grad=False) + + def forward(self, timesteps): + emb = torch.mm(timesteps[:, None], self.w * 2 * 3.14159265359) + return torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + + +def init_temb_fun(embedding_type, embedding_scale, embedding_dim): + if embedding_type == 'positional': + temb_fun = PositionalEmbedding(embedding_dim, embedding_scale) + elif embedding_type == 'fourier': + temb_fun = RandomFourierEmbedding(embedding_dim, embedding_scale) + else: + raise NotImplementedError + + return temb_fun diff --git a/models/vae_adain.py b/models/vae_adain.py new file mode 100644 index 0000000..325b8e2 --- /dev/null +++ b/models/vae_adain.py @@ -0,0 +1,339 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch +import numpy as np +from loguru import logger +import importlib +import torch.nn as nn +from .distributions import Normal +from utils.model_helper import import_model +from utils.model_helper import loss_fn +from utils import utils as helper + +class Model(nn.Module): + def __init__(self, args): + super().__init__() + self.num_total_iter = 0 + self.args = args + self.input_dim = args.ddpm.input_dim + latent_dim = args.shapelatent.latent_dim + self.latent_dim = latent_dim + self.kl_weight = args.shapelatent.kl_weight + + self.num_points = args.data.tr_max_sample_points + # ---- global ---- # + # build encoder + self.style_encoder = import_model(args.latent_pts.style_encoder)( + zdim=args.latent_pts.style_dim, + input_dim=self.input_dim, + args=args) + if len(args.latent_pts.style_mlp): + self.style_mlp = import_model(args.latent_pts.style_mlp)(args) + else: + self.style_mlp = None + + self.encoder = import_model(args.shapelatent.encoder_type)( + zdim=latent_dim, + input_dim=self.input_dim, + args=args) + + # build decoder + self.decoder = import_model(args.shapelatent.decoder_type)( + context_dim=latent_dim, + point_dim=args.ddpm.input_dim, + args=args) + logger.info('[Build Model] style_encoder: {}, encoder: {}, decoder: {}', + args.latent_pts.style_encoder, + args.shapelatent.encoder_type, + args.shapelatent.decoder_type) + + @torch.no_grad() + def encode(self, x, class_label=None): + batch_size, _, point_dim = x.size() + assert(x.shape[2] == self.input_dim), f'expect input in ' \ + f'[B,Npoint,PointDim={self.input_dim}], get: {x.shape}' + x_0_target = x + latent_list = [] + all_eps = [] + all_log_q = [] + if self.args.data.cond_on_cat: + assert(class_label is not None), f'require class label input for cond on cat' + cls_emb = self.class_embedding(class_label) + enc_input = x, cls_emb + else: + enc_input = x + + # ---- global style encoder ---- # + z = self.style_encoder(enc_input) + z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma + dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F) + z_global = dist.sample()[0] + all_eps.append(z_global) + all_log_q.append(dist.log_p(z_global)) + latent_list.append( [z_global, z_mu, z_sigma] ) + + # ---- original encoder ---- # + style = z_global # torch.cat([z_global, cls_emb], dim=1) if self.args.data.cond_on_cat else z_global + style = self.style_mlp(style) if self.style_mlp is not None else style + z = self.encoder([x, style]) + z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] + z_sigma = z_sigma - self.args.shapelatent.log_sigma_offset + dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F) + z_local = dist.sample()[0] + all_eps.append(z_local) + all_log_q.append(dist.log_p(z_local)) + latent_list.append( [z_local, z_mu, z_sigma] ) + all_eps = self.compose_eps(all_eps) + if self.args.data.cond_on_cat: + return all_eps, all_log_q, latent_list, cls_emb + else: + return all_eps, all_log_q, latent_list + + def compose_eps(self, all_eps): + return torch.cat(all_eps, dim=1) # style: [B,D1], latent pts: [B,ND2] + + def decompose_eps(self, all_eps): + eps_style = all_eps[:,:self.args.latent_pts.style_dim] + eps_local = all_eps[:,self.args.latent_pts.style_dim:] + return [eps_style, eps_local] + + def encode_global(self, x, class_label=None): + + batch_size, N, point_dim = x.size() + if self.args.data.cond_on_cat: + assert(class_label is not None), f'require class label input for cond on cat' + cls_emb = self.class_embedding(class_label) + enc_input = x, cls_emb + else: + enc_input = x + + z = self.style_encoder(enc_input) + z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma + dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F) + return dist + + def global2style(self, style): ##, cls_emb=None): + Ndim = len(style.shape) + if Ndim == 4: + style = style.squeeze(-1).squeeze(-1) + style = self.style_mlp(style) if self.style_mlp is not None else style + if Ndim == 4: + style = style.unsqueeze(-1).unsqueeze(-1) + return style + + def encode_local(self, x, style): + # ---- original encoder ---- # + z = self.encoder([x, style]) + z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma + z_sigma = z_sigma - self.args.shapelatent.log_sigma_offset + dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F) + return dist + + def recont(self, x, target=None, class_label=None, cls_emb=None): + batch_size, N, point_dim = x.size() + assert(x.shape[2] == self.input_dim), f'expect input in ' \ + f'[B,Npoint,PointDim={self.input_dim}], get: {x.shape}' + x_0_target = x if target is None else target + latent_list = [] + all_eps = [] + all_log_q = [] + + # ---- global style encoder ---- # + if self.args.data.cond_on_cat: + if class_label is not None: + assert(class_label is not None) + cls_emb = self.class_embedding(class_label) + else: + assert(cls_emb is not None) + + enc_input = x, cls_emb + else: + enc_input = x + z = self.style_encoder(enc_input) + z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma + dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F) + + z_global = dist.sample()[0] + all_eps.append(z_global) + all_log_q.append(dist.log_p(z_global)) + latent_list.append( [z_global, z_mu, z_sigma] ) + + # ---- original encoder ---- # + style = torch.cat([z_global, cls_emb], dim=1) if self.args.data.cond_on_cat else z_global + style = self.style_mlp(style) if self.style_mlp is not None else style + z = self.encoder([x, style]) + z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma + z_sigma = z_sigma - self.args.shapelatent.log_sigma_offset + dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F) + z_local = dist.sample()[0] + all_eps.append(z_local) + all_log_q.append(dist.log_p(z_local)) + latent_list.append( [z_local, z_mu, z_sigma] ) + + # ---- decoder ---- # + x_0_pred = self.decoder(None, beta=None, context=z_local, style=style) # (B,ncenter,3) + + make_4d = lambda x: x.unsqueeze(-1).unsqueeze(-1) if len(x.shape) == 2 else x.unsqueeze(-1) + all_eps = [make_4d(e) for e in all_eps] + all_log_q = [make_4d(e) for e in all_log_q] + + output = { + 'all_eps': all_eps, + 'all_log_q': all_log_q, + 'latent_list': latent_list, + 'x_0_pred':x_0_pred, + 'x_0_target': x_0_target, + 'x_t': torch.zeros_like(x_0_target), + 't': torch.zeros(batch_size), + 'x_0': x_0_target + } + output['hist/global_var'] = latent_list[0][2].exp() + + if 'LatentPoint' in self.args.shapelatent.decoder_type: + latent_shape = [batch_size, -1, self.latent_dim + self.input_dim] + if 'Hir' in self.args.shapelatent.decoder_type: + latent_pts = z_local[:,:-self.args.latent_pts.latent_dim_ext[0]].view(*latent_shape)[:,:,:3].contiguous().clone() + else: + latent_pts = z_local.view(*latent_shape)[:,:,:self.input_dim].contiguous().clone() + + output['vis/latent_pts'] = latent_pts.detach().cpu().view(batch_size, + -1, self.input_dim) # B,N,3 + output['final_pred'] = output['x_0_pred'] + return output + + def get_loss(self, x, writer=None, it=None, ## weight_loss_1=1, + noisy_input=None, class_label=None, **kwargs): + """ + shapelatent z ~ q(z|x_0) + and x_t ~ q(x_t|x_0, t), t ~ Uniform(T) + forward and get x_{t-1} ~ p(x_{t-1} | x_t, z) + Args: + x: Input point clouds, (B, N, d). + """ + ## kl_weight = self.kl_weight + if self.args.trainer.anneal_kl and self.num_total_iter > 0: + global_step = it + kl_weight = helper.kl_coeff(step=global_step, + total_step=self.args.sde.kl_anneal_portion_vada * self.num_total_iter, + constant_step=self.args.sde.kl_const_portion_vada * self.num_total_iter, + min_kl_coeff=self.args.sde.kl_const_coeff_vada, + max_kl_coeff=self.args.sde.kl_max_coeff_vada) + else: + kl_weight = self.kl_weight + + batch_size = x.shape[0] + # CHECKDIM(x, 2, self.input_dim) + assert(x.shape[2] == self.input_dim) + + inputs = noisy_input if noisy_input is not None else x + output = self.recont(inputs, target=x, class_label=class_label) + + x_0_pred, x_0_target = output['x_0_pred'], output['x_0_target'] + loss_0 = loss_fn(x_0_pred, x_0_target, self.args.ddpm.loss_type, + self.input_dim, batch_size).mean() + rec_loss = loss_0 + output['print/loss_0'] = loss_0 + output['rec_loss'] = rec_loss + + # Loss + ## z_global, z_sigma, z_mu = output['z_global'], output['z_sigma'], output['z_mu'] + kl_term_list = [] + weighted_kl_terms = [] + for pairs_id, pairs in enumerate(output['latent_list']): + cz, cmu, csigma = pairs + log_sigma = csigma + kl_term_close = (0.5*log_sigma.exp()**2 + + 0.5*cmu**2 - log_sigma - 0.5).view( + batch_size, -1) + if 'LatentPoint' in self.args.shapelatent.decoder_type and 'Hir' not in self.args.shapelatent.decoder_type: + if pairs_id == 1: + latent_shape = [batch_size, -1, self.latent_dim + self.input_dim] + kl_pt = kl_term_close.view(*latent_shape)[:,:,:self.input_dim] + kl_feat = kl_term_close.view(*latent_shape)[:,:,self.input_dim:] + weighted_kl_terms.append(kl_pt.sum(2).sum(1) * self.args.latent_pts.weight_kl_pt) + weighted_kl_terms.append(kl_feat.sum(2).sum(1) * self.args.latent_pts.weight_kl_feat) + + output['print/kl_pt%d'%pairs_id] = kl_pt.sum(2).sum(1) + output['print/kl_feat%d'%pairs_id] = kl_feat.sum(2).sum(1) + + output['print/z_var_pt%d'%pairs_id] = (log_sigma.view(*latent_shape)[:,:,:self.input_dim] + ).exp()**2 + output['print/z_var_feat%d'%pairs_id] = (log_sigma.view(*latent_shape)[:,:,self.input_dim:] + ).exp()**2 + output['print/z_mean_feat%d'%pairs_id] = cmu.view(*latent_shape)[:,:,self.input_dim:].mean() + elif pairs_id == 0: + kl_style = kl_term_close + weighted_kl_terms.append(kl_style.sum(-1) * self.args.latent_pts.weight_kl_glb) + + output['print/kl_glb%d'%pairs_id] = kl_style.sum(-1) + output['print/z_var_glb%d'%pairs_id] = (log_sigma).exp()**2 + + kl_term_close = kl_term_close.sum(-1) + kl_term_list.append(kl_term_close) + output['print/kl_%d'%pairs_id] = kl_term_close + output['print/z_mean_%d'%pairs_id] = cmu.mean() + output['print/z_mag_%d'%pairs_id] = cmu.abs().max() + # logger.info('log_sigma: {}, mean: {}', log_sigma.shape, (log_sigma.exp()**2).mean()) + output['print/z_var_%d'%pairs_id] = (log_sigma).exp()**2 + output['print/z_logsigma_%d'%pairs_id] = log_sigma + output['print/kl_weight'] = kl_weight + + + loss_recons = rec_loss + if len(weighted_kl_terms) > 0: + kl = kl_weight * sum(weighted_kl_terms) + else: + kl = kl_weight * sum(kl_term_list) + loss = kl + loss_recons * self.args.weight_recont + output['msg/kl'] = kl + output['msg/rec'] = loss_recons + output['loss'] = loss + return output + + def pz(self, w): + return w + + def sample(self, num_samples=10, temp=None, decomposed_eps=[], + enable_autocast=False, device_str='cuda', cls_emb=None): + """ currently not support the samples of local level + Return: + model_output: [B,N,D] + """ + batch_size = num_samples + center_emd = None + if 'LatentPoint' in self.args.shapelatent.decoder_type: + # Latent Point Model: latent shape; B; ND + latent_shape = (num_samples, self.num_points*(self.latent_dim+self.input_dim)) + style_latent_shape = (num_samples, self.args.latent_pts.style_dim) + else: + raise NotImplementedError + + if len(decomposed_eps) == 0: + z_local = torch.zeros(*latent_shape).to( + torch.device(device_str)).normal_() + z_global = torch.zeros(*style_latent_shape).to( + torch.device(device_str)).normal_() + else: + z_global = decomposed_eps[0] + z_local = decomposed_eps[1] + + z_local = z_local.view(*latent_shape) + z_global = z_global.view(style_latent_shape) + + style = z_global + style = self.style_mlp(style) if self.style_mlp is not None else style + x_0_pred = self.decoder(None, beta=None, + context=z_local, style=z_global) # (B,ncenter,3) + ## CHECKSIZE(x_0_pred, (batch_size,self.num_points,[3,6])) + return x_0_pred + + def latent_shape(self): + return [ + [self.args.latent_pts.style_dim, 1, 1], + [self.num_points*(self.latent_dim+self.input_dim),1,1] + ] diff --git a/script/compute_score.py b/script/compute_score.py new file mode 100644 index 0000000..ff0498a --- /dev/null +++ b/script/compute_score.py @@ -0,0 +1,43 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import sys +sys.path.append('.') +from utils.eval_helper import compute_score +# samples = sys.argv[1] +# ref = sys.argv[2] + +samples = './lion_ckpt/unconditional/car/samples.pt' +ref = './datasets/test_data/ref_val_car.pt' +compute_score(samples, ref_name=ref) +""" +will get: +[Test] MinMatDis | CD 0.000913 | EMD 0.007523 +[Test] Coverage | CD 0.500000 | EMD 0.565341 +[Test] 1NN-Accur | CD 0.534091 | EMD 0.511364 +[Test] JsnShnDis | 0.009229 +""" + +samples = './lion_ckpt/unconditional/chair/samples.pt' +ref = './datasets/test_data/ref_val_chair.pt' +compute_score(samples, ref_name=ref) +""" +[Test] MinMatDis | CD 0.002643 | EMD 0.015516 +[Test] Coverage | CD 0.489426 | EMD 0.521148 +[Test] 1NN-Accur | CD 0.537009 | EMD 0.523414 +[Test] JsnShnDis | 0.013535 +""" + +samples = './lion_ckpt/unconditional/chair/samples.pt' +ref = './datasets/test_data/ref_val_chair.pt' +compute_score(samples, ref_name=ref) +""" +[Test] MinMatDis | CD 0.000221 | EMD 0.003706 +[Test] Coverage | CD 0.471605 | EMD 0.496296 +[Test] 1NN-Accur | CD 0.674074 | EMD 0.612346 +[Test] JsnShnDis | 0.060703 +""" diff --git a/third_party/ChamferDistancePytorch/.gitignore b/third_party/ChamferDistancePytorch/.gitignore new file mode 100644 index 0000000..35b4be2 --- /dev/null +++ b/third_party/ChamferDistancePytorch/.gitignore @@ -0,0 +1,3 @@ +*__pycache__* +/tmp +tmp/* diff --git a/third_party/ChamferDistancePytorch/LICENSE b/third_party/ChamferDistancePytorch/LICENSE new file mode 100644 index 0000000..794e2df --- /dev/null +++ b/third_party/ChamferDistancePytorch/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 ThibaultGROUEIX + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/ChamferDistancePytorch/README.md b/third_party/ChamferDistancePytorch/README.md new file mode 100755 index 0000000..f31b76f --- /dev/null +++ b/third_party/ChamferDistancePytorch/README.md @@ -0,0 +1,104 @@ +* adapted from https://github.com/ThibaultGROUEIX/ChamferDistancePytorch + +---------------------------------- +# Pytorch Chamfer Distance. + +Include a **CUDA** version, and a **PYTHON** version with pytorch standard operations. +NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt thresholds accordingly. + +- [x] F - Score + + + +### CUDA VERSION + +- [x] JIT compilation +- [x] Supports multi-gpu +- [x] 2D point clouds. +- [x] 3D point clouds. +- [x] 5D point clouds. +- [x] Contiguous() safe. + + + +### Python Version + +- [x] Supports any dimension + + + +### Usage + +```python +import torch, chamfer3D.dist_chamfer_3D, fscore +chamLoss = chamfer3D.dist_chamfer_3D.chamfer_3DDist() +points1 = torch.rand(32, 1000, 3).cuda() +points2 = torch.rand(32, 2000, 3, requires_grad=True).cuda() +dist1, dist2, idx1, idx2 = chamLoss(points1, points2) +f_score, precision, recall = fscore.fscore(dist1, dist2) +``` + + + +### Add it to your project as a submodule + +```shell +git submodule add https://github.com/ThibaultGROUEIX/ChamferDistancePytorch +``` + + + +### Benchmark: [forward + backward] pass +- [x] CUDA 10.1, NVIDIA 435, Pytorch 1.4 +- [x] p1 : 32 x 2000 x dim +- [x] p2 : 32 x 1000 x dim + +| *Timing (sec * 1000)* | 2D | 3D | 5D | +| ---------- | -------- | ------- | ------- | +| **Cuda Compiled** | **1.2** | 1.4 |1.8 | +| **Cuda JIT** | 1.3 | **1.4** |**1.5** | +| **Python** | 37 | 37 | 37 | + + +| *Memory (MB)* | 2D | 3D | 5D | +| ---------- | -------- | ------- | ------- | +| **Cuda Compiled** | 529 | 529 | 549 | +| **Cuda JIT** | **520** | **529** |**549** | +| **Python** | 2495 | 2495 | 2495 | + + + +### What is the chamfer distance ? + +[Stanford course](http://graphics.stanford.edu/courses/cs468-17-spring/LectureSlides/L14%20-%203d%20deep%20learning%20on%20point%20cloud%20representation%20(analysis).pdf) on 3D deep Learning + + + +### Aknowledgment + +Original backbone from [Fei Xia](https://github.com/fxia22/pointGAN/blob/master/nndistance/src/nnd_cuda.cu). + +JIT cool trick from [Christian Diller](https://github.com/chrdiller) + +### Troubleshoot + +- `Undefined symbol: Zxxxxxxxxxxxxxxxxx `: + +--> Fix: Make sure to `import torch` before you `import chamfer`. +--> Use pytorch.version >= 1.1.0 + +- [RuntimeError: Ninja is required to load C++ extension](https://github.com/zhanghang1989/PyTorch-Encoding/issues/167) + +```shell +wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip +sudo unzip ninja-linux.zip -d /usr/local/bin/ +sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force +``` + + + + + +#### TODO: + +* Discuss behaviour of torch.min() and tensor.min() which causes issues in some pytorch versions diff --git a/third_party/ChamferDistancePytorch/chamfer2D/chamfer2D.cu b/third_party/ChamferDistancePytorch/chamfer2D/chamfer2D.cu new file mode 100755 index 0000000..567dd1a --- /dev/null +++ b/third_party/ChamferDistancePytorch/chamfer2D/chamfer2D.cu @@ -0,0 +1,182 @@ + +#include +#include + +#include +#include + +#include + + + +__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ + const int batch=512; + __shared__ float buf[batch*2]; + for (int i=blockIdx.x;ibest){ + result[(i*n+j)]=best; + result_i[(i*n+j)]=best_i; + } + } + __syncthreads(); + } + } +} +// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ + + const auto batch_size = xyz1.size(0); + const auto n = xyz1.size(1); //num_points point cloud A + const auto m = xyz2.size(1); //num_points point cloud B + + NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); + NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + + +} +__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ + for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); + NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + +} + diff --git a/third_party/ChamferDistancePytorch/chamfer2D/chamfer_cuda.cpp b/third_party/ChamferDistancePytorch/chamfer2D/chamfer_cuda.cpp new file mode 100755 index 0000000..67574e2 --- /dev/null +++ b/third_party/ChamferDistancePytorch/chamfer2D/chamfer_cuda.cpp @@ -0,0 +1,33 @@ +#include +#include + +///TMP +//#include "common.h" +/// NOT TMP + + +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); + + +int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); + + + + +int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { + return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); +} + + +int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, + at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { + + return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); + m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); +} \ No newline at end of file diff --git a/third_party/ChamferDistancePytorch/chamfer2D/dist_chamfer_2D.py b/third_party/ChamferDistancePytorch/chamfer2D/dist_chamfer_2D.py new file mode 100644 index 0000000..b013642 --- /dev/null +++ b/third_party/ChamferDistancePytorch/chamfer2D/dist_chamfer_2D.py @@ -0,0 +1,80 @@ +from torch import nn +from torch.autograd import Function +import torch +import importlib +import os +chamfer_found = importlib.find_loader("chamfer_2D") is not None +if not chamfer_found: + ## Cool trick from https://github.com/chrdiller + print("Jitting Chamfer 2D") + cur_path = os.path.dirname(os.path.abspath(__file__)) + build_path = cur_path.replace('chamfer2D', 'tmp') + os.makedirs(build_path, exist_ok=True) + + from torch.utils.cpp_extension import load + chamfer_2D = load(name="chamfer_2D", + sources=[ + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer2D.cu"]), + ], build_directory=build_path) + print("Loaded JIT 2D CUDA chamfer distance") + +else: + import chamfer_2D + print("Loaded compiled 2D CUDA chamfer distance") + +# Chamfer's distance module @thibaultgroueix +# GPU tensors only +class chamfer_2DFunction(Function): + @staticmethod + def forward(ctx, xyz1, xyz2): + batchsize, n, dim = xyz1.size() + assert dim==2, "Wrong last dimension for the chamfer distance 's input! Check with .size()" + _, m, dim = xyz2.size() + assert dim==2, "Wrong last dimension for the chamfer distance 's input! Check with .size()" + device = xyz1.device + + device = xyz1.device + + dist1 = torch.zeros(batchsize, n) + dist2 = torch.zeros(batchsize, m) + + idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) + idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) + + dist1 = dist1.to(device) + dist2 = dist2.to(device) + idx1 = idx1.to(device) + idx2 = idx2.to(device) + torch.cuda.set_device(device) + + chamfer_2D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) + ctx.save_for_backward(xyz1, xyz2, idx1, idx2) + return dist1, dist2, idx1, idx2 + + @staticmethod + def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): + xyz1, xyz2, idx1, idx2 = ctx.saved_tensors + graddist1 = graddist1.contiguous() + graddist2 = graddist2.contiguous() + device = graddist1.device + + gradxyz1 = torch.zeros(xyz1.size()) + gradxyz2 = torch.zeros(xyz2.size()) + + gradxyz1 = gradxyz1.to(device) + gradxyz2 = gradxyz2.to(device) + chamfer_2D.backward( + xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 + ) + return gradxyz1, gradxyz2 + + +class chamfer_2DDist(nn.Module): + def __init__(self): + super(chamfer_2DDist, self).__init__() + + def forward(self, input1, input2): + input1 = input1.contiguous() + input2 = input2.contiguous() + return chamfer_2DFunction.apply(input1, input2) diff --git a/third_party/ChamferDistancePytorch/chamfer2D/setup.py b/third_party/ChamferDistancePytorch/chamfer2D/setup.py new file mode 100755 index 0000000..11d0123 --- /dev/null +++ b/third_party/ChamferDistancePytorch/chamfer2D/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='chamfer_2D', + ext_modules=[ + CUDAExtension('chamfer_2D', [ + "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), + "/".join(__file__.split('/')[:-1] + ['chamfer2D.cu']), + ]), + ], + cmdclass={ + 'build_ext': BuildExtension + }) \ No newline at end of file diff --git a/third_party/ChamferDistancePytorch/chamfer3D/chamfer3D.cu b/third_party/ChamferDistancePytorch/chamfer3D/chamfer3D.cu new file mode 100755 index 0000000..d5b886d --- /dev/null +++ b/third_party/ChamferDistancePytorch/chamfer3D/chamfer3D.cu @@ -0,0 +1,196 @@ + +#include +#include + +#include +#include + +#include + + + +__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ + const int batch=512; + __shared__ float buf[batch*3]; + for (int i=blockIdx.x;ibest){ + result[(i*n+j)]=best; + result_i[(i*n+j)]=best_i; + } + } + __syncthreads(); + } + } +} +// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ + + const auto batch_size = xyz1.size(0); + const auto n = xyz1.size(1); //num_points point cloud A + const auto m = xyz2.size(1); //num_points point cloud B + + NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); + NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + + +} +__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ + for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); + NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + +} + diff --git a/third_party/ChamferDistancePytorch/chamfer3D/chamfer_cuda.cpp b/third_party/ChamferDistancePytorch/chamfer3D/chamfer_cuda.cpp new file mode 100755 index 0000000..67574e2 --- /dev/null +++ b/third_party/ChamferDistancePytorch/chamfer3D/chamfer_cuda.cpp @@ -0,0 +1,33 @@ +#include +#include + +///TMP +//#include "common.h" +/// NOT TMP + + +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); + + +int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); + + + + +int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { + return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); +} + + +int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, + at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { + + return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); + m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); +} \ No newline at end of file diff --git a/third_party/ChamferDistancePytorch/chamfer3D/dist_chamfer_3D.py b/third_party/ChamferDistancePytorch/chamfer3D/dist_chamfer_3D.py new file mode 100644 index 0000000..89a248c --- /dev/null +++ b/third_party/ChamferDistancePytorch/chamfer3D/dist_chamfer_3D.py @@ -0,0 +1,133 @@ +from torch import nn +from torch.autograd import Function +import torch +import importlib +import os +from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd +cur_path = os.path.dirname(os.path.abspath(__file__)) +build_path = cur_path.replace('chamfer3D', 'tmp') +os.makedirs(build_path, exist_ok=True) + +from torch.utils.cpp_extension import load +chamfer_3D = load(name="chamfer_3D", + sources=[ + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]), + ], build_directory=build_path) + +#chamfer_found = importlib.find_loader("chamfer_3D") is not None +#if not chamfer_found: +# ## Cool trick from https://github.com/chrdiller +# print("Jitting Chamfer 3D") +# cur_path = os.path.dirname(os.path.abspath(__file__)) +# build_path = cur_path.replace('chamfer3D', 'tmp') +# os.makedirs(build_path, exist_ok=True) +# +# from torch.utils.cpp_extension import load +# chamfer_3D = load(name="chamfer_3D", +# sources=[ +# "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), +# "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]), +# ], build_directory=build_path) +# print("Loaded JIT 3D CUDA chamfer distance") +# +#else: +# import chamfer_3D +# print("Loaded compiled 3D CUDA chamfer distance") + + +# Chamfer's distance module @thibaultgroueix +# GPU tensors only +class chamfer_3DFunction(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, xyz1, xyz2): + batchsize, n, dim = xyz1.size() + assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()" + _, m, dim = xyz2.size() + assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()" + device = xyz1.device + + device = xyz1.device + + dist1 = torch.zeros(batchsize, n) + dist2 = torch.zeros(batchsize, m) + + idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) + idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) + + dist1 = dist1.to(device) + dist2 = dist2.to(device) + idx1 = idx1.to(device) + idx2 = idx2.to(device) + torch.cuda.set_device(device) + + chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) + ctx.save_for_backward(xyz1, xyz2, idx1, idx2) + return dist1, dist2, idx1, idx2 + + @staticmethod + @custom_bwd + def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): + xyz1, xyz2, idx1, idx2 = ctx.saved_tensors + graddist1 = graddist1.contiguous() + graddist2 = graddist2.contiguous() + device = graddist1.device + + gradxyz1 = torch.zeros(xyz1.size()) + gradxyz2 = torch.zeros(xyz2.size()) + + gradxyz1 = gradxyz1.to(device) + gradxyz2 = gradxyz2.to(device) + chamfer_3D.backward( + xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 + ) + return gradxyz1, gradxyz2 + + +class chamfer_3DDist(nn.Module): + def __init__(self): + super(chamfer_3DDist, self).__init__() + + def forward(self, input1, input2): + input1 = input1.contiguous() + input2 = input2.contiguous() + return chamfer_3DFunction.apply(input1, input2) + + +# Chamfer's distance module @thibaultgroueix +# GPU tensors only +class chamfer_3DFunction_noGrad(Function): + @staticmethod + def forward(ctx, xyz1, xyz2): + batchsize, n, dim = xyz1.size() + assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()" + _, m, dim = xyz2.size() + assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()" + device = xyz1.device + + device = xyz1.device + + dist1 = torch.zeros(batchsize, n) + dist2 = torch.zeros(batchsize, m) + + idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) + idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) + + dist1 = dist1.to(device) + dist2 = dist2.to(device) + idx1 = idx1.to(device) + idx2 = idx2.to(device) + torch.cuda.set_device(device) + + chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) + return dist1, dist2, idx1, idx2 + +class chamfer_3DDist_nograd(nn.Module): + def __init__(self): + super(chamfer_3DDist_nograd, self).__init__() + + def forward(self, input1, input2): + input1 = input1.contiguous() + input2 = input2.contiguous() + return chamfer_3DFunction_noGrad.apply(input1, input2) diff --git a/third_party/ChamferDistancePytorch/chamfer3D/setup.py b/third_party/ChamferDistancePytorch/chamfer3D/setup.py new file mode 100755 index 0000000..9a23aad --- /dev/null +++ b/third_party/ChamferDistancePytorch/chamfer3D/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='chamfer_3D', + ext_modules=[ + CUDAExtension('chamfer_3D', [ + "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), + "/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']), + ]), + ], + cmdclass={ + 'build_ext': BuildExtension + }) \ No newline at end of file diff --git a/third_party/ChamferDistancePytorch/chamfer5D/chamfer5D.cu b/third_party/ChamferDistancePytorch/chamfer5D/chamfer5D.cu new file mode 100755 index 0000000..650e889 --- /dev/null +++ b/third_party/ChamferDistancePytorch/chamfer5D/chamfer5D.cu @@ -0,0 +1,223 @@ + +#include +#include + +#include +#include + +#include + + + +__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ + const int batch=2048; + __shared__ float buf[batch*5]; + for (int i=blockIdx.x;ibest){ + result[(i*n+j)]=best; + result_i[(i*n+j)]=best_i; + } + } + __syncthreads(); + } + } +} +// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ + + const auto batch_size = xyz1.size(0); + const auto n = xyz1.size(1); //num_points point cloud A + const auto m = xyz2.size(1); //num_points point cloud B + + NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); + NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + + +} +__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ + for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); + NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + +} diff --git a/third_party/ChamferDistancePytorch/chamfer5D/chamfer_cuda.cpp b/third_party/ChamferDistancePytorch/chamfer5D/chamfer_cuda.cpp new file mode 100755 index 0000000..67574e2 --- /dev/null +++ b/third_party/ChamferDistancePytorch/chamfer5D/chamfer_cuda.cpp @@ -0,0 +1,33 @@ +#include +#include + +///TMP +//#include "common.h" +/// NOT TMP + + +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); + + +int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); + + + + +int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { + return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); +} + + +int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, + at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { + + return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); + m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); +} \ No newline at end of file diff --git a/third_party/ChamferDistancePytorch/chamfer5D/dist_chamfer_5D.py b/third_party/ChamferDistancePytorch/chamfer5D/dist_chamfer_5D.py new file mode 100644 index 0000000..9cf749d --- /dev/null +++ b/third_party/ChamferDistancePytorch/chamfer5D/dist_chamfer_5D.py @@ -0,0 +1,82 @@ +from torch import nn +from torch.autograd import Function +import torch +import importlib +import os + +chamfer_found = importlib.find_loader("chamfer_5D") is not None +if not chamfer_found: + ## Cool trick from https://github.com/chrdiller + print("Jitting Chamfer 5D") + cur_path = os.path.dirname(os.path.abspath(__file__)) + build_path = cur_path.replace('chamfer5D', 'tmp') + os.makedirs(build_path, exist_ok=True) + + from torch.utils.cpp_extension import load + chamfer_5D = load(name="chamfer_5D", + sources=[ + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer5D.cu"]), + ], build_directory=build_path) + print("Loaded JIT 5D CUDA chamfer distance") + +else: + import chamfer_5D + print("Loaded compiled 5D CUDA chamfer distance") + + +# Chamfer's distance module @thibaultgroueix +# GPU tensors only +class chamfer_5DFunction(Function): + @staticmethod + def forward(ctx, xyz1, xyz2): + batchsize, n, dim = xyz1.size() + assert dim==5, "Wrong last dimension for the chamfer distance 's input! Check with .size()" + _, m, dim = xyz2.size() + assert dim==5, "Wrong last dimension for the chamfer distance 's input! Check with .size()" + device = xyz1.device + + device = xyz1.device + + dist1 = torch.zeros(batchsize, n) + dist2 = torch.zeros(batchsize, m) + + idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) + idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) + + dist1 = dist1.to(device) + dist2 = dist2.to(device) + idx1 = idx1.to(device) + idx2 = idx2.to(device) + torch.cuda.set_device(device) + + chamfer_5D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) + ctx.save_for_backward(xyz1, xyz2, idx1, idx2) + return dist1, dist2, idx1, idx2 + + @staticmethod + def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): + xyz1, xyz2, idx1, idx2 = ctx.saved_tensors + graddist1 = graddist1.contiguous() + graddist2 = graddist2.contiguous() + device = graddist1.device + + gradxyz1 = torch.zeros(xyz1.size()) + gradxyz2 = torch.zeros(xyz2.size()) + + gradxyz1 = gradxyz1.to(device) + gradxyz2 = gradxyz2.to(device) + chamfer_5D.backward( + xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 + ) + return gradxyz1, gradxyz2 + + +class chamfer_5DDist(nn.Module): + def __init__(self): + super(chamfer_5DDist, self).__init__() + + def forward(self, input1, input2): + input1 = input1.contiguous() + input2 = input2.contiguous() + return chamfer_5DFunction.apply(input1, input2) diff --git a/third_party/ChamferDistancePytorch/chamfer5D/setup.py b/third_party/ChamferDistancePytorch/chamfer5D/setup.py new file mode 100755 index 0000000..2429235 --- /dev/null +++ b/third_party/ChamferDistancePytorch/chamfer5D/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='chamfer_5D', + ext_modules=[ + CUDAExtension('chamfer_5D', [ + "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), + "/".join(__file__.split('/')[:-1] + ['chamfer5D.cu']), + ]), + ], + cmdclass={ + 'build_ext': BuildExtension + }) \ No newline at end of file diff --git a/third_party/ChamferDistancePytorch/chamfer6D/chamfer6D.cu b/third_party/ChamferDistancePytorch/chamfer6D/chamfer6D.cu new file mode 100755 index 0000000..bb51530 --- /dev/null +++ b/third_party/ChamferDistancePytorch/chamfer6D/chamfer6D.cu @@ -0,0 +1,237 @@ + +#include +#include + +#include +#include + +#include + + + +__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ + const int batch=2048; + __shared__ float buf[batch*6]; + for (int i=blockIdx.x;ibest){ + result[(i*n+j)]=best; + result_i[(i*n+j)]=best_i; + } + } + __syncthreads(); + } + } +} +// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ + + const auto batch_size = xyz1.size(0); + const auto n = xyz1.size(1); //num_points point cloud A + const auto m = xyz2.size(1); //num_points point cloud B + + NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); + NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + + +} +__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ + for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); + NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + +} diff --git a/third_party/ChamferDistancePytorch/chamfer6D/chamfer_cuda.cpp b/third_party/ChamferDistancePytorch/chamfer6D/chamfer_cuda.cpp new file mode 100755 index 0000000..67574e2 --- /dev/null +++ b/third_party/ChamferDistancePytorch/chamfer6D/chamfer_cuda.cpp @@ -0,0 +1,33 @@ +#include +#include + +///TMP +//#include "common.h" +/// NOT TMP + + +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); + + +int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); + + + + +int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { + return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); +} + + +int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, + at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { + + return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); + m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); +} \ No newline at end of file diff --git a/third_party/ChamferDistancePytorch/chamfer6D/dist_chamfer_6D.py b/third_party/ChamferDistancePytorch/chamfer6D/dist_chamfer_6D.py new file mode 100755 index 0000000..0f073f8 --- /dev/null +++ b/third_party/ChamferDistancePytorch/chamfer6D/dist_chamfer_6D.py @@ -0,0 +1,82 @@ +from torch import nn +from torch.autograd import Function +import torch +import importlib +import os + +chamfer_found = importlib.find_loader("chamfer_6D") is not None +if not chamfer_found: + ## Cool trick from https://github.com/chrdiller + print("Jitting Chamfer 6D") + cur_path = os.path.dirname(os.path.abspath(__file__)) + build_path = cur_path.replace('chamfer6D', 'tmp') + os.makedirs(build_path, exist_ok=True) + + from torch.utils.cpp_extension import load + chamfer_6D = load(name="chamfer_6D", + sources=[ + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer6D.cu"]), + ], build_directory=build_path) + print("Loaded JIT 6D CUDA chamfer distance") + +else: + import chamfer_6D + print("Loaded compiled 6D CUDA chamfer distance") + + +# Chamfer's distance module @thibaultgroueix +# GPU tensors only +class chamfer_6DFunction(Function): + @staticmethod + def forward(ctx, xyz1, xyz2): + batchsize, n, dim = xyz1.size() + assert dim==6, "Wrong last dimension for the chamfer distance 's input! Check with .size()" + _, m, dim = xyz2.size() + assert dim==6, "Wrong last dimension for the chamfer distance 's input! Check with .size()" + device = xyz1.device + + device = xyz1.device + + dist1 = torch.zeros(batchsize, n) + dist2 = torch.zeros(batchsize, m) + + idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) + idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) + + dist1 = dist1.to(device) + dist2 = dist2.to(device) + idx1 = idx1.to(device) + idx2 = idx2.to(device) + torch.cuda.set_device(device) + + chamfer_6D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) + ctx.save_for_backward(xyz1, xyz2, idx1, idx2) + return dist1, dist2, idx1, idx2 + + @staticmethod + def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): + xyz1, xyz2, idx1, idx2 = ctx.saved_tensors + graddist1 = graddist1.contiguous() + graddist2 = graddist2.contiguous() + device = graddist1.device + + gradxyz1 = torch.zeros(xyz1.size()) + gradxyz2 = torch.zeros(xyz2.size()) + + gradxyz1 = gradxyz1.to(device) + gradxyz2 = gradxyz2.to(device) + chamfer_6D.backward( + xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 + ) + return gradxyz1, gradxyz2 + + +class chamfer_6DDist(nn.Module): + def __init__(self): + super(chamfer_6DDist, self).__init__() + + def forward(self, input1, input2): + input1 = input1.contiguous() + input2 = input2.contiguous() + return chamfer_6DFunction.apply(input1, input2) diff --git a/third_party/ChamferDistancePytorch/chamfer6D/setup.py b/third_party/ChamferDistancePytorch/chamfer6D/setup.py new file mode 100755 index 0000000..4b9044c --- /dev/null +++ b/third_party/ChamferDistancePytorch/chamfer6D/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='chamfer_6D', + ext_modules=[ + CUDAExtension('chamfer_6D', [ + "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), + "/".join(__file__.split('/')[:-1] + ['chamfer6D.cu']), + ]), + ], + cmdclass={ + 'build_ext': BuildExtension + }) \ No newline at end of file diff --git a/third_party/ChamferDistancePytorch/chamfer_python.py b/third_party/ChamferDistancePytorch/chamfer_python.py new file mode 100644 index 0000000..9c3bc1d --- /dev/null +++ b/third_party/ChamferDistancePytorch/chamfer_python.py @@ -0,0 +1,44 @@ +import torch + + +def pairwise_dist(x, y): + xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t()) + rx = xx.diag().unsqueeze(0).expand_as(xx) + ry = yy.diag().unsqueeze(0).expand_as(yy) + P = rx.t() + ry - 2 * zz + return P + + +def NN_loss(x, y, dim=0): + dist = pairwise_dist(x, y) + values, indices = dist.min(dim=dim) + return values.mean() + + +def batched_pairwise_dist(a, b): + x, y = a.double(), b.double() + bs, num_points_x, points_dim = x.size() + bs, num_points_y, points_dim = y.size() + + xx = torch.pow(x, 2).sum(2) + yy = torch.pow(y, 2).sum(2) + zz = torch.bmm(x, y.transpose(2, 1)) + rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx + ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy + P = rx.transpose(2, 1) + ry - 2 * zz + return P + +def distChamfer(a, b): + """ + :param a: Pointclouds Batch x nul_points x dim + :param b: Pointclouds Batch x nul_points x dim + :return: + -closest point on b of points from a + -closest point on a of points from b + -idx of closest point on b of points from a + -idx of closest point on a of points from b + Works for pointcloud of any dimension + """ + P = batched_pairwise_dist(a, b) + return torch.min(P, 2)[0].float(), torch.min(P, 1)[0].float(), torch.min(P, 2)[1].int(), torch.min(P, 1)[1].int() + diff --git a/third_party/ChamferDistancePytorch/fscore.py b/third_party/ChamferDistancePytorch/fscore.py new file mode 100644 index 0000000..265378b --- /dev/null +++ b/third_party/ChamferDistancePytorch/fscore.py @@ -0,0 +1,17 @@ +import torch + +def fscore(dist1, dist2, threshold=0.001): + """ + Calculates the F-score between two point clouds with the corresponding threshold value. + :param dist1: Batch, N-Points + :param dist2: Batch, N-Points + :param th: float + :return: fscore, precision, recall + """ + # NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt the threshold accordingly. + precision_1 = torch.mean((dist1 < threshold).float(), dim=1) + precision_2 = torch.mean((dist2 < threshold).float(), dim=1) + fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2) + fscore[torch.isnan(fscore)] = 0 + return fscore, precision_1, precision_2 + diff --git a/third_party/ChamferDistancePytorch/unit_test.py b/third_party/ChamferDistancePytorch/unit_test.py new file mode 100644 index 0000000..13af6a3 --- /dev/null +++ b/third_party/ChamferDistancePytorch/unit_test.py @@ -0,0 +1,69 @@ +import torch, time +import chamfer2D.dist_chamfer_2D +import chamfer3D.dist_chamfer_3D +import chamfer5D.dist_chamfer_5D +import chamfer_python + +cham2D = chamfer2D.dist_chamfer_2D.chamfer_2DDist() +cham3D = chamfer3D.dist_chamfer_3D.chamfer_3DDist() +cham5D = chamfer5D.dist_chamfer_5D.chamfer_5DDist() + +from torch.autograd import Variable +from fscore import fscore + +def test_chamfer(distChamfer, dim): + points1 = torch.rand(4, 100, dim).cuda() + points2 = torch.rand(4, 200, dim, requires_grad=True).cuda() + dist1, dist2, idx1, idx2= distChamfer(points1, points2) + + loss = torch.sum(dist1) + loss.backward() + + mydist1, mydist2, myidx1, myidx2 = chamfer_python.distChamfer(points1, points2) + d1 = (dist1 - mydist1) ** 2 + d2 = (dist2 - mydist2) ** 2 + assert ( + torch.mean(d1) + torch.mean(d2) < 0.00000001 + ), "chamfer cuda and chamfer normal are not giving the same results" + + xd1 = idx1 - myidx1 + xd2 = idx2 - myidx2 + assert ( + torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0 + ), "chamfer cuda and chamfer normal are not giving the same results" + print(f"fscore :", fscore(dist1, dist2)) + print("Unit test passed") + + +def timings(distChamfer, dim): + p1 = torch.rand(32, 2000, dim).cuda() + p2 = torch.rand(32, 1000, dim).cuda() + print("Timings : Start CUDA version") + start = time.time() + num_it = 100 + for i in range(num_it): + points1 = Variable(p1, requires_grad=True) + points2 = Variable(p2) + mydist1, mydist2, idx1, idx2 = distChamfer(points1, points2) + loss = torch.sum(mydist1) + loss.backward() + print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.") + + + print("Timings : Start Pythonic version") + start = time.time() + for i in range(num_it): + points1 = Variable(p1, requires_grad=True) + points2 = Variable(p2) + mydist1, mydist2, idx1, idx2 = chamfer_python.distChamfer(points1, points2) + loss = torch.sum(mydist1) + loss.backward() + print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.") + + + +dims = [2,3,5] +for i,cham in enumerate([cham2D, cham3D, cham5D]): + print(f"testing Chamfer {dims[i]}D") + test_chamfer(cham, dims[i]) + timings(cham, dims[i]) diff --git a/third_party/PyTorchEMD/.gitignore b/third_party/PyTorchEMD/.gitignore new file mode 100644 index 0000000..8400d00 --- /dev/null +++ b/third_party/PyTorchEMD/.gitignore @@ -0,0 +1,5 @@ +__pycache__ +build +dist +emd_ext.egg-info +*.so diff --git a/third_party/PyTorchEMD/README.md b/third_party/PyTorchEMD/README.md new file mode 100644 index 0000000..f6b38fc --- /dev/null +++ b/third_party/PyTorchEMD/README.md @@ -0,0 +1,34 @@ +* adapted from https://github.com/daerduoCarey/PyTorchEMD + +--------------------------------- +# PyTorch Wrapper for Point-cloud Earth-Mover-Distance (EMD) + +## Dependency + +The code has been tested on Ubuntu 16.04, PyTorch 1.1.0, CUDA 9.0. + +## Usage + +First compile using + + python setup.py install + +Then, copy the lib file out to the main directory, + + cp build/lib.linux-x86_64-3.6/emd_cuda.cpython-36m-x86_64-linux-gnu.so . + +Then, you can use it by simply + + from emd import earth_mover_distance + d = earth_mover_distance(p1, p2, transpose=False) # p1: B x N1 x 3, p2: B x N2 x 3 + +Check `test_emd_loss.py` for example. + +## Author + +The cuda code is originally written by Haoqiang Fan. The PyTorch wrapper is written by Kaichun Mo. Also, Jiayuan Gu provided helps. + +## License + +MIT + diff --git a/third_party/PyTorchEMD/__init__.py b/third_party/PyTorchEMD/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/third_party/PyTorchEMD/backend.py b/third_party/PyTorchEMD/backend.py new file mode 100755 index 0000000..19943fc --- /dev/null +++ b/third_party/PyTorchEMD/backend.py @@ -0,0 +1,21 @@ +import os +import time +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +if not os.path.exists(os.path.join(_src_path, 'build_dynamic')): + os.makedirs(os.path.join(_src_path, 'build_dynamic')) +tic = time.time() +emd_cuda_dynamic = load(name='emd_ext', + extra_cflags=['-O3', '-std=c++17'], + ## build_directory=os.path.join(_src_path, 'build_dynamic'), + verbose=True, + sources=[ + os.path.join(_src_path, f) for f in [ + 'cuda/emd.cpp', + 'cuda/emd_kernel.cu', + ] + ]) +print('load emd_ext time: {:.3f}s'.format(time.time() - tic)) +__all__ = ['emd_cuda_dynamic'] diff --git a/third_party/PyTorchEMD/cuda/emd.cpp b/third_party/PyTorchEMD/cuda/emd.cpp new file mode 100755 index 0000000..b94db14 --- /dev/null +++ b/third_party/PyTorchEMD/cuda/emd.cpp @@ -0,0 +1,29 @@ +#ifndef _EMD +#define _EMD + +#include +#include + +//CUDA declarations +at::Tensor ApproxMatchForward( + const at::Tensor xyz1, + const at::Tensor xyz2); + +at::Tensor MatchCostForward( + const at::Tensor xyz1, + const at::Tensor xyz2, + const at::Tensor match); + +std::vector MatchCostBackward( + const at::Tensor grad_cost, + const at::Tensor xyz1, + const at::Tensor xyz2, + const at::Tensor match); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("approxmatch_forward", &ApproxMatchForward,"ApproxMatch forward (CUDA)"); + m.def("matchcost_forward", &MatchCostForward,"MatchCost forward (CUDA)"); + m.def("matchcost_backward", &MatchCostBackward,"MatchCost backward (CUDA)"); +} + +#endif diff --git a/third_party/PyTorchEMD/cuda/emd_kernel.cu b/third_party/PyTorchEMD/cuda/emd_kernel.cu new file mode 100644 index 0000000..8da5d88 --- /dev/null +++ b/third_party/PyTorchEMD/cuda/emd_kernel.cu @@ -0,0 +1,398 @@ +/********************************** + * Original Author: Haoqiang Fan + * Modified by: Kaichun Mo + *********************************/ + +#ifndef _EMD_KERNEL +#define _EMD_KERNEL + +#include +#include + +#include +#include // at::cuda::getApplyGrid +#include + +#define CHECK_INPUT(x) + + +/******************************** +* Forward kernel for approxmatch +*********************************/ + +template +__global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,scalar_t * __restrict__ match,scalar_t * temp){ + scalar_t * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n; + scalar_t multiL,multiR; + if (n>=m){ + multiL=1; + multiR=n/m; + }else{ + multiL=m/n; + multiR=1; + } + const int Block=1024; + __shared__ scalar_t buf[Block*4]; + for (int i=blockIdx.x;i=-2;j--){ + scalar_t level=-powf(4.0f,j); + if (j==-2){ + level=0; + } + for (int k0=0;k0>>(b,n,m,xyz1,xyz2,match,temp); +//} + +/* ApproxMatch forward interface +Input: + xyz1: (B, N1, 3) # dataset_points + xyz2: (B, N2, 3) # query_points +Output: + match: (B, N2, N1) +*/ +at::Tensor ApproxMatchForward( + const at::Tensor xyz1, + const at::Tensor xyz2){ + const auto b = xyz1.size(0); + const auto n = xyz1.size(1); + const auto m = xyz2.size(1); + + CHECK_EQ(xyz2.size(0), b); + CHECK_EQ(xyz1.size(2), 3); + CHECK_EQ(xyz2.size(2), 3); + CHECK_INPUT(xyz1); + CHECK_INPUT(xyz2); + + auto match = at::zeros({b, m, n}, xyz1.type()); + auto temp = at::zeros({b, (n+m)*2}, xyz1.type()); + + AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] { + approxmatch<<<32,512>>>(b, n, m, xyz1.data(), xyz2.data(), match.data(), temp.data()); + })); + THCudaCheck(cudaGetLastError()); + + return match; +} + + +/******************************** +* Forward kernel for matchcost +*********************************/ + +template +__global__ void matchcost(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ out){ + __shared__ scalar_t allsum[512]; + const int Block=1024; + __shared__ scalar_t buf[Block*3]; + for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,out); +//} + +/* MatchCost forward interface +Input: + xyz1: (B, N1, 3) # dataset_points + xyz2: (B, N2, 3) # query_points + match: (B, N2, N1) +Output: + cost: (B) +*/ +at::Tensor MatchCostForward( + const at::Tensor xyz1, + const at::Tensor xyz2, + const at::Tensor match){ + const auto b = xyz1.size(0); + const auto n = xyz1.size(1); + const auto m = xyz2.size(1); + + CHECK_EQ(xyz2.size(0), b); + CHECK_EQ(xyz1.size(2), 3); + CHECK_EQ(xyz2.size(2), 3); + CHECK_INPUT(xyz1); + CHECK_INPUT(xyz2); + + auto cost = at::zeros({b}, xyz1.type()); + + AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] { + matchcost<<<32,512>>>(b, n, m, xyz1.data(), xyz2.data(), match.data(), cost.data()); + })); + THCudaCheck(cudaGetLastError()); + + return cost; +} + + +/******************************** +* matchcostgrad2 kernel +*********************************/ + +template +__global__ void matchcostgrad2(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad2){ + __shared__ scalar_t sum_grad[256*3]; + for (int i=blockIdx.x;i +__global__ void matchcostgrad1(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad1){ + for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,grad1); +// matchcostgrad2<<>>(b,n,m,xyz1,xyz2,match,grad2); +//} + + +/* MatchCost backward interface +Input: + grad_cost: (B) # gradients on cost + xyz1: (B, N1, 3) # dataset_points + xyz2: (B, N2, 3) # query_points + match: (B, N2, N1) +Output: + grad1: (B, N1, 3) + grad2: (B, N2, 3) +*/ +std::vector MatchCostBackward( + const at::Tensor grad_cost, + const at::Tensor xyz1, + const at::Tensor xyz2, + const at::Tensor match){ + const auto b = xyz1.size(0); + const auto n = xyz1.size(1); + const auto m = xyz2.size(1); + + CHECK_EQ(xyz2.size(0), b); + CHECK_EQ(xyz1.size(2), 3); + CHECK_EQ(xyz2.size(2), 3); + CHECK_INPUT(xyz1); + CHECK_INPUT(xyz2); + + auto grad1 = at::zeros({b, n, 3}, xyz1.type()); + auto grad2 = at::zeros({b, m, 3}, xyz1.type()); + + AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostBackward", ([&] { + matchcostgrad1<<<32,512>>>(b, n, m, grad_cost.data(), xyz1.data(), xyz2.data(), match.data(), grad1.data()); + matchcostgrad2<<>>(b, n, m, grad_cost.data(), xyz1.data(), xyz2.data(), match.data(), grad2.data()); + })); + THCudaCheck(cudaGetLastError()); + + return std::vector({grad1, grad2}); +} + +#endif diff --git a/third_party/PyTorchEMD/emd.py b/third_party/PyTorchEMD/emd.py new file mode 100755 index 0000000..ee3bdc7 --- /dev/null +++ b/third_party/PyTorchEMD/emd.py @@ -0,0 +1,52 @@ +import torch +# from backend import emd_cuda_dynamic as emd_cuda # jit compiling +from third_party.PyTorchEMD.backend import emd_cuda_dynamic as emd_cuda +from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd + +class EarthMoverDistanceFunction(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, xyz1, xyz2): + xyz1 = xyz1.contiguous() + xyz2 = xyz2.contiguous() + assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently." + match = emd_cuda.approxmatch_forward(xyz1, xyz2) + cost = emd_cuda.matchcost_forward(xyz1, xyz2, match) + ctx.save_for_backward(xyz1, xyz2, match) + return cost + + @staticmethod + @custom_bwd + def backward(ctx, grad_cost): + xyz1, xyz2, match = ctx.saved_tensors + grad_cost = grad_cost.contiguous() + grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match) + return grad_xyz1, grad_xyz2 + + +def earth_mover_distance(xyz1, xyz2, transpose=True): + """Earth Mover Distance (Approx) + + Args: + xyz1 (torch.Tensor): (b, 3, n1) + xyz2 (torch.Tensor): (b, 3, n1) + transpose (bool): whether to transpose inputs as it might be BCN format. + Extensions only support BNC format. + + Returns: + cost (torch.Tensor): (b) + + """ + if xyz1.dim() == 2: + xyz1 = xyz1.unsqueeze(0) + if xyz2.dim() == 2: + xyz2 = xyz2.unsqueeze(0) + if transpose: + xyz1 = xyz1.transpose(1, 2) + xyz2 = xyz2.transpose(1, 2) + # xyz1: B,N,3 + N = xyz1.shape[1] + assert(xyz1.shape[-1] == 3), f'require it to be B,N,3; get: {xyz1.shape}' + cost = EarthMoverDistanceFunction.apply(xyz1, xyz2) / float(N) + return cost + diff --git a/third_party/PyTorchEMD/emd_cuda.py b/third_party/PyTorchEMD/emd_cuda.py new file mode 100644 index 0000000..b145b27 --- /dev/null +++ b/third_party/PyTorchEMD/emd_cuda.py @@ -0,0 +1,9 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, importlib.util + __file__ = pkg_resources.resource_filename(__name__, 'emd_cuda.cpython-38-x86_64-linux-gnu.so') + __loader__ = None; del __bootstrap__, __loader__ + spec = importlib.util.spec_from_file_location(__name__,__file__) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) +__bootstrap__() diff --git a/third_party/PyTorchEMD/emd_nograd.py b/third_party/PyTorchEMD/emd_nograd.py new file mode 100644 index 0000000..96461ae --- /dev/null +++ b/third_party/PyTorchEMD/emd_nograd.py @@ -0,0 +1,45 @@ +import torch +#import emd_cuda +# from evaluation.PyTorchEMD import emd_cuda +from third_party.PyTorchEMD.backend import emd_cuda_dynamic as emd_cuda + + +class EarthMoverDistanceFunctionNoGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, xyz1, xyz2): + xyz1 = xyz1.contiguous() + xyz2 = xyz2.contiguous() + assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently." + match = emd_cuda.approxmatch_forward(xyz1, xyz2) + cost = emd_cuda.matchcost_forward(xyz1, xyz2, match) + # ctx.save_for_backward(xyz1, xyz2, match) + return cost + + +def earth_mover_distance_nograd(xyz1, xyz2, transpose=True): + """Earth Mover Distance (Approx) + + Args: + xyz1 (torch.Tensor): (b, 3, n1) + xyz2 (torch.Tensor): (b, 3, n1) + transpose (bool): whether to transpose inputs as it might be BCN format. + Extensions only support BNC format. + + Returns: + cost (torch.Tensor): (b) + + """ + if xyz1.dim() == 2: + xyz1 = xyz1.unsqueeze(0) + if xyz2.dim() == 2: + xyz2 = xyz2.unsqueeze(0) + if transpose: + xyz1 = xyz1.transpose(1, 2) + xyz2 = xyz2.transpose(1, 2) + # xyz1: B,N,3 + N = xyz1.shape[1] + assert(xyz1.shape[-1] == 3), f'require it to be B,N,3; get: {xyz1.shape}' + #print('xyz1: ', xyz1.shape, xyz2.shape, xyz1.min(), xyz1.max(), xyz2.min(), xyz2.max()) + cost = EarthMoverDistanceFunctionNoGrad.apply(xyz1, xyz2) / float(N) + return cost + diff --git a/third_party/PyTorchEMD/emd_static.py b/third_party/PyTorchEMD/emd_static.py new file mode 100755 index 0000000..ac44fc3 --- /dev/null +++ b/third_party/PyTorchEMD/emd_static.py @@ -0,0 +1,49 @@ +import torch +import emd_cuda + + +class EarthMoverDistanceFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, xyz1, xyz2): + xyz1 = xyz1.contiguous() + xyz2 = xyz2.contiguous() + assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently." + match = emd_cuda.approxmatch_forward(xyz1, xyz2) + cost = emd_cuda.matchcost_forward(xyz1, xyz2, match) + ctx.save_for_backward(xyz1, xyz2, match) + return cost + + @staticmethod + def backward(ctx, grad_cost): + xyz1, xyz2, match = ctx.saved_tensors + grad_cost = grad_cost.contiguous() + grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match) + return grad_xyz1, grad_xyz2 + + +def earth_mover_distance(xyz1, xyz2, transpose=True): + """Earth Mover Distance (Approx) + + Args: + xyz1 (torch.Tensor): (b, 3, n1) + xyz2 (torch.Tensor): (b, 3, n1) + transpose (bool): whether to transpose inputs as it might be BCN format. + Extensions only support BNC format. + + Returns: + cost (torch.Tensor): (b) + + """ + if xyz1.dim() == 2: + xyz1 = xyz1.unsqueeze(0) + if xyz2.dim() == 2: + xyz2 = xyz2.unsqueeze(0) + if transpose: + xyz1 = xyz1.transpose(1, 2) + xyz2 = xyz2.transpose(1, 2) + # xyz1: B,N,3 + N = xyz1.shape[1] + assert(xyz1.shape[-1] == 3), f'require it to be B,N,3; get: {xyz1.shape}' + cost = EarthMoverDistanceFunction.apply(xyz1, xyz2) / float(N) + return cost + diff --git a/third_party/PyTorchEMD/setup.py b/third_party/PyTorchEMD/setup.py new file mode 100755 index 0000000..f648c3e --- /dev/null +++ b/third_party/PyTorchEMD/setup.py @@ -0,0 +1,27 @@ +"""Setup extension + +Notes: + If extra_compile_args is provided, you need to provide different instances for different extensions. + Refer to https://github.com/pytorch/pytorch/issues/20169 + +""" + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + +setup( + name='emd_ext', + ext_modules=[ + CUDAExtension( + name='emd_cuda', + sources=[ + 'cuda/emd.cpp', + 'cuda/emd_kernel.cu', + ], + extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']} + ), + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/third_party/PyTorchEMD/test_emd_loss.py b/third_party/PyTorchEMD/test_emd_loss.py new file mode 100644 index 0000000..66aa33c --- /dev/null +++ b/third_party/PyTorchEMD/test_emd_loss.py @@ -0,0 +1,44 @@ +import torch +import numpy as np +import time +from emd import earth_mover_distance + +# gt +p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda() +p1 = p1.repeat(3, 1, 1) +p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda() +p2 = p2.repeat(3, 1, 1) +print(p1) +print(p2) +p1.requires_grad = True +p2.requires_grad = True + +gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \ + (((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \ + (((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3 +print('gt_dist: ', gt_dist) + +gt_dist.backward() +print(p1.grad) +print(p2.grad) + +# emd +p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda() +p1 = p1.repeat(3, 1, 1) +p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda() +p2 = p2.repeat(3, 1, 1) +print(p1) +print(p2) +p1.requires_grad = True +p2.requires_grad = True + +d = earth_mover_distance(p1, p2, transpose=False) +print(d) + +loss = d[0] / 2 + d[1] * 2 + d[2] / 3 +print(loss) + +loss.backward() +print(p1.grad) +print(p2.grad) + diff --git a/third_party/pvcnn/LICENSE b/third_party/pvcnn/LICENSE new file mode 100644 index 0000000..22913f5 --- /dev/null +++ b/third_party/pvcnn/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 Zhijian Liu, Haotian Tang, Yujun Lin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/pvcnn/README.md b/third_party/pvcnn/README.md new file mode 100644 index 0000000..390d9a4 --- /dev/null +++ b/third_party/pvcnn/README.md @@ -0,0 +1,2 @@ +* all the code under this folder is based on the code under https://github.com/mit-han-lab/pvcnn/tree/master/modules + diff --git a/third_party/pvcnn/functional/__init__.py b/third_party/pvcnn/functional/__init__.py new file mode 100644 index 0000000..b03ee71 --- /dev/null +++ b/third_party/pvcnn/functional/__init__.py @@ -0,0 +1,7 @@ +from third_party.pvcnn.functional.ball_query import ball_query +from third_party.pvcnn.functional.devoxelization import trilinear_devoxelize +from third_party.pvcnn.functional.grouping import grouping +from third_party.pvcnn.functional.interpolatation import nearest_neighbor_interpolate +from third_party.pvcnn.functional.loss import kl_loss, huber_loss +from third_party.pvcnn.functional.sampling import gather, furthest_point_sample, logits_mask +from third_party.pvcnn.functional.voxelization import avg_voxelize diff --git a/third_party/pvcnn/functional/backend.py b/third_party/pvcnn/functional/backend.py new file mode 100644 index 0000000..8d8ed75 --- /dev/null +++ b/third_party/pvcnn/functional/backend.py @@ -0,0 +1,29 @@ +import os + +from torch.utils.cpp_extension import load +_src_path = os.path.dirname(os.path.abspath(__file__)) + +if not os.path.exists(os.path.join(_src_path, 'build')): + os.makedirs(os.path.join(_src_path, 'build')) +_backend = load(name='_pvcnn_backend', + extra_cflags=['-O3', '-std=c++17'], + verbose=True, + sources=[ + os.path.join(_src_path, 'src', f) for f in [ + 'ball_query/ball_query.cpp', + 'ball_query/ball_query.cu', + 'grouping/grouping.cpp', + 'grouping/grouping.cu', + 'interpolate/neighbor_interpolate.cpp', + 'interpolate/neighbor_interpolate.cu', + 'interpolate/trilinear_devox.cpp', + 'interpolate/trilinear_devox.cu', + 'sampling/sampling.cpp', + 'sampling/sampling.cu', + 'voxelization/vox.cpp', + 'voxelization/vox.cu', + 'bindings.cpp', + ] + ]) + +__all__ = ['_backend'] diff --git a/third_party/pvcnn/functional/ball_query.py b/third_party/pvcnn/functional/ball_query.py new file mode 100644 index 0000000..3a3699f --- /dev/null +++ b/third_party/pvcnn/functional/ball_query.py @@ -0,0 +1,20 @@ +from torch.autograd import Function + +from third_party.pvcnn.functional.backend import _backend + +__all__ = ['ball_query'] + + +def ball_query(centers_coords, points_coords, radius, num_neighbors): + """ + :param centers_coords: coordinates of centers, FloatTensor[B, 3, M] + :param points_coords: coordinates of points, FloatTensor[B, 3, N] + :param radius: float, radius of ball query + :param num_neighbors: int, maximum number of neighbors + :return: + neighbor_indices: indices of neighbors, IntTensor[B, M, U] + """ + centers_coords = centers_coords[:,:3].contiguous() + points_coords = points_coords[:,:3].contiguous() + return _backend.ball_query(centers_coords, points_coords, radius, + num_neighbors) diff --git a/third_party/pvcnn/functional/devoxelization.py b/third_party/pvcnn/functional/devoxelization.py new file mode 100644 index 0000000..29bc365 --- /dev/null +++ b/third_party/pvcnn/functional/devoxelization.py @@ -0,0 +1,45 @@ +from torch.autograd import Function + +from third_party.pvcnn.functional.backend import _backend + +__all__ = ['trilinear_devoxelize'] + + +class TrilinearDevoxelization(Function): + @staticmethod + def forward(ctx, features, coords, resolution, is_training=True): + """ + :param ctx: + :param coords: the coordinates of points, FloatTensor[B, 3, N] + :param features: FloatTensor[B, C, R, R, R] + :param resolution: int, the voxel resolution + :param is_training: bool, training mode + :return: + FloatTensor[B, C, N] + """ + B, C = features.shape[:2] + features = features.contiguous().view(B, C, -1) + coords = coords[:,:3].contiguous() + outs, inds, wgts = _backend.trilinear_devoxelize_forward( + resolution, is_training, coords, features) + if is_training: + ctx.save_for_backward(inds, wgts) + ctx.r = resolution + return outs + + @staticmethod + def backward(ctx, grad_output): + """ + :param ctx: + :param grad_output: gradient of outputs, FloatTensor[B, C, N] + :return: + gradient of inputs, FloatTensor[B, C, R, R, R] + """ + inds, wgts = ctx.saved_tensors + grad_inputs = _backend.trilinear_devoxelize_backward( + grad_output.contiguous(), inds, wgts, ctx.r) + return grad_inputs.view(grad_output.size(0), grad_output.size(1), + ctx.r, ctx.r, ctx.r), None, None, None + + +trilinear_devoxelize = TrilinearDevoxelization.apply diff --git a/third_party/pvcnn/functional/grouping.py b/third_party/pvcnn/functional/grouping.py new file mode 100644 index 0000000..cbf0850 --- /dev/null +++ b/third_party/pvcnn/functional/grouping.py @@ -0,0 +1,33 @@ +from torch.autograd import Function + +# from modules.functional.backend import _backend +from third_party.pvcnn.functional.backend import _backend + +__all__ = ['grouping'] + + +class Grouping(Function): + @staticmethod + def forward(ctx, features, indices): + """ + :param ctx: + :param features: features of points, FloatTensor[B, C, N] + :param indices: neighbor indices of centers, IntTensor[B, M, U], M is #centers, U is #neighbors + :return: + grouped_features: grouped features, FloatTensor[B, C, M, U] + """ + features = features.contiguous() + indices = indices.contiguous() + ctx.save_for_backward(indices) + ctx.num_points = features.size(-1) + return _backend.grouping_forward(features, indices) + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + grad_features = _backend.grouping_backward(grad_output.contiguous(), + indices, ctx.num_points) + return grad_features, None + + +grouping = Grouping.apply diff --git a/third_party/pvcnn/functional/interpolatation.py b/third_party/pvcnn/functional/interpolatation.py new file mode 100644 index 0000000..02566a7 --- /dev/null +++ b/third_party/pvcnn/functional/interpolatation.py @@ -0,0 +1,54 @@ +from torch.autograd import Function + +# from modules.functional.backend import _backend +from third_party.pvcnn.functional.backend import _backend +import torch +from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd + +__all__ = ['nearest_neighbor_interpolate'] + + +class NeighborInterpolation(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, points_coords, centers_coords, centers_features): + """ + :param ctx: + :param points_coords: coordinates of points, FloatTensor[B, 3, N] + :param centers_coords: coordinates of centers, FloatTensor[B, 3, M] + :param centers_features: features of centers, FloatTensor[B, C, M] + :return: + points_features: features of points, FloatTensor[B, C, N] + """ + centers_coords = centers_coords[:,:3].contiguous() + points_coords = points_coords[:,:3].contiguous() + centers_features = centers_features.contiguous() + points_features, indices, weights = _backend.three_nearest_neighbors_interpolate_forward( + points_coords, centers_coords, centers_features) + ctx.save_for_backward(indices, weights) + ctx.num_centers = centers_coords.size(-1) + return points_features + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + indices, weights = ctx.saved_tensors + grad_centers_features = _backend.three_nearest_neighbors_interpolate_backward( + grad_output.contiguous(), indices, weights, ctx.num_centers) + return None, None, grad_centers_features + + +nearest_neighbor_interpolate = NeighborInterpolation.apply + +#def nearest_neighbor_interpolate(points_coords, centers_coords, centers_features): +# # points_coords: (B,6, 64) +# # centers_coords: (B,6, 16) +# # centers_features: (B,128,16) +# # interpolated_features: (B,128,64) +# B = points_coords.shape[0] +# D = centers_features.shape[1] +# N = points_coords.shape[2] +# output = torch.zeros(B,D,N).to(points_coords.shape) +# for b in range(B): +# for n in range(N): +# points_coords_cur = points_coords diff --git a/third_party/pvcnn/functional/loss.py b/third_party/pvcnn/functional/loss.py new file mode 100644 index 0000000..6f1f6a8 --- /dev/null +++ b/third_party/pvcnn/functional/loss.py @@ -0,0 +1,18 @@ +import torch +import torch.nn.functional as F + +__all__ = ['kl_loss', 'huber_loss'] + + +def kl_loss(x, y): + x = F.softmax(x.detach(), dim=1) + y = F.log_softmax(y, dim=1) + return torch.mean(torch.sum(x * (torch.log(x) - y), dim=1)) + + +def huber_loss(error, delta): + abs_error = torch.abs(error) + quadratic = torch.min(abs_error, + torch.full_like(abs_error, fill_value=delta)) + losses = 0.5 * (quadratic**2) + delta * (abs_error - quadratic) + return torch.mean(losses) diff --git a/third_party/pvcnn/functional/sampling.py b/third_party/pvcnn/functional/sampling.py new file mode 100644 index 0000000..7ee7207 --- /dev/null +++ b/third_party/pvcnn/functional/sampling.py @@ -0,0 +1,100 @@ +import numpy as np +import torch +from torch.autograd import Function + +# from modules.functional.backend import _backend +from third_party.pvcnn.functional.backend import _backend + +__all__ = ['gather', 'furthest_point_sample', 'logits_mask'] + + +class Gather(Function): + @staticmethod + def forward(ctx, features, indices): + """ + Gather + :param ctx: + :param features: features of points, FloatTensor[B, C, N] + :param indices: centers' indices in points, IntTensor[b, m] + :return: + centers_coords: coordinates of sampled centers, FloatTensor[B, C, M] + """ + features = features.contiguous() + indices = indices.int().contiguous() + ctx.save_for_backward(indices) + ctx.num_points = features.size(-1) + return _backend.gather_features_forward(features, indices) + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + grad_features = _backend.gather_features_backward( + grad_output.contiguous(), indices, ctx.num_points) + return grad_features, None + + +gather = Gather.apply + + +def furthest_point_sample(coords, num_samples, normals=None): + """ + Uses iterative furthest point sampling to select a set of npoint features that have the largest + minimum distance to the sampled point set + :param coords: coordinates of points, FloatTensor[B, 3, N] + :param num_samples: int, M + :return: + center_coords: coordinates of sampled centers, FloatTensor[B, 3, M] + """ + assert(len(coords.shape) == 3 and coords.shape[1] == 3), f'expect input as B,3,N; get: {coords.shape}' + coords = coords.contiguous() + indices = _backend.furthest_point_sampling(coords, num_samples) + centers_coords = gather(coords, indices) + if normals is not None: + center_normals = gather(normals, indices) + return centers_coords if normals is None else (centers_coords, center_normals) + + +def logits_mask(coords, logits, num_points_per_object): + """ + Use logits to sample points + :param coords: coords of points, FloatTensor[B, 3, N] + :param logits: binary classification logits, FloatTensor[B, 2, N] + :param num_points_per_object: M, #points per object after masking, int + :return: + selected_coords: FloatTensor[B, 3, M] + masked_coords_mean: mean coords of selected points, FloatTensor[B, 3] + mask: mask to select points, BoolTensor[B, N] + """ + batch_size, _, num_points = coords.shape + mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N] + num_candidates = torch.sum(mask, dim=-1, keepdim=True) # [B, 1] + masked_coords = coords * mask.view(batch_size, 1, num_points) # [B, C, N] + masked_coords_mean = torch.sum(masked_coords, dim=-1) / torch.max( + num_candidates, torch.ones_like(num_candidates)).float() # [B, C] + selected_indices = torch.zeros((batch_size, num_points_per_object), + device=coords.device, + dtype=torch.int32) + for i in range(batch_size): + current_mask = mask[i] # [N] + current_candidates = current_mask.nonzero().view(-1) + current_num_candidates = current_candidates.numel() + if current_num_candidates >= num_points_per_object: + choices = np.random.choice(current_num_candidates, + num_points_per_object, + replace=False) + selected_indices[i] = current_candidates[choices] + elif current_num_candidates > 0: + choices = np.concatenate([ + np.arange(current_num_candidates).repeat( + num_points_per_object // current_num_candidates), + np.random.choice(current_num_candidates, + num_points_per_object % + current_num_candidates, + replace=False) + ]) + np.random.shuffle(choices) + selected_indices[i] = current_candidates[choices] + selected_coords = gather( + masked_coords - masked_coords_mean.view(batch_size, -1, 1), + selected_indices) + return selected_coords, masked_coords_mean, mask diff --git a/third_party/pvcnn/functional/src/ball_query/ball_query.cpp b/third_party/pvcnn/functional/src/ball_query/ball_query.cpp new file mode 100644 index 0000000..5ae1fb6 --- /dev/null +++ b/third_party/pvcnn/functional/src/ball_query/ball_query.cpp @@ -0,0 +1,30 @@ +#include "ball_query.hpp" +#include "ball_query.cuh" + +#include "../utils.hpp" + +at::Tensor ball_query_forward(at::Tensor centers_coords, + at::Tensor points_coords, const float radius, + const int num_neighbors) { + CHECK_CUDA(centers_coords); + CHECK_CUDA(points_coords); + CHECK_CONTIGUOUS(centers_coords); + CHECK_CONTIGUOUS(points_coords); + CHECK_IS_FLOAT(centers_coords); + CHECK_IS_FLOAT(points_coords); + + int b = centers_coords.size(0); + int m = centers_coords.size(2); + int n = points_coords.size(2); + + at::Tensor neighbors_indices = torch::zeros( + {b, m, num_neighbors}, + at::device(centers_coords.device()).dtype(at::ScalarType::Int)); + + ball_query(b, n, m, radius * radius, num_neighbors, + centers_coords.data_ptr(), + points_coords.data_ptr(), + neighbors_indices.data_ptr()); + + return neighbors_indices; +} diff --git a/third_party/pvcnn/functional/src/ball_query/ball_query.cu b/third_party/pvcnn/functional/src/ball_query/ball_query.cu new file mode 100644 index 0000000..079e3cb --- /dev/null +++ b/third_party/pvcnn/functional/src/ball_query/ball_query.cu @@ -0,0 +1,59 @@ +#include +#include +#include + +#include "../cuda_utils.cuh" + +/* + Function: ball query + Args: + b : batch size + n : number of points in point clouds + m : number of query centers + r2 : ball query radius ** 2 + u : maximum number of neighbors + centers_coords: coordinates of centers, FloatTensor[b, 3, m] + points_coords : coordinates of points, FloatTensor[b, 3, n] + neighbors_indices : neighbor indices in points, IntTensor[b, m, u] +*/ +__global__ void ball_query_kernel(int b, int n, int m, float r2, int u, + const float *__restrict__ centers_coords, + const float *__restrict__ points_coords, + int *__restrict__ neighbors_indices) { + int batch_index = blockIdx.x; + int index = threadIdx.x; + int stride = blockDim.x; + points_coords += batch_index * n * 3; + centers_coords += batch_index * m * 3; + neighbors_indices += batch_index * m * u; + + for (int j = index; j < m; j += stride) { + float center_x = centers_coords[j]; + float center_y = centers_coords[j + m]; + float center_z = centers_coords[j + m + m]; + for (int k = 0, cnt = 0; k < n && cnt < u; ++k) { + float dx = center_x - points_coords[k]; + float dy = center_y - points_coords[k + n]; + float dz = center_z - points_coords[k + n + n]; + float d2 = dx * dx + dy * dy + dz * dz; + if (d2 < r2) { + if (cnt == 0) { + for (int v = 0; v < u; ++v) { + neighbors_indices[j * u + v] = k; + } + } + neighbors_indices[j * u + cnt] = k; + ++cnt; + } + } + } +} + +void ball_query(int b, int n, int m, float r2, int u, + const float *centers_coords, const float *points_coords, + int *neighbors_indices) { + ball_query_kernel<<>>( + b, n, m, r2, u, centers_coords, points_coords, neighbors_indices); + CUDA_CHECK_ERRORS(); +} diff --git a/third_party/pvcnn/functional/src/ball_query/ball_query.cuh b/third_party/pvcnn/functional/src/ball_query/ball_query.cuh new file mode 100644 index 0000000..ba32492 --- /dev/null +++ b/third_party/pvcnn/functional/src/ball_query/ball_query.cuh @@ -0,0 +1,8 @@ +#ifndef _BALL_QUERY_CUH +#define _BALL_QUERY_CUH + +void ball_query(int b, int n, int m, float r2, int u, + const float *centers_coords, const float *points_coords, + int *neighbors_indices); + +#endif diff --git a/third_party/pvcnn/functional/src/ball_query/ball_query.hpp b/third_party/pvcnn/functional/src/ball_query/ball_query.hpp new file mode 100644 index 0000000..d87bbd9 --- /dev/null +++ b/third_party/pvcnn/functional/src/ball_query/ball_query.hpp @@ -0,0 +1,10 @@ +#ifndef _BALL_QUERY_HPP +#define _BALL_QUERY_HPP + +#include + +at::Tensor ball_query_forward(at::Tensor centers_coords, + at::Tensor points_coords, const float radius, + const int num_neighbors); + +#endif diff --git a/third_party/pvcnn/functional/src/bindings.cpp b/third_party/pvcnn/functional/src/bindings.cpp new file mode 100644 index 0000000..994e01b --- /dev/null +++ b/third_party/pvcnn/functional/src/bindings.cpp @@ -0,0 +1,37 @@ +#include + +#include "ball_query/ball_query.hpp" +#include "grouping/grouping.hpp" +#include "interpolate/neighbor_interpolate.hpp" +#include "interpolate/trilinear_devox.hpp" +#include "sampling/sampling.hpp" +#include "voxelization/vox.hpp" + +PYBIND11_MODULE(_pvcnn_backend, m) { + m.def("gather_features_forward", &gather_features_forward, + "Gather Centers' Features forward (CUDA)"); + m.def("gather_features_backward", &gather_features_backward, + "Gather Centers' Features backward (CUDA)"); + m.def("furthest_point_sampling", &furthest_point_sampling_forward, + "Furthest Point Sampling (CUDA)"); + m.def("ball_query", &ball_query_forward, "Ball Query (CUDA)"); + m.def("grouping_forward", &grouping_forward, + "Grouping Features forward (CUDA)"); + m.def("grouping_backward", &grouping_backward, + "Grouping Features backward (CUDA)"); + m.def("three_nearest_neighbors_interpolate_forward", + &three_nearest_neighbors_interpolate_forward, + "3 Nearest Neighbors Interpolate forward (CUDA)"); + m.def("three_nearest_neighbors_interpolate_backward", + &three_nearest_neighbors_interpolate_backward, + "3 Nearest Neighbors Interpolate backward (CUDA)"); + + m.def("trilinear_devoxelize_forward", &trilinear_devoxelize_forward, + "Trilinear Devoxelization forward (CUDA)"); + m.def("trilinear_devoxelize_backward", &trilinear_devoxelize_backward, + "Trilinear Devoxelization backward (CUDA)"); + m.def("avg_voxelize_forward", &avg_voxelize_forward, + "Voxelization forward with average pooling (CUDA)"); + m.def("avg_voxelize_backward", &avg_voxelize_backward, + "Voxelization backward (CUDA)"); +} diff --git a/third_party/pvcnn/functional/src/cuda_utils.cuh b/third_party/pvcnn/functional/src/cuda_utils.cuh new file mode 100644 index 0000000..01bf551 --- /dev/null +++ b/third_party/pvcnn/functional/src/cuda_utils.cuh @@ -0,0 +1,39 @@ +#ifndef _CUDA_UTILS_H +#define _CUDA_UTILS_H + +#include +#include +#include + +#include +#include + +#include + +#define MAXIMUM_THREADS 512 + +inline int optimal_num_threads(int work_size) { + const int pow_2 = std::log2(static_cast(work_size)); + return max(min(1 << pow_2, MAXIMUM_THREADS), 1); +} + +inline dim3 optimal_block_config(int x, int y) { + const int x_threads = optimal_num_threads(x); + const int y_threads = + max(min(optimal_num_threads(y), MAXIMUM_THREADS / x_threads), 1); + dim3 block_config(x_threads, y_threads, 1); + return block_config; +} + +#define CUDA_CHECK_ERRORS() \ + { \ + cudaError_t err = cudaGetLastError(); \ + if (cudaSuccess != err) { \ + fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ + cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ + __FILE__); \ + exit(-1); \ + } \ + } + +#endif diff --git a/third_party/pvcnn/functional/src/grouping/grouping.cpp b/third_party/pvcnn/functional/src/grouping/grouping.cpp new file mode 100644 index 0000000..4f97650 --- /dev/null +++ b/third_party/pvcnn/functional/src/grouping/grouping.cpp @@ -0,0 +1,44 @@ +#include "grouping.hpp" +#include "grouping.cuh" + +#include "../utils.hpp" + +at::Tensor grouping_forward(at::Tensor features, at::Tensor indices) { + CHECK_CUDA(features); + CHECK_CUDA(indices); + CHECK_CONTIGUOUS(features); + CHECK_CONTIGUOUS(indices); + CHECK_IS_FLOAT(features); + CHECK_IS_INT(indices); + + int b = features.size(0); + int c = features.size(1); + int n = features.size(2); + int m = indices.size(1); + int u = indices.size(2); + at::Tensor output = torch::zeros( + {b, c, m, u}, at::device(features.device()).dtype(at::ScalarType::Float)); + grouping(b, c, n, m, u, features.data_ptr(), indices.data_ptr(), + output.data_ptr()); + return output; +} + +at::Tensor grouping_backward(at::Tensor grad_y, at::Tensor indices, + const int n) { + CHECK_CUDA(grad_y); + CHECK_CUDA(indices); + CHECK_CONTIGUOUS(grad_y); + CHECK_CONTIGUOUS(indices); + CHECK_IS_FLOAT(grad_y); + CHECK_IS_INT(indices); + + int b = grad_y.size(0); + int c = grad_y.size(1); + int m = indices.size(1); + int u = indices.size(2); + at::Tensor grad_x = torch::zeros( + {b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float)); + grouping_grad(b, c, n, m, u, grad_y.data_ptr(), + indices.data_ptr(), grad_x.data_ptr()); + return grad_x; +} diff --git a/third_party/pvcnn/functional/src/grouping/grouping.cu b/third_party/pvcnn/functional/src/grouping/grouping.cu new file mode 100644 index 0000000..0cf561a --- /dev/null +++ b/third_party/pvcnn/functional/src/grouping/grouping.cu @@ -0,0 +1,85 @@ +#include +#include + +#include "../cuda_utils.cuh" + +/* + Function: grouping features of neighbors (forward) + Args: + b : batch size + c : #channles of features + n : number of points in point clouds + m : number of query centers + u : maximum number of neighbors + features: points' features, FloatTensor[b, c, n] + indices : neighbor indices in points, IntTensor[b, m, u] + out : gathered features, FloatTensor[b, c, m, u] +*/ +__global__ void grouping_kernel(int b, int c, int n, int m, int u, + const float *__restrict__ features, + const int *__restrict__ indices, + float *__restrict__ out) { + int batch_index = blockIdx.x; + features += batch_index * n * c; + indices += batch_index * m * u; + out += batch_index * m * u * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * m; i += stride) { + const int l = i / m; + const int j = i % m; + for (int k = 0; k < u; ++k) { + out[(l * m + j) * u + k] = features[l * n + indices[j * u + k]]; + } + } +} + +void grouping(int b, int c, int n, int m, int u, const float *features, + const int *indices, float *out) { + grouping_kernel<<>>(b, c, n, m, u, features, + indices, out); + CUDA_CHECK_ERRORS(); +} + +/* + Function: grouping features of neighbors (backward) + Args: + b : batch size + c : #channles of features + n : number of points in point clouds + m : number of query centers + u : maximum number of neighbors + grad_y : grad of gathered features, FloatTensor[b, c, m, u] + indices : neighbor indices in points, IntTensor[b, m, u] + grad_x: grad of points' features, FloatTensor[b, c, n] +*/ +__global__ void grouping_grad_kernel(int b, int c, int n, int m, int u, + const float *__restrict__ grad_y, + const int *__restrict__ indices, + float *__restrict__ grad_x) { + int batch_index = blockIdx.x; + grad_y += batch_index * m * u * c; + indices += batch_index * m * u; + grad_x += batch_index * n * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * m; i += stride) { + const int l = i / m; + const int j = i % m; + for (int k = 0; k < u; ++k) { + atomicAdd(grad_x + l * n + indices[j * u + k], + grad_y[(l * m + j) * u + k]); + } + } +} + +void grouping_grad(int b, int c, int n, int m, int u, const float *grad_y, + const int *indices, float *grad_x) { + grouping_grad_kernel<<>>( + b, c, n, m, u, grad_y, indices, grad_x); + CUDA_CHECK_ERRORS(); +} diff --git a/third_party/pvcnn/functional/src/grouping/grouping.cuh b/third_party/pvcnn/functional/src/grouping/grouping.cuh new file mode 100644 index 0000000..c8a114f --- /dev/null +++ b/third_party/pvcnn/functional/src/grouping/grouping.cuh @@ -0,0 +1,9 @@ +#ifndef _GROUPING_CUH +#define _GROUPING_CUH + +void grouping(int b, int c, int n, int m, int u, const float *features, + const int *indices, float *out); +void grouping_grad(int b, int c, int n, int m, int u, const float *grad_y, + const int *indices, float *grad_x); + +#endif \ No newline at end of file diff --git a/third_party/pvcnn/functional/src/grouping/grouping.hpp b/third_party/pvcnn/functional/src/grouping/grouping.hpp new file mode 100644 index 0000000..3f5733d --- /dev/null +++ b/third_party/pvcnn/functional/src/grouping/grouping.hpp @@ -0,0 +1,10 @@ +#ifndef _GROUPING_HPP +#define _GROUPING_HPP + +#include + +at::Tensor grouping_forward(at::Tensor features, at::Tensor indices); +at::Tensor grouping_backward(at::Tensor grad_y, at::Tensor indices, + const int n); + +#endif diff --git a/third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.cpp b/third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.cpp new file mode 100644 index 0000000..fc73c43 --- /dev/null +++ b/third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.cpp @@ -0,0 +1,65 @@ +#include "neighbor_interpolate.hpp" +#include "neighbor_interpolate.cuh" + +#include "../utils.hpp" + +std::vector +three_nearest_neighbors_interpolate_forward(at::Tensor points_coords, + at::Tensor centers_coords, + at::Tensor centers_features) { + CHECK_CUDA(points_coords); + CHECK_CUDA(centers_coords); + CHECK_CUDA(centers_features); + CHECK_CONTIGUOUS(points_coords); + CHECK_CONTIGUOUS(centers_coords); + CHECK_CONTIGUOUS(centers_features); + CHECK_IS_FLOAT(points_coords); + CHECK_IS_FLOAT(centers_coords); + CHECK_IS_FLOAT(centers_features); + + int b = centers_features.size(0); + int c = centers_features.size(1); + int m = centers_features.size(2); + int n = points_coords.size(2); + + at::Tensor indices = torch::zeros( + {b, 3, n}, at::device(points_coords.device()).dtype(at::ScalarType::Int)); + at::Tensor weights = torch::zeros( + {b, 3, n}, + at::device(points_coords.device()).dtype(at::ScalarType::Float)); + at::Tensor output = torch::zeros( + {b, c, n}, + at::device(centers_features.device()).dtype(at::ScalarType::Float)); + + three_nearest_neighbors_interpolate( + b, c, m, n, points_coords.data_ptr(), + centers_coords.data_ptr(), centers_features.data_ptr(), + indices.data_ptr(), weights.data_ptr(), + output.data_ptr()); + return {output, indices, weights}; +} + +at::Tensor three_nearest_neighbors_interpolate_backward(at::Tensor grad_y, + at::Tensor indices, + at::Tensor weights, + const int m) { + CHECK_CUDA(grad_y); + CHECK_CUDA(indices); + CHECK_CUDA(weights); + CHECK_CONTIGUOUS(grad_y); + CHECK_CONTIGUOUS(indices); + CHECK_CONTIGUOUS(weights); + CHECK_IS_FLOAT(grad_y); + CHECK_IS_INT(indices); + CHECK_IS_FLOAT(weights); + + int b = grad_y.size(0); + int c = grad_y.size(1); + int n = grad_y.size(2); + at::Tensor grad_x = torch::zeros( + {b, c, m}, at::device(grad_y.device()).dtype(at::ScalarType::Float)); + three_nearest_neighbors_interpolate_grad( + b, c, n, m, grad_y.data_ptr(), indices.data_ptr(), + weights.data_ptr(), grad_x.data_ptr()); + return grad_x; +} diff --git a/third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.cu b/third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.cu new file mode 100644 index 0000000..8168507 --- /dev/null +++ b/third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.cu @@ -0,0 +1,181 @@ +#include +#include +#include + +#include "../cuda_utils.cuh" + +/* + Function: three nearest neighbors + Args: + b : batch size + n : number of points in point clouds + m : number of query centers + points_coords : coordinates of points, FloatTensor[b, 3, n] + centers_coords: coordinates of centers, FloatTensor[b, 3, m] + weights : weights of nearest 3 centers to the point, + FloatTensor[b, 3, n] + indices : indices of nearest 3 centers to the point, + IntTensor[b, 3, n] +*/ +__global__ void three_nearest_neighbors_kernel( + int b, int n, int m, const float *__restrict__ points_coords, + const float *__restrict__ centers_coords, float *__restrict__ weights, + int *__restrict__ indices) { + int batch_index = blockIdx.x; + int index = threadIdx.x; + int stride = blockDim.x; + points_coords += batch_index * 3 * n; + weights += batch_index * 3 * n; + indices += batch_index * 3 * n; + centers_coords += batch_index * 3 * m; + + for (int j = index; j < n; j += stride) { + float ux = points_coords[j]; + float uy = points_coords[j + n]; + float uz = points_coords[j + n + n]; + + double best0 = 1e40, best1 = 1e40, best2 = 1e40; + int besti0 = 0, besti1 = 0, besti2 = 0; + for (int k = 0; k < m; ++k) { + float x = centers_coords[k]; + float y = centers_coords[k + m]; + float z = centers_coords[k + m + m]; + float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); + if (d < best2) { + best2 = d; + besti2 = k; + if (d < best1) { + best2 = best1; + besti2 = besti1; + best1 = d; + besti1 = k; + if (d < best0) { + best1 = best0; + besti1 = besti0; + best0 = d; + besti0 = k; + } + } + } + } + best0 = max(min(1e10f, best0), 1e-10f); + best1 = max(min(1e10f, best1), 1e-10f); + best2 = max(min(1e10f, best2), 1e-10f); + float d0d1 = best0 * best1; + float d0d2 = best0 * best2; + float d1d2 = best1 * best2; + float d0d1d2 = 1.0f / (d0d1 + d0d2 + d1d2); + weights[j] = d1d2 * d0d1d2; + indices[j] = besti0; + weights[j + n] = d0d2 * d0d1d2; + indices[j + n] = besti1; + weights[j + n + n] = d0d1 * d0d1d2; + indices[j + n + n] = besti2; + } +} + +/* + Function: interpolate three nearest neighbors (forward) + Args: + b : batch size + c : #channels of features + m : number of query centers + n : number of points in point clouds + centers_features: features of centers, FloatTensor[b, c, m] + indices : indices of nearest 3 centers to the point, + IntTensor[b, 3, n] + weights : weights for interpolation, FloatTensor[b, 3, n] + out : features of points, FloatTensor[b, c, n] +*/ +__global__ void three_nearest_neighbors_interpolate_kernel( + int b, int c, int m, int n, const float *__restrict__ centers_features, + const int *__restrict__ indices, const float *__restrict__ weights, + float *__restrict__ out) { + int batch_index = blockIdx.x; + centers_features += batch_index * m * c; + indices += batch_index * n * 3; + weights += batch_index * n * 3; + out += batch_index * n * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * n; i += stride) { + const int l = i / n; + const int j = i % n; + float w1 = weights[j]; + float w2 = weights[j + n]; + float w3 = weights[j + n + n]; + int i1 = indices[j]; + int i2 = indices[j + n]; + int i3 = indices[j + n + n]; + + out[i] = centers_features[l * m + i1] * w1 + + centers_features[l * m + i2] * w2 + + centers_features[l * m + i3] * w3; + } +} + +void three_nearest_neighbors_interpolate(int b, int c, int m, int n, + const float *points_coords, + const float *centers_coords, + const float *centers_features, + int *indices, float *weights, + float *out) { + three_nearest_neighbors_kernel<<>>( + b, n, m, points_coords, centers_coords, weights, indices); + three_nearest_neighbors_interpolate_kernel<<< + b, optimal_block_config(n, c), 0, at::cuda::getCurrentCUDAStream()>>>( + b, c, m, n, centers_features, indices, weights, out); + CUDA_CHECK_ERRORS(); +} + +/* + Function: interpolate three nearest neighbors (backward) + Args: + b : batch size + c : #channels of features + m : number of query centers + n : number of points in point clouds + grad_y : grad of features of points, FloatTensor[b, c, n] + indices : indices of nearest 3 centers to the point, IntTensor[b, 3, n] + weights : weights for interpolation, FloatTensor[b, 3, n] + grad_x : grad of features of centers, FloatTensor[b, c, m] +*/ +__global__ void three_nearest_neighbors_interpolate_grad_kernel( + int b, int c, int n, int m, const float *__restrict__ grad_y, + const int *__restrict__ indices, const float *__restrict__ weights, + float *__restrict__ grad_x) { + int batch_index = blockIdx.x; + grad_y += batch_index * n * c; + indices += batch_index * n * 3; + weights += batch_index * n * 3; + grad_x += batch_index * m * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * n; i += stride) { + const int l = i / n; + const int j = i % n; + float w1 = weights[j]; + float w2 = weights[j + n]; + float w3 = weights[j + n + n]; + int i1 = indices[j]; + int i2 = indices[j + n]; + int i3 = indices[j + n + n]; + atomicAdd(grad_x + l * m + i1, grad_y[i] * w1); + atomicAdd(grad_x + l * m + i2, grad_y[i] * w2); + atomicAdd(grad_x + l * m + i3, grad_y[i] * w3); + } +} + +void three_nearest_neighbors_interpolate_grad(int b, int c, int n, int m, + const float *grad_y, + const int *indices, + const float *weights, + float *grad_x) { + three_nearest_neighbors_interpolate_grad_kernel<<< + b, optimal_block_config(n, c), 0, at::cuda::getCurrentCUDAStream()>>>( + b, c, n, m, grad_y, indices, weights, grad_x); + CUDA_CHECK_ERRORS(); +} diff --git a/third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.cuh b/third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.cuh new file mode 100644 index 0000000..a15f37e --- /dev/null +++ b/third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.cuh @@ -0,0 +1,16 @@ +#ifndef _NEIGHBOR_INTERPOLATE_CUH +#define _NEIGHBOR_INTERPOLATE_CUH + +void three_nearest_neighbors_interpolate(int b, int c, int m, int n, + const float *points_coords, + const float *centers_coords, + const float *centers_features, + int *indices, float *weights, + float *out); +void three_nearest_neighbors_interpolate_grad(int b, int c, int n, int m, + const float *grad_y, + const int *indices, + const float *weights, + float *grad_x); + +#endif diff --git a/third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.hpp b/third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.hpp new file mode 100644 index 0000000..cdc7835 --- /dev/null +++ b/third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.hpp @@ -0,0 +1,16 @@ +#ifndef _NEIGHBOR_INTERPOLATE_HPP +#define _NEIGHBOR_INTERPOLATE_HPP + +#include +#include + +std::vector +three_nearest_neighbors_interpolate_forward(at::Tensor points_coords, + at::Tensor centers_coords, + at::Tensor centers_features); +at::Tensor three_nearest_neighbors_interpolate_backward(at::Tensor grad_y, + at::Tensor indices, + at::Tensor weights, + const int m); + +#endif diff --git a/third_party/pvcnn/functional/src/interpolate/trilinear_devox.cpp b/third_party/pvcnn/functional/src/interpolate/trilinear_devox.cpp new file mode 100644 index 0000000..a8ff4fc --- /dev/null +++ b/third_party/pvcnn/functional/src/interpolate/trilinear_devox.cpp @@ -0,0 +1,91 @@ +#include "trilinear_devox.hpp" +#include "trilinear_devox.cuh" + +#include "../utils.hpp" + +/* + Function: trilinear devoxelization (forward) + Args: + r : voxel resolution + trainig : whether is training mode + coords : the coordinates of points, FloatTensor[b, 3, n] + features : features, FloatTensor[b, c, s], s = r ** 3 + Return: + outs : outputs, FloatTensor[b, c, n] + inds : the voxel coordinates of point cube, IntTensor[b, 8, n] + wgts : weight for trilinear interpolation, FloatTensor[b, 8, n] +*/ +std::vector +trilinear_devoxelize_forward(const int r, const bool is_training, + const at::Tensor coords, + const at::Tensor features) { + CHECK_CUDA(features); + CHECK_CUDA(coords); + CHECK_CONTIGUOUS(features); + CHECK_CONTIGUOUS(coords); + CHECK_IS_FLOAT(features); + CHECK_IS_FLOAT(coords); + + int b = features.size(0); + int c = features.size(1); + int n = coords.size(2); + int r2 = r * r; + int r3 = r2 * r; + at::Tensor outs = torch::zeros( + {b, c, n}, at::device(features.device()).dtype(at::ScalarType::Float)); + if (is_training) { + at::Tensor inds = torch::zeros( + {b, 8, n}, at::device(features.device()).dtype(at::ScalarType::Int)); + at::Tensor wgts = torch::zeros( + {b, 8, n}, at::device(features.device()).dtype(at::ScalarType::Float)); + trilinear_devoxelize(b, c, n, r, r2, r3, true, coords.data_ptr(), + features.data_ptr(), inds.data_ptr(), + wgts.data_ptr(), outs.data_ptr()); + return {outs, inds, wgts}; + } else { + at::Tensor inds = torch::zeros( + {1}, at::device(features.device()).dtype(at::ScalarType::Int)); + at::Tensor wgts = torch::zeros( + {1}, at::device(features.device()).dtype(at::ScalarType::Float)); + trilinear_devoxelize(b, c, n, r, r2, r3, false, coords.data_ptr(), + features.data_ptr(), inds.data_ptr(), + wgts.data_ptr(), outs.data_ptr()); + return {outs, inds, wgts}; + } +} + +/* + Function: trilinear devoxelization (backward) + Args: + grad_y : grad outputs, FloatTensor[b, c, n] + indices : the voxel coordinates of point cube, IntTensor[b, 8, n] + weights : weight for trilinear interpolation, FloatTensor[b, 8, n] + r : voxel resolution + Return: + grad_x : grad inputs, FloatTensor[b, c, s], s = r ** 3 +*/ +at::Tensor trilinear_devoxelize_backward(const at::Tensor grad_y, + const at::Tensor indices, + const at::Tensor weights, + const int r) { + CHECK_CUDA(grad_y); + CHECK_CUDA(weights); + CHECK_CUDA(indices); + CHECK_CONTIGUOUS(grad_y); + CHECK_CONTIGUOUS(weights); + CHECK_CONTIGUOUS(indices); + CHECK_IS_FLOAT(grad_y); + CHECK_IS_FLOAT(weights); + CHECK_IS_INT(indices); + + int b = grad_y.size(0); + int c = grad_y.size(1); + int n = grad_y.size(2); + int r3 = r * r * r; + at::Tensor grad_x = torch::zeros( + {b, c, r3}, at::device(grad_y.device()).dtype(at::ScalarType::Float)); + trilinear_devoxelize_grad(b, c, n, r3, indices.data_ptr(), + weights.data_ptr(), grad_y.data_ptr(), + grad_x.data_ptr()); + return grad_x; +} diff --git a/third_party/pvcnn/functional/src/interpolate/trilinear_devox.cu b/third_party/pvcnn/functional/src/interpolate/trilinear_devox.cu new file mode 100644 index 0000000..4e1e50c --- /dev/null +++ b/third_party/pvcnn/functional/src/interpolate/trilinear_devox.cu @@ -0,0 +1,178 @@ +#include +#include + +#include "../cuda_utils.cuh" + +/* + Function: trilinear devoxlization (forward) + Args: + b : batch size + c : #channels + n : number of points + r : voxel resolution + r2 : r ** 2 + r3 : r ** 3 + coords : the coordinates of points, FloatTensor[b, 3, n] + feat : features, FloatTensor[b, c, r3] + inds : the voxel indices of point cube, IntTensor[b, 8, n] + wgts : weight for trilinear interpolation, FloatTensor[b, 8, n] + outs : outputs, FloatTensor[b, c, n] +*/ +__global__ void trilinear_devoxelize_kernel(int b, int c, int n, int r, int r2, + int r3, bool is_training, + const float *__restrict__ coords, + const float *__restrict__ feat, + int *__restrict__ inds, + float *__restrict__ wgts, + float *__restrict__ outs) { + int batch_index = blockIdx.x; + int stride = blockDim.x; + int index = threadIdx.x; + coords += batch_index * n * 3; + inds += batch_index * n * 8; + wgts += batch_index * n * 8; + feat += batch_index * c * r3; + outs += batch_index * c * n; + + for (int i = index; i < n; i += stride) { + float x = coords[i]; + float y = coords[i + n]; + float z = coords[i + n + n]; + float x_lo_f = floorf(x); + float y_lo_f = floorf(y); + float z_lo_f = floorf(z); + + float x_d_1 = x - x_lo_f; // / (x_hi_f - x_lo_f + 1e-8f) + float y_d_1 = y - y_lo_f; + float z_d_1 = z - z_lo_f; + float x_d_0 = 1.0f - x_d_1; + float y_d_0 = 1.0f - y_d_1; + float z_d_0 = 1.0f - z_d_1; + + float wgt000 = x_d_0 * y_d_0 * z_d_0; + float wgt001 = x_d_0 * y_d_0 * z_d_1; + float wgt010 = x_d_0 * y_d_1 * z_d_0; + float wgt011 = x_d_0 * y_d_1 * z_d_1; + float wgt100 = x_d_1 * y_d_0 * z_d_0; + float wgt101 = x_d_1 * y_d_0 * z_d_1; + float wgt110 = x_d_1 * y_d_1 * z_d_0; + float wgt111 = x_d_1 * y_d_1 * z_d_1; + + int x_lo = static_cast(x_lo_f); + int y_lo = static_cast(y_lo_f); + int z_lo = static_cast(z_lo_f); + int x_hi = (x_d_1 > 0) ? -1 : 0; + int y_hi = (y_d_1 > 0) ? -1 : 0; + int z_hi = (z_d_1 > 0) ? 1 : 0; + + int idx000 = x_lo * r2 + y_lo * r + z_lo; + int idx001 = idx000 + z_hi; // x_lo * r2 + y_lo * r + z_hi; + int idx010 = idx000 + (y_hi & r); // x_lo * r2 + y_hi * r + z_lo; + int idx011 = idx010 + z_hi; // x_lo * r2 + y_hi * r + z_hi; + int idx100 = idx000 + (x_hi & r2); // x_hi * r2 + y_lo * r + z_lo; + int idx101 = idx100 + z_hi; // x_hi * r2 + y_lo * r + z_hi; + int idx110 = idx100 + (y_hi & r); // x_hi * r2 + y_hi * r + z_lo; + int idx111 = idx110 + z_hi; // x_hi * r2 + y_hi * r + z_hi; + + if (is_training) { + wgts[i] = wgt000; + wgts[i + n] = wgt001; + wgts[i + n * 2] = wgt010; + wgts[i + n * 3] = wgt011; + wgts[i + n * 4] = wgt100; + wgts[i + n * 5] = wgt101; + wgts[i + n * 6] = wgt110; + wgts[i + n * 7] = wgt111; + inds[i] = idx000; + inds[i + n] = idx001; + inds[i + n * 2] = idx010; + inds[i + n * 3] = idx011; + inds[i + n * 4] = idx100; + inds[i + n * 5] = idx101; + inds[i + n * 6] = idx110; + inds[i + n * 7] = idx111; + } + + for (int j = 0; j < c; j++) { + int jr3 = j * r3; + outs[j * n + i] = + wgt000 * feat[jr3 + idx000] + wgt001 * feat[jr3 + idx001] + + wgt010 * feat[jr3 + idx010] + wgt011 * feat[jr3 + idx011] + + wgt100 * feat[jr3 + idx100] + wgt101 * feat[jr3 + idx101] + + wgt110 * feat[jr3 + idx110] + wgt111 * feat[jr3 + idx111]; + } + } +} + +/* + Function: trilinear devoxlization (backward) + Args: + b : batch size + c : #channels + n : number of points + r3 : voxel cube size = voxel resolution ** 3 + inds : the voxel indices of point cube, IntTensor[b, 8, n] + wgts : weight for trilinear interpolation, FloatTensor[b, 8, n] + grad_y : grad outputs, FloatTensor[b, c, n] + grad_x : grad inputs, FloatTensor[b, c, r3] +*/ +__global__ void trilinear_devoxelize_grad_kernel( + int b, int c, int n, int r3, const int *__restrict__ inds, + const float *__restrict__ wgts, const float *__restrict__ grad_y, + float *__restrict__ grad_x) { + int batch_index = blockIdx.x; + int stride = blockDim.x; + int index = threadIdx.x; + inds += batch_index * n * 8; + wgts += batch_index * n * 8; + grad_x += batch_index * c * r3; + grad_y += batch_index * c * n; + + for (int i = index; i < n; i += stride) { + int idx000 = inds[i]; + int idx001 = inds[i + n]; + int idx010 = inds[i + n * 2]; + int idx011 = inds[i + n * 3]; + int idx100 = inds[i + n * 4]; + int idx101 = inds[i + n * 5]; + int idx110 = inds[i + n * 6]; + int idx111 = inds[i + n * 7]; + float wgt000 = wgts[i]; + float wgt001 = wgts[i + n]; + float wgt010 = wgts[i + n * 2]; + float wgt011 = wgts[i + n * 3]; + float wgt100 = wgts[i + n * 4]; + float wgt101 = wgts[i + n * 5]; + float wgt110 = wgts[i + n * 6]; + float wgt111 = wgts[i + n * 7]; + + for (int j = 0; j < c; j++) { + int jr3 = j * r3; + float g = grad_y[j * n + i]; + atomicAdd(grad_x + jr3 + idx000, wgt000 * g); + atomicAdd(grad_x + jr3 + idx001, wgt001 * g); + atomicAdd(grad_x + jr3 + idx010, wgt010 * g); + atomicAdd(grad_x + jr3 + idx011, wgt011 * g); + atomicAdd(grad_x + jr3 + idx100, wgt100 * g); + atomicAdd(grad_x + jr3 + idx101, wgt101 * g); + atomicAdd(grad_x + jr3 + idx110, wgt110 * g); + atomicAdd(grad_x + jr3 + idx111, wgt111 * g); + } + } +} + +void trilinear_devoxelize(int b, int c, int n, int r, int r2, int r3, + bool training, const float *coords, const float *feat, + int *inds, float *wgts, float *outs) { + trilinear_devoxelize_kernel<<>>( + b, c, n, r, r2, r3, training, coords, feat, inds, wgts, outs); + CUDA_CHECK_ERRORS(); +} + +void trilinear_devoxelize_grad(int b, int c, int n, int r3, const int *inds, + const float *wgts, const float *grad_y, + float *grad_x) { + trilinear_devoxelize_grad_kernel<<>>( + b, c, n, r3, inds, wgts, grad_y, grad_x); + CUDA_CHECK_ERRORS(); +} diff --git a/third_party/pvcnn/functional/src/interpolate/trilinear_devox.cuh b/third_party/pvcnn/functional/src/interpolate/trilinear_devox.cuh new file mode 100644 index 0000000..8aadbaf --- /dev/null +++ b/third_party/pvcnn/functional/src/interpolate/trilinear_devox.cuh @@ -0,0 +1,13 @@ +#ifndef _TRILINEAR_DEVOX_CUH +#define _TRILINEAR_DEVOX_CUH + +// CUDA function declarations +void trilinear_devoxelize(int b, int c, int n, int r, int r2, int r3, + bool is_training, const float *coords, + const float *feat, int *inds, float *wgts, + float *outs); +void trilinear_devoxelize_grad(int b, int c, int n, int r3, const int *inds, + const float *wgts, const float *grad_y, + float *grad_x); + +#endif diff --git a/third_party/pvcnn/functional/src/interpolate/trilinear_devox.hpp b/third_party/pvcnn/functional/src/interpolate/trilinear_devox.hpp new file mode 100644 index 0000000..a9d6795 --- /dev/null +++ b/third_party/pvcnn/functional/src/interpolate/trilinear_devox.hpp @@ -0,0 +1,16 @@ +#ifndef _TRILINEAR_DEVOX_HPP +#define _TRILINEAR_DEVOX_HPP + +#include +#include + +std::vector trilinear_devoxelize_forward(const int r, + const bool is_training, + const at::Tensor coords, + const at::Tensor features); + +at::Tensor trilinear_devoxelize_backward(const at::Tensor grad_y, + const at::Tensor indices, + const at::Tensor weights, const int r); + +#endif diff --git a/third_party/pvcnn/functional/src/sampling/sampling.cpp b/third_party/pvcnn/functional/src/sampling/sampling.cpp new file mode 100644 index 0000000..9b8ca6e --- /dev/null +++ b/third_party/pvcnn/functional/src/sampling/sampling.cpp @@ -0,0 +1,58 @@ +#include "sampling.hpp" +#include "sampling.cuh" + +#include "../utils.hpp" + +at::Tensor gather_features_forward(at::Tensor features, at::Tensor indices) { + CHECK_CUDA(features); + CHECK_CUDA(indices); + CHECK_CONTIGUOUS(features); + CHECK_CONTIGUOUS(indices); + CHECK_IS_FLOAT(features); + CHECK_IS_INT(indices); + + int b = features.size(0); + int c = features.size(1); + int n = features.size(2); + int m = indices.size(1); + at::Tensor output = torch::zeros( + {b, c, m}, at::device(features.device()).dtype(at::ScalarType::Float)); + gather_features(b, c, n, m, features.data_ptr(), + indices.data_ptr(), output.data_ptr()); + return output; +} + +at::Tensor gather_features_backward(at::Tensor grad_y, at::Tensor indices, + const int n) { + CHECK_CUDA(grad_y); + CHECK_CUDA(indices); + CHECK_CONTIGUOUS(grad_y); + CHECK_CONTIGUOUS(indices); + CHECK_IS_FLOAT(grad_y); + CHECK_IS_INT(indices); + + int b = grad_y.size(0); + int c = grad_y.size(1); + at::Tensor grad_x = torch::zeros( + {b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float)); + gather_features_grad(b, c, n, indices.size(1), grad_y.data_ptr(), + indices.data_ptr(), grad_x.data_ptr()); + return grad_x; +} + +at::Tensor furthest_point_sampling_forward(at::Tensor coords, + const int num_samples) { + CHECK_CUDA(coords); + CHECK_CONTIGUOUS(coords); + CHECK_IS_FLOAT(coords); + + int b = coords.size(0); + int n = coords.size(2); + at::Tensor indices = torch::zeros( + {b, num_samples}, at::device(coords.device()).dtype(at::ScalarType::Int)); + at::Tensor distances = torch::full( + {b, n}, 1e38f, at::device(coords.device()).dtype(at::ScalarType::Float)); + furthest_point_sampling(b, n, num_samples, coords.data_ptr(), + distances.data_ptr(), indices.data_ptr()); + return indices; +} diff --git a/third_party/pvcnn/functional/src/sampling/sampling.cu b/third_party/pvcnn/functional/src/sampling/sampling.cu new file mode 100644 index 0000000..06bc0ee --- /dev/null +++ b/third_party/pvcnn/functional/src/sampling/sampling.cu @@ -0,0 +1,174 @@ +#include +#include + +#include "../cuda_utils.cuh" + +/* + Function: gather centers' features (forward) + Args: + b : batch size + c : #channles of features + n : number of points in point clouds + m : number of query/sampled centers + features: points' features, FloatTensor[b, c, n] + indices : centers' indices in points, IntTensor[b, m] + out : gathered features, FloatTensor[b, c, m] +*/ +__global__ void gather_features_kernel(int b, int c, int n, int m, + const float *__restrict__ features, + const int *__restrict__ indices, + float *__restrict__ out) { + int batch_index = blockIdx.x; + int channel_index = blockIdx.y; + int temp_index = batch_index * c + channel_index; + features += temp_index * n; + indices += batch_index * m; + out += temp_index * m; + + for (int j = threadIdx.x; j < m; j += blockDim.x) { + out[j] = features[indices[j]]; + } +} + +void gather_features(int b, int c, int n, int m, const float *features, + const int *indices, float *out) { + gather_features_kernel<<>>( + b, c, n, m, features, indices, out); + CUDA_CHECK_ERRORS(); +} + +/* + Function: gather centers' features (backward) + Args: + b : batch size + c : #channles of features + n : number of points in point clouds + m : number of query/sampled centers + grad_y : grad of gathered features, FloatTensor[b, c, m] + indices : centers' indices in points, IntTensor[b, m] + grad_x : grad of points' features, FloatTensor[b, c, n] +*/ +__global__ void gather_features_grad_kernel(int b, int c, int n, int m, + const float *__restrict__ grad_y, + const int *__restrict__ indices, + float *__restrict__ grad_x) { + int batch_index = blockIdx.x; + int channel_index = blockIdx.y; + int temp_index = batch_index * c + channel_index; + grad_y += temp_index * m; + indices += batch_index * m; + grad_x += temp_index * n; + + for (int j = threadIdx.x; j < m; j += blockDim.x) { + atomicAdd(grad_x + indices[j], grad_y[j]); + } +} + +void gather_features_grad(int b, int c, int n, int m, const float *grad_y, + const int *indices, float *grad_x) { + gather_features_grad_kernel<<>>( + b, c, n, m, grad_y, indices, grad_x); + CUDA_CHECK_ERRORS(); +} + +/* + Function: furthest point sampling + Args: + b : batch size + n : number of points in point clouds + m : number of query/sampled centers + coords : points' coords, FloatTensor[b, 3, n] + distances : minimum distance of a point to the set, IntTensor[b, n] + indices : sampled centers' indices in points, IntTensor[b, m] +*/ +__global__ void furthest_point_sampling_kernel(int b, int n, int m, + const float *__restrict__ coords, + float *__restrict__ distances, + int *__restrict__ indices) { + if (m <= 0) + return; + int batch_index = blockIdx.x; + coords += batch_index * n * 3; + distances += batch_index * n; + indices += batch_index * m; + + const int BlockSize = 512; + __shared__ float dists[BlockSize]; + __shared__ int dists_i[BlockSize]; + const int BufferSize = 3072; + __shared__ float buf[BufferSize * 3]; + + int old = 0; + if (threadIdx.x == 0) + indices[0] = old; + + for (int j = threadIdx.x; j < min(BufferSize, n); j += blockDim.x) { + buf[j] = coords[j]; + buf[j + BufferSize] = coords[j + n]; + buf[j + BufferSize + BufferSize] = coords[j + n + n]; + } + __syncthreads(); + + for (int j = 1; j < m; j++) { + int besti = 0; // best index + float best = -1; // farthest distance + // calculating the distance with the latest sampled point + float x1 = coords[old]; + float y1 = coords[old + n]; + float z1 = coords[old + n + n]; + for (int k = threadIdx.x; k < n; k += blockDim.x) { + // fetch distance at block n, thread k + float td = distances[k]; + float x2, y2, z2; + if (k < BufferSize) { + x2 = buf[k]; + y2 = buf[k + BufferSize]; + z2 = buf[k + BufferSize + BufferSize]; + } else { + x2 = coords[k]; + y2 = coords[k + n]; + z2 = coords[k + n + n]; + } + float d = + (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); + float d2 = min(d, td); + // update "point-to-set" distance + if (d2 != td) + distances[k] = d2; + // update the farthest distance at sample step j + if (d2 > best) { + best = d2; + besti = k; + } + } + + dists[threadIdx.x] = best; + dists_i[threadIdx.x] = besti; + for (int u = 0; (1 << u) < blockDim.x; u++) { + __syncthreads(); + if (threadIdx.x < (blockDim.x >> (u + 1))) { + int i1 = (threadIdx.x * 2) << u; + int i2 = (threadIdx.x * 2 + 1) << u; + if (dists[i1] < dists[i2]) { + dists[i1] = dists[i2]; + dists_i[i1] = dists_i[i2]; + } + } + } + __syncthreads(); + + // finish sample step j; old is the sampled index + old = dists_i[0]; + if (threadIdx.x == 0) + indices[j] = old; + } +} + +void furthest_point_sampling(int b, int n, int m, const float *coords, + float *distances, int *indices) { + furthest_point_sampling_kernel<<>>(b, n, m, coords, distances, + indices); + CUDA_CHECK_ERRORS(); +} diff --git a/third_party/pvcnn/functional/src/sampling/sampling.cuh b/third_party/pvcnn/functional/src/sampling/sampling.cuh new file mode 100644 index 0000000..e68358f --- /dev/null +++ b/third_party/pvcnn/functional/src/sampling/sampling.cuh @@ -0,0 +1,11 @@ +#ifndef _SAMPLING_CUH +#define _SAMPLING_CUH + +void gather_features(int b, int c, int n, int m, const float *features, + const int *indices, float *out); +void gather_features_grad(int b, int c, int n, int m, const float *grad_y, + const int *indices, float *grad_x); +void furthest_point_sampling(int b, int n, int m, const float *coords, + float *distances, int *indices); + +#endif diff --git a/third_party/pvcnn/functional/src/sampling/sampling.hpp b/third_party/pvcnn/functional/src/sampling/sampling.hpp new file mode 100644 index 0000000..db2a5c8 --- /dev/null +++ b/third_party/pvcnn/functional/src/sampling/sampling.hpp @@ -0,0 +1,12 @@ +#ifndef _SAMPLING_HPP +#define _SAMPLING_HPP + +#include + +at::Tensor gather_features_forward(at::Tensor features, at::Tensor indices); +at::Tensor gather_features_backward(at::Tensor grad_y, at::Tensor indices, + const int n); +at::Tensor furthest_point_sampling_forward(at::Tensor coords, + const int num_samples); + +#endif diff --git a/third_party/pvcnn/functional/src/utils.hpp b/third_party/pvcnn/functional/src/utils.hpp new file mode 100644 index 0000000..f4f21a0 --- /dev/null +++ b/third_party/pvcnn/functional/src/utils.hpp @@ -0,0 +1,20 @@ +#ifndef _UTILS_HPP +#define _UTILS_HPP + +#include +#include + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") + +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") + +#define CHECK_IS_INT(x) \ + TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ + #x " must be an int tensor") + +#define CHECK_IS_FLOAT(x) \ + TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \ + #x " must be a float tensor") + +#endif diff --git a/third_party/pvcnn/functional/src/voxelization/vox.cpp b/third_party/pvcnn/functional/src/voxelization/vox.cpp new file mode 100644 index 0000000..6a84594 --- /dev/null +++ b/third_party/pvcnn/functional/src/voxelization/vox.cpp @@ -0,0 +1,76 @@ +#include "vox.hpp" +#include "vox.cuh" + +#include "../utils.hpp" + +/* + Function: average pool voxelization (forward) + Args: + features: features, FloatTensor[b, c, n] + coords : coords of each point, IntTensor[b, 3, n] + resolution : voxel resolution + Return: + out : outputs, FloatTensor[b, c, s], s = r ** 3 + ind : voxel index of each point, IntTensor[b, n] + cnt : #points in each voxel index, IntTensor[b, s] +*/ +std::vector avg_voxelize_forward(const at::Tensor features, + const at::Tensor coords, + const int resolution) { + CHECK_CUDA(features); + CHECK_CUDA(coords); + CHECK_CONTIGUOUS(features); + CHECK_CONTIGUOUS(coords); + CHECK_IS_FLOAT(features); + CHECK_IS_INT(coords); + + int b = features.size(0); + int c = features.size(1); + int n = features.size(2); + int r = resolution; + int r2 = r * r; + int r3 = r2 * r; + at::Tensor ind = torch::zeros( + {b, n}, at::device(features.device()).dtype(at::ScalarType::Int)); + at::Tensor out = torch::zeros( + {b, c, r3}, at::device(features.device()).dtype(at::ScalarType::Float)); + at::Tensor cnt = torch::zeros( + {b, r3}, at::device(features.device()).dtype(at::ScalarType::Int)); + avg_voxelize(b, c, n, r, r2, r3, coords.data_ptr(), + features.data_ptr(), ind.data_ptr(), + cnt.data_ptr(), out.data_ptr()); + return {out, ind, cnt}; +} + +/* + Function: average pool voxelization (backward) + Args: + grad_y : grad outputs, FloatTensor[b, c, s] + indices: voxel index of each point, IntTensor[b, n] + cnt : #points in each voxel index, IntTensor[b, s] + Return: + grad_x : grad inputs, FloatTensor[b, c, n] +*/ +at::Tensor avg_voxelize_backward(const at::Tensor grad_y, + const at::Tensor indices, + const at::Tensor cnt) { + CHECK_CUDA(grad_y); + CHECK_CUDA(indices); + CHECK_CUDA(cnt); + CHECK_CONTIGUOUS(grad_y); + CHECK_CONTIGUOUS(indices); + CHECK_CONTIGUOUS(cnt); + CHECK_IS_FLOAT(grad_y); + CHECK_IS_INT(indices); + CHECK_IS_INT(cnt); + + int b = grad_y.size(0); + int c = grad_y.size(1); + int s = grad_y.size(2); + int n = indices.size(1); + at::Tensor grad_x = torch::zeros( + {b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float)); + avg_voxelize_grad(b, c, n, s, indices.data_ptr(), cnt.data_ptr(), + grad_y.data_ptr(), grad_x.data_ptr()); + return grad_x; +} diff --git a/third_party/pvcnn/functional/src/voxelization/vox.cu b/third_party/pvcnn/functional/src/voxelization/vox.cu new file mode 100644 index 0000000..1c1a2c9 --- /dev/null +++ b/third_party/pvcnn/functional/src/voxelization/vox.cu @@ -0,0 +1,126 @@ +#include +#include + +#include "../cuda_utils.cuh" + +/* + Function: get how many points in each voxel grid + Args: + b : batch size + n : number of points + r : voxel resolution + r2 : = r * r + r3 : s, voxel cube size = r ** 3 + coords : coords of each point, IntTensor[b, 3, n] + ind : voxel index of each point, IntTensor[b, n] + cnt : #points in each voxel index, IntTensor[b, s] +*/ +__global__ void grid_stats_kernel(int b, int n, int r, int r2, int r3, + const int *__restrict__ coords, + int *__restrict__ ind, int *cnt) { + int batch_index = blockIdx.x; + int stride = blockDim.x; + int index = threadIdx.x; + coords += batch_index * n * 3; + ind += batch_index * n; + cnt += batch_index * r3; + + for (int i = index; i < n; i += stride) { + // if (ind[i] == -1) + // continue; + ind[i] = coords[i] * r2 + coords[i + n] * r + coords[i + n + n]; + atomicAdd(cnt + ind[i], 1); + } +} + +/* + Function: average pool voxelization (forward) + Args: + b : batch size + c : #channels + n : number of points + s : voxel cube size = voxel resolution ** 3 + ind : voxel index of each point, IntTensor[b, n] + cnt : #points in each voxel index, IntTensor[b, s] + feat: features, FloatTensor[b, c, n] + out : outputs, FloatTensor[b, c, s] +*/ +__global__ void avg_voxelize_kernel(int b, int c, int n, int s, + const int *__restrict__ ind, + const int *__restrict__ cnt, + const float *__restrict__ feat, + float *__restrict__ out) { + int batch_index = blockIdx.x; + int stride = blockDim.x; + int index = threadIdx.x; + ind += batch_index * n; + feat += batch_index * c * n; + out += batch_index * c * s; + cnt += batch_index * s; + for (int i = index; i < n; i += stride) { + int pos = ind[i]; + // if (pos == -1) + // continue; + int cur_cnt = cnt[pos]; + if (cur_cnt > 0) { + float div_cur_cnt = 1.0 / static_cast(cur_cnt); + for (int j = 0; j < c; j++) { + atomicAdd(out + j * s + pos, feat[j * n + i] * div_cur_cnt); + } + } + } +} + +/* + Function: average pool voxelization (backward) + Args: + b : batch size + c : #channels + n : number of points + r3 : voxel cube size = voxel resolution ** 3 + ind : voxel index of each point, IntTensor[b, n] + cnt : #points in each voxel index, IntTensor[b, s] + grad_y : grad outputs, FloatTensor[b, c, s] + grad_x : grad inputs, FloatTensor[b, c, n] +*/ +__global__ void avg_voxelize_grad_kernel(int b, int c, int n, int r3, + const int *__restrict__ ind, + const int *__restrict__ cnt, + const float *__restrict__ grad_y, + float *__restrict__ grad_x) { + int batch_index = blockIdx.x; + int stride = blockDim.x; + int index = threadIdx.x; + ind += batch_index * n; + grad_x += batch_index * c * n; + grad_y += batch_index * c * r3; + cnt += batch_index * r3; + for (int i = index; i < n; i += stride) { + int pos = ind[i]; + // if (pos == -1) + // continue; + int cur_cnt = cnt[pos]; + if (cur_cnt > 0) { + float div_cur_cnt = 1.0 / static_cast(cur_cnt); + for (int j = 0; j < c; j++) { + atomicAdd(grad_x + j * n + i, grad_y[j * r3 + pos] * div_cur_cnt); + } + } + } +} + +void avg_voxelize(int b, int c, int n, int r, int r2, int r3, const int *coords, + const float *feat, int *ind, int *cnt, float *out) { + grid_stats_kernel<<>>(b, n, r, r2, r3, coords, ind, + cnt); + avg_voxelize_kernel<<>>(b, c, n, r3, ind, cnt, + feat, out); + CUDA_CHECK_ERRORS(); +} + +void avg_voxelize_grad(int b, int c, int n, int s, const int *ind, + const int *cnt, const float *grad_y, float *grad_x) { + avg_voxelize_grad_kernel<<>>(b, c, n, s, ind, cnt, + grad_y, grad_x); + CUDA_CHECK_ERRORS(); +} diff --git a/third_party/pvcnn/functional/src/voxelization/vox.cuh b/third_party/pvcnn/functional/src/voxelization/vox.cuh new file mode 100644 index 0000000..9adb0fd --- /dev/null +++ b/third_party/pvcnn/functional/src/voxelization/vox.cuh @@ -0,0 +1,10 @@ +#ifndef _VOX_CUH +#define _VOX_CUH + +// CUDA function declarations +void avg_voxelize(int b, int c, int n, int r, int r2, int r3, const int *coords, + const float *feat, int *ind, int *cnt, float *out); +void avg_voxelize_grad(int b, int c, int n, int s, const int *idx, + const int *cnt, const float *grad_y, float *grad_x); + +#endif diff --git a/third_party/pvcnn/functional/src/voxelization/vox.hpp b/third_party/pvcnn/functional/src/voxelization/vox.hpp new file mode 100644 index 0000000..6e62bc3 --- /dev/null +++ b/third_party/pvcnn/functional/src/voxelization/vox.hpp @@ -0,0 +1,15 @@ +#ifndef _VOX_HPP +#define _VOX_HPP + +#include +#include + +std::vector avg_voxelize_forward(const at::Tensor features, + const at::Tensor coords, + const int resolution); + +at::Tensor avg_voxelize_backward(const at::Tensor grad_y, + const at::Tensor indices, + const at::Tensor cnt); + +#endif diff --git a/third_party/pvcnn/functional/voxelization.py b/third_party/pvcnn/functional/voxelization.py new file mode 100644 index 0000000..53c86fa --- /dev/null +++ b/third_party/pvcnn/functional/voxelization.py @@ -0,0 +1,47 @@ +from torch.autograd import Function +import torch +# from modules.functional.backend import _backend +from third_party.pvcnn.functional.backend import _backend +from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd + +__all__ = ['avg_voxelize'] + + +class AvgVoxelization(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, features, coords, resolution): + """ + :param ctx: + :param features: Features of the point cloud, FloatTensor[B, C, N] + :param coords: Voxelized Coordinates of each point, IntTensor[B, 3, N] + :param resolution: Voxel resolution + :return: + Voxelized Features, FloatTensor[B, C, R, R, R] + """ + features = features.contiguous() + coords = coords.int()[:,:3].contiguous() + b, c, _ = features.shape + out, indices, counts = _backend.avg_voxelize_forward( + features, coords, resolution) + ctx.save_for_backward(indices, counts) + return out.view(b, c, resolution, resolution, resolution) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + """ + :param ctx: + :param grad_output: gradient of output, FloatTensor[B, C, R, R, R] + :return: + gradient of inputs, FloatTensor[B, C, N] + """ + b, c = grad_output.shape[:2] + indices, counts = ctx.saved_tensors + grad_features = _backend.avg_voxelize_backward( + grad_output.contiguous().view(b, c, -1), indices, counts) + return grad_features, None, None + + +avg_voxelize = AvgVoxelization.apply + diff --git a/third_party/torchdiffeq/LICENSE b/third_party/torchdiffeq/LICENSE new file mode 100644 index 0000000..bcd034d --- /dev/null +++ b/third_party/torchdiffeq/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 Ricky Tian Qi Chen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/torchdiffeq/README.md b/third_party/torchdiffeq/README.md new file mode 100644 index 0000000..ad97ddc --- /dev/null +++ b/third_party/torchdiffeq/README.md @@ -0,0 +1 @@ +adapted from `https://github.com/rtqichen/torchdiffeq/tree/master/torchdiffeq` diff --git a/third_party/torchdiffeq/torchdiffeq/__init__.py b/third_party/torchdiffeq/torchdiffeq/__init__.py new file mode 100644 index 0000000..bf4f651 --- /dev/null +++ b/third_party/torchdiffeq/torchdiffeq/__init__.py @@ -0,0 +1,4 @@ +from ._impl import odeint +from ._impl import odeint_adjoint +from ._impl import odeint_event +__version__ = "0.2.2" diff --git a/third_party/torchdiffeq/torchdiffeq/_impl/__init__.py b/third_party/torchdiffeq/torchdiffeq/_impl/__init__.py new file mode 100644 index 0000000..05b671e --- /dev/null +++ b/third_party/torchdiffeq/torchdiffeq/_impl/__init__.py @@ -0,0 +1,2 @@ +from .odeint import odeint, odeint_event +from .adjoint import odeint_adjoint diff --git a/third_party/torchdiffeq/torchdiffeq/_impl/adaptive_heun.py b/third_party/torchdiffeq/torchdiffeq/_impl/adaptive_heun.py new file mode 100644 index 0000000..10f9696 --- /dev/null +++ b/third_party/torchdiffeq/torchdiffeq/_impl/adaptive_heun.py @@ -0,0 +1,25 @@ +import torch +from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver + + +_ADAPTIVE_HEUN_TABLEAU = _ButcherTableau( + alpha=torch.tensor([1.], dtype=torch.float64), + beta=[ + torch.tensor([1.], dtype=torch.float64), + ], + c_sol=torch.tensor([0.5, 0.5], dtype=torch.float64), + c_error=torch.tensor([ + 0.5, + -0.5, + ], dtype=torch.float64), +) + +_AH_C_MID = torch.tensor([ + 0.5, 0. +], dtype=torch.float64) + + +class AdaptiveHeunSolver(RKAdaptiveStepsizeODESolver): + order = 2 + tableau = _ADAPTIVE_HEUN_TABLEAU + mid = _AH_C_MID diff --git a/third_party/torchdiffeq/torchdiffeq/_impl/adjoint.py b/third_party/torchdiffeq/torchdiffeq/_impl/adjoint.py new file mode 100644 index 0000000..aa7c60e --- /dev/null +++ b/third_party/torchdiffeq/torchdiffeq/_impl/adjoint.py @@ -0,0 +1,280 @@ +import warnings +import torch +import torch.nn as nn +from .odeint import SOLVERS, odeint +from .misc import _check_inputs, _flat_to_shape +from .misc import _mixed_norm + + +class OdeintAdjointMethod(torch.autograd.Function): + + @staticmethod + def forward(ctx, shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method, + adjoint_options, t_requires_grad, *adjoint_params): + + ctx.shapes = shapes + ctx.func = func + ctx.adjoint_rtol = adjoint_rtol + ctx.adjoint_atol = adjoint_atol + ctx.adjoint_method = adjoint_method + ctx.adjoint_options = adjoint_options + ctx.t_requires_grad = t_requires_grad + ctx.event_mode = event_fn is not None + + with torch.no_grad(): + ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options, event_fn=event_fn) + + if event_fn is None: + y = ans + else: + event_t, y = ans + ctx.event_t = event_t + + ctx.save_for_backward(t, y, *adjoint_params) + return ans + + @staticmethod + def backward(ctx, *grad_y): + with torch.no_grad(): + func = ctx.func + adjoint_rtol = ctx.adjoint_rtol + adjoint_atol = ctx.adjoint_atol + adjoint_method = ctx.adjoint_method + adjoint_options = ctx.adjoint_options + t_requires_grad = ctx.t_requires_grad + + t, y, *adjoint_params = ctx.saved_tensors + adjoint_params = tuple(adjoint_params) + + # Backprop as if integrating up to event time. + # Does NOT backpropagate through the event time. + event_mode = ctx.event_mode + if event_mode: + event_t = ctx.event_t + _t = t + t = torch.cat([t[0].reshape(-1), event_t.reshape(-1)]) + grad_y = grad_y[1] + else: + grad_y = grad_y[0] + + ################################## + # Set up initial state # + ################################## + + # [-1] because y and grad_y are both of shape (len(t), *y0.shape) + aug_state = [torch.zeros((), dtype=y.dtype, device=y.device), y[-1], grad_y[-1]] # vjp_t, y, vjp_y + aug_state.extend([torch.zeros_like(param) for param in adjoint_params]) # vjp_params + + ################################## + # Set up backward ODE func # + ################################## + + # TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives. + def augmented_dynamics(t, y_aug): + # Dynamics of the original system augmented with + # the adjoint wrt y, and an integrator wrt t and args. + y = y_aug[1] + adj_y = y_aug[2] + # ignore gradients wrt time and parameters + + with torch.enable_grad(): + t_ = t.detach() + t = t_.requires_grad_(True) + y = y.detach().requires_grad_(True) + + # If using an adaptive solver we don't want to waste time resolving dL/dt unless we need it (which + # doesn't necessarily even exist if there is piecewise structure in time), so turning off gradients + # wrt t here means we won't compute that if we don't need it. + func_eval = func(t if t_requires_grad else t_, y) + + # Workaround for PyTorch bug #39784 + _t = torch.as_strided(t, (), ()) # noqa + _y = torch.as_strided(y, (), ()) # noqa + _params = tuple(torch.as_strided(param, (), ()) for param in adjoint_params) # noqa + + vjp_t, vjp_y, *vjp_params = torch.autograd.grad( + func_eval, (t, y) + adjoint_params, -adj_y, + allow_unused=True, retain_graph=True + ) + + # autograd.grad returns None if no gradient, set to zero. + vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t + vjp_y = torch.zeros_like(y) if vjp_y is None else vjp_y + vjp_params = [torch.zeros_like(param) if vjp_param is None else vjp_param + for param, vjp_param in zip(adjoint_params, vjp_params)] + + return (vjp_t, func_eval, vjp_y, *vjp_params) + + ################################## + # Solve adjoint ODE # + ################################## + + if t_requires_grad: + time_vjps = torch.empty(len(t), dtype=t.dtype, device=t.device) + else: + time_vjps = None + for i in range(len(t) - 1, 0, -1): + if t_requires_grad: + # Compute the effect of moving the current time measurement point. + # We don't compute this unless we need to, to save some computation. + func_eval = func(t[i], y[i]) + dLd_cur_t = func_eval.reshape(-1).dot(grad_y[i].reshape(-1)) + aug_state[0] -= dLd_cur_t + time_vjps[i] = dLd_cur_t + + # Run the augmented system backwards in time. + aug_state = odeint( + augmented_dynamics, tuple(aug_state), + t[i - 1:i + 1].flip(0), + rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options + ) + aug_state = [a[1] for a in aug_state] # extract just the t[i - 1] value + aug_state[1] = y[i - 1] # update to use our forward-pass estimate of the state + aug_state[2] += grad_y[i - 1] # update any gradients wrt state at this time point + + if t_requires_grad: + time_vjps[0] = aug_state[0] + + # Only compute gradient wrt initial time when in event handling mode. + if event_mode and t_requires_grad: + time_vjps = torch.cat([time_vjps[0].reshape(-1), torch.zeros_like(_t[1:])]) + + adj_y = aug_state[2] + adj_params = aug_state[3:] + + return (None, None, adj_y, time_vjps, None, None, None, None, None, None, None, None, None, None, *adj_params) + + +def odeint_adjoint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None, + adjoint_rtol=None, adjoint_atol=None, adjoint_method=None, adjoint_options=None, adjoint_params=None): + + # We need this in order to access the variables inside this module, + # since we have no other way of getting variables along the execution path. + if adjoint_params is None and not isinstance(func, nn.Module): + raise ValueError('func must be an instance of nn.Module to specify the adjoint parameters; alternatively they ' + 'can be specified explicitly via the `adjoint_params` argument. If there are no parameters ' + 'then it is allowable to set `adjoint_params=()`.') + + # Must come before _check_inputs as we don't want to use normalised input (in particular any changes to options) + if adjoint_rtol is None: + adjoint_rtol = rtol + if adjoint_atol is None: + adjoint_atol = atol + if adjoint_method is None: + adjoint_method = method + + if adjoint_method != method and options is not None and adjoint_options is None: + raise ValueError("If `adjoint_method != method` then we cannot infer `adjoint_options` from `options`. So as " + "`options` has been passed then `adjoint_options` must be passed as well.") + + if adjoint_options is None: + adjoint_options = {k: v for k, v in options.items() if k != "norm"} if options is not None else {} + else: + # Avoid in-place modifying a user-specified dict. + adjoint_options = adjoint_options.copy() + + if adjoint_params is None: + adjoint_params = tuple(find_parameters(func)) + else: + adjoint_params = tuple(adjoint_params) # in case adjoint_params is a generator. + + # Filter params that don't require gradients. + oldlen_ = len(adjoint_params) + adjoint_params = tuple(p for p in adjoint_params if p.requires_grad) + if len(adjoint_params) != oldlen_: + # Some params were excluded. + # Issue a warning if a user-specified norm is specified. + if 'norm' in adjoint_options and callable(adjoint_options['norm']): + warnings.warn("An adjoint parameter was passed without requiring gradient. For efficiency this will be " + "excluded from the adjoint pass, and will not appear as a tensor in the adjoint norm.") + + # Convert to flattened state. + shapes, func, y0, t, rtol, atol, method, options, event_fn, decreasing_time = _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS) + + # Handle the adjoint norm function. + state_norm = options["norm"] + handle_adjoint_norm_(adjoint_options, shapes, state_norm) + + ans = OdeintAdjointMethod.apply(shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, + adjoint_method, adjoint_options, t.requires_grad, *adjoint_params) + + if event_fn is None: + solution = ans + else: + event_t, solution = ans + event_t = event_t.to(t) + if decreasing_time: + event_t = -event_t + + if shapes is not None: + solution = _flat_to_shape(solution, (len(t),), shapes) + + if event_fn is None: + return solution + else: + return event_t, solution + + +def find_parameters(module): + + assert isinstance(module, nn.Module) + + # If called within DataParallel, parameters won't appear in module.parameters(). + if getattr(module, '_is_replica', False): + + def find_tensor_attributes(module): + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v) and v.requires_grad] + return tuples + + gen = module._named_members(get_members_fn=find_tensor_attributes) + return [param for _, param in gen] + else: + return list(module.parameters()) + + +def handle_adjoint_norm_(adjoint_options, shapes, state_norm): + """In-place modifies the adjoint options to choose or wrap the norm function.""" + + # This is the default adjoint norm on the backward pass: a mixed norm over the tuple of inputs. + def default_adjoint_norm(tensor_tuple): + t, y, adj_y, *adj_params = tensor_tuple + # (If the state is actually a flattened tuple then this will be unpacked again in state_norm.) + return max(t.abs(), state_norm(y), state_norm(adj_y), _mixed_norm(adj_params)) + + if "norm" not in adjoint_options: + # `adjoint_options` was not explicitly specified by the user. Use the default norm. + adjoint_options["norm"] = default_adjoint_norm + else: + # `adjoint_options` was explicitly specified by the user... + try: + adjoint_norm = adjoint_options['norm'] + except KeyError: + # ...but they did not specify the norm argument. Back to plan A: use the default norm. + adjoint_options['norm'] = default_adjoint_norm + else: + # ...and they did specify the norm argument. + if adjoint_norm == 'seminorm': + # They told us they want to use seminorms. Slight modification to plan A: use the default norm, + # but ignore the parameter state + def adjoint_seminorm(tensor_tuple): + t, y, adj_y, *adj_params = tensor_tuple + # (If the state is actually a flattened tuple then this will be unpacked again in state_norm.) + return max(t.abs(), state_norm(y), state_norm(adj_y)) + adjoint_options['norm'] = adjoint_seminorm + else: + # And they're using their own custom norm. + if shapes is None: + # The state on the forward pass was a tensor, not a tuple. We don't need to do anything, they're + # already going to get given the full adjoint state as (t, y, adj_y, adj_params) + pass # this branch included for clarity + else: + # This is the bit that is tuple/tensor abstraction-breaking, because the odeint machinery + # doesn't know about the tupled nature of the forward state. We need to tell the user's adjoint + # norm about that ourselves. + + def _adjoint_norm(tensor_tuple): + t, y, adj_y, *adj_params = tensor_tuple + y = _flat_to_shape(y, (), shapes) + adj_y = _flat_to_shape(adj_y, (), shapes) + return adjoint_norm((t, *y, *adj_y, *adj_params)) + adjoint_options['norm'] = _adjoint_norm diff --git a/third_party/torchdiffeq/torchdiffeq/_impl/bosh3.py b/third_party/torchdiffeq/torchdiffeq/_impl/bosh3.py new file mode 100644 index 0000000..cf09e7f --- /dev/null +++ b/third_party/torchdiffeq/torchdiffeq/_impl/bosh3.py @@ -0,0 +1,22 @@ +import torch +from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver + + +_BOGACKI_SHAMPINE_TABLEAU = _ButcherTableau( + alpha=torch.tensor([1 / 2, 3 / 4, 1.], dtype=torch.float64), + beta=[ + torch.tensor([1 / 2], dtype=torch.float64), + torch.tensor([0., 3 / 4], dtype=torch.float64), + torch.tensor([2 / 9, 1 / 3, 4 / 9], dtype=torch.float64) + ], + c_sol=torch.tensor([2 / 9, 1 / 3, 4 / 9, 0.], dtype=torch.float64), + c_error=torch.tensor([2 / 9 - 7 / 24, 1 / 3 - 1 / 4, 4 / 9 - 1 / 3, -1 / 8], dtype=torch.float64), +) + +_BS_C_MID = torch.tensor([0., 0.5, 0., 0.], dtype=torch.float64) + + +class Bosh3Solver(RKAdaptiveStepsizeODESolver): + order = 3 + tableau = _BOGACKI_SHAMPINE_TABLEAU + mid = _BS_C_MID diff --git a/third_party/torchdiffeq/torchdiffeq/_impl/dopri5.py b/third_party/torchdiffeq/torchdiffeq/_impl/dopri5.py new file mode 100644 index 0000000..1a925ef --- /dev/null +++ b/third_party/torchdiffeq/torchdiffeq/_impl/dopri5.py @@ -0,0 +1,36 @@ +import torch +from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver + + +_DORMAND_PRINCE_SHAMPINE_TABLEAU = _ButcherTableau( + alpha=torch.tensor([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.], dtype=torch.float64), + beta=[ + torch.tensor([1 / 5], dtype=torch.float64), + torch.tensor([3 / 40, 9 / 40], dtype=torch.float64), + torch.tensor([44 / 45, -56 / 15, 32 / 9], dtype=torch.float64), + torch.tensor([19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], dtype=torch.float64), + torch.tensor([9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], dtype=torch.float64), + torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84], dtype=torch.float64), + ], + c_sol=torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0], dtype=torch.float64), + c_error=torch.tensor([ + 35 / 384 - 1951 / 21600, + 0, + 500 / 1113 - 22642 / 50085, + 125 / 192 - 451 / 720, + -2187 / 6784 - -12231 / 42400, + 11 / 84 - 649 / 6300, + -1. / 60., + ], dtype=torch.float64), +) + +DPS_C_MID = torch.tensor([ + 6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2, -2691868925 / 45128329728 / 2, + 187940372067 / 1594534317056 / 2, -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2 +], dtype=torch.float64) + + +class Dopri5Solver(RKAdaptiveStepsizeODESolver): + order = 5 + tableau = _DORMAND_PRINCE_SHAMPINE_TABLEAU + mid = DPS_C_MID diff --git a/third_party/torchdiffeq/torchdiffeq/_impl/dopri8.py b/third_party/torchdiffeq/torchdiffeq/_impl/dopri8.py new file mode 100644 index 0000000..ef10dc4 --- /dev/null +++ b/third_party/torchdiffeq/torchdiffeq/_impl/dopri8.py @@ -0,0 +1,76 @@ +import torch +from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver + + +A = [1 / 18, 1 / 12, 1 / 8, 5 / 16, 3 / 8, 59 / 400, 93 / 200, 5490023248 / 9719169821, 13 / 20, 1201146811 / 1299019798, 1, 1, 1] + +B = [ + [1 / 18], + + [1 / 48, 1 / 16], + + [1 / 32, 0, 3 / 32], + + [5 / 16, 0, -75 / 64, 75 / 64], + + [3 / 80, 0, 0, 3 / 16, 3 / 20], + + [29443841 / 614563906, 0, 0, 77736538 / 692538347, -28693883 / 1125000000, 23124283 / 1800000000], + + [16016141 / 946692911, 0, 0, 61564180 / 158732637, 22789713 / 633445777, 545815736 / 2771057229, -180193667 / 1043307555], + + [39632708 / 573591083, 0, 0, -433636366 / 683701615, -421739975 / 2616292301, 100302831 / 723423059, 790204164 / 839813087, 800635310 / 3783071287], + + [246121993 / 1340847787, 0, 0, -37695042795 / 15268766246, -309121744 / 1061227803, -12992083 / 490766935, 6005943493 / 2108947869, 393006217 / 1396673457, 123872331 / 1001029789], + + [-1028468189 / 846180014, 0, 0, 8478235783 / 508512852, 1311729495 / 1432422823, -10304129995 / 1701304382, -48777925059 / 3047939560, 15336726248 / 1032824649, -45442868181 / 3398467696, 3065993473 / 597172653], + + [185892177 / 718116043, 0, 0, -3185094517 / 667107341, -477755414 / 1098053517, -703635378 / 230739211, 5731566787 / 1027545527, 5232866602 / 850066563, -4093664535 / 808688257, 3962137247 / 1805957418, 65686358 / 487910083], + + [403863854 / 491063109, 0, 0, -5068492393 / 434740067, -411421997 / 543043805, 652783627 / 914296604, 11173962825 / 925320556, -13158990841 / 6184727034, 3936647629 / 1978049680, -160528059 / 685178525, 248638103 / 1413531060, 0], + + [14005451 / 335480064, 0, 0, 0, 0, -59238493 / 1068277825, 181606767 / 758867731, 561292985 / 797845732, -1041891430 / 1371343529, 760417239 / 1151165299, 118820643 / 751138087, -528747749 / 2220607170, 1 / 4] +] + +C_sol = [14005451 / 335480064, 0, 0, 0, 0, -59238493 / 1068277825, 181606767 / 758867731, 561292985 / 797845732, -1041891430 / 1371343529, 760417239 / 1151165299, 118820643 / 751138087, -528747749 / 2220607170, 1 / 4, 0] + +C_err = [14005451 / 335480064 - 13451932 / 455176623, 0, 0, 0, 0, -59238493 / 1068277825 - -808719846 / 976000145, 181606767 / 758867731 - 1757004468 / 5645159321, 561292985 / 797845732 - 656045339 / 265891186, -1041891430 / 1371343529 - -3867574721 / 1518517206, 760417239 / 1151165299 - 465885868 / 322736535, 118820643 / 751138087 - 53011238 / 667516719, -528747749 / 2220607170 - 2 / 45, 1 / 4, 0] + +h = 1 / 2 + +C_mid = [0.] * 14 + +C_mid[0] = (- 6.3448349392860401388 * (h**5) + 22.1396504998094068976 * (h**4) - 30.0610568289666450593 * (h**3) + 19.9990069333683970610 * (h**2) - 6.6910181737837595697 * h + 1.0) / (1 / h) + +C_mid[5] = (- 39.6107919852202505218 * (h**5) + 116.4422149550342161651 * (h**4) - 121.4999627731334642623 * (h**3) + 52.2273532792945524050 * (h**2) - 7.6142658045872677172 * h) / (1 / h) + +C_mid[6] = (20.3761213808791436958 * (h**5) - 67.1451318825957197185 * (h**4) + 83.1721004639847717481 * (h**3) - 46.8919164181093621583 * (h**2) + 10.7281392630428866124 * h) / (1 / h) + +C_mid[7] = (7.3347098826795362023 * (h**5) - 16.5672243527496524646 * (h**4) + 9.5724507555993664382 * (h**3) - 0.1890893225010595467 * (h**2) + 0.5526637063753648783 * h) / (1 / h) + +C_mid[8] = (32.8801774352459155182 * (h**5) - 89.9916014847245016028 * (h**4) + 87.8406057677205645007 * (h**3) - 35.7075975946222072821 * (h**2) + 4.2186562625665153803 * h) / (1 / h) + +C_mid[9] = (- 10.1588990526426760954 * (h**5) + 22.6237489648532849093 * (h**4) - 17.4152107770762969005 * (h**3) + 6.2736448083240352160 * (h**2) - 0.6627209125361597559 * h) / (1 / h) + +C_mid[10] = (- 12.5401268098782561200 * (h**5) + 32.2362340167355370113 * (h**4) - 28.5903289514790976966 * (h**3) + 10.3160881272450748458 * (h**2) - 1.2636789001135462218 * h) / (1 / h) + +C_mid[11] = (29.5553001484516038033 * (h**5) - 82.1020315488359848644 * (h**4) + 81.6630950584341412934 * (h**3) - 34.7650769866611817349 * (h**2) + 5.4106037898590422230 * h) / (1 / h) + +C_mid[12] = (- 41.7923486424390588923 * (h**5) + 116.2662185791119533462 * (h**4) - 114.9375291377009418170 * (h**3) + 47.7457971078225540396 * (h**2) - 7.0321379067945741781 * h) / (1 / h) + +C_mid[13] = (20.3006925822100825485 * (h**5) - 53.9020777466385396792 * (h**4) + 50.2558364226176017553 * (h**3) - 19.0082099341608028453 * (h**2) + 2.3537586759714983486 * h) / (1 / h) + + +A = torch.tensor(A, dtype=torch.float64) +B = [torch.tensor(B_, dtype=torch.float64) for B_ in B] +C_sol = torch.tensor(C_sol, dtype=torch.float64) +C_err = torch.tensor(C_err, dtype=torch.float64) +_C_mid = torch.tensor(C_mid, dtype=torch.float64) + +_DOPRI8_TABLEAU = _ButcherTableau(alpha=A, beta=B, c_sol=C_sol, c_error=C_err) + + +class Dopri8Solver(RKAdaptiveStepsizeODESolver): + order = 8 + tableau = _DOPRI8_TABLEAU + mid = _C_mid diff --git a/third_party/torchdiffeq/torchdiffeq/_impl/event_handling.py b/third_party/torchdiffeq/torchdiffeq/_impl/event_handling.py new file mode 100644 index 0000000..067aac3 --- /dev/null +++ b/third_party/torchdiffeq/torchdiffeq/_impl/event_handling.py @@ -0,0 +1,35 @@ +import math +import torch + + +def find_event(interp_fn, sign0, t0, t1, event_fn, tol): + with torch.no_grad(): + + # Num iterations for the secant method until tolerance is within target. + nitrs = torch.ceil(torch.log((t1 - t0) / tol) / math.log(2.0)) + + for _ in range(nitrs.long()): + t_mid = (t1 + t0) / 2.0 + y_mid = interp_fn(t_mid) + sign_mid = torch.sign(event_fn(t_mid, y_mid)) + same_as_sign0 = (sign0 == sign_mid) + t0 = torch.where(same_as_sign0, t_mid, t0) + t1 = torch.where(same_as_sign0, t1, t_mid) + event_t = (t0 + t1) / 2.0 + + return event_t, interp_fn(event_t) + + +def combine_event_functions(event_fn, t0, y0): + """ + We ensure all event functions are initially positive, + so then we can combine them by taking a min. + """ + with torch.no_grad(): + initial_signs = torch.sign(event_fn(t0, y0)) + + def combined_event_fn(t, y): + c = event_fn(t, y) + return torch.min(c * initial_signs) + + return combined_event_fn diff --git a/third_party/torchdiffeq/torchdiffeq/_impl/fehlberg2.py b/third_party/torchdiffeq/torchdiffeq/_impl/fehlberg2.py new file mode 100644 index 0000000..c6339a6 --- /dev/null +++ b/third_party/torchdiffeq/torchdiffeq/_impl/fehlberg2.py @@ -0,0 +1,22 @@ +import torch +from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver + +_FEHLBERG2_TABLEAU = _ButcherTableau( + alpha=torch.tensor([1 / 2, 1.0], dtype=torch.float64), + beta=[ + torch.tensor([1 / 2], dtype=torch.float64), + torch.tensor([1 / 256, 255 / 256], dtype=torch.float64), + ], + c_sol=torch.tensor([1 / 512, 255 / 256, 1 / 512], dtype=torch.float64), + c_error=torch.tensor( + [-1 / 512, 0, 1 / 512], dtype=torch.float64 + ), +) + +_FE_C_MID = torch.tensor([0.0, 0.5, 0.0], dtype=torch.float64) + + +class Fehlberg2(RKAdaptiveStepsizeODESolver): + order = 2 + tableau = _FEHLBERG2_TABLEAU + mid = _FE_C_MID diff --git a/third_party/torchdiffeq/torchdiffeq/_impl/fixed_adams.py b/third_party/torchdiffeq/torchdiffeq/_impl/fixed_adams.py new file mode 100644 index 0000000..daff802 --- /dev/null +++ b/third_party/torchdiffeq/torchdiffeq/_impl/fixed_adams.py @@ -0,0 +1,228 @@ +import collections +import sys +import torch +import warnings +from .solvers import FixedGridODESolver +from .misc import _compute_error_ratio, _linf_norm +from .misc import Perturb +from .rk_common import rk4_alt_step_func + +_BASHFORTH_COEFFICIENTS = [ + [], # order 0 + [11], + [3, -1], + [23, -16, 5], + [55, -59, 37, -9], + [1901, -2774, 2616, -1274, 251], + [4277, -7923, 9982, -7298, 2877, -475], + [198721, -447288, 705549, -688256, 407139, -134472, 19087], + [434241, -1152169, 2183877, -2664477, 2102243, -1041723, 295767, -36799], + [14097247, -43125206, 95476786, -139855262, 137968480, -91172642, 38833486, -9664106, 1070017], + [30277247, -104995189, 265932680, -454661776, 538363838, -444772162, 252618224, -94307320, 20884811, -2082753], + [ + 2132509567, -8271795124, 23591063805, -46113029016, 63716378958, -63176201472, 44857168434, -22329634920, + 7417904451, -1479574348, 134211265 + ], + [ + 4527766399, -19433810163, 61633227185, -135579356757, 214139355366, -247741639374, 211103573298, -131365867290, + 58189107627, -17410248271, 3158642445, -262747265 + ], + [ + 13064406523627, -61497552797274, 214696591002612, -524924579905150, 932884546055895, -1233589244941764, + 1226443086129408, -915883387152444, 507140369728425, -202322913738370, 55060974662412, -9160551085734, + 703604254357 + ], + [ + 27511554976875, -140970750679621, 537247052515662, -1445313351681906, 2854429571790805, -4246767353305755, + 4825671323488452, -4204551925534524, 2793869602879077, -1393306307155755, 505586141196430, -126174972681906, + 19382853593787, -1382741929621 + ], + [ + 173233498598849, -960122866404112, 3966421670215481, -11643637530577472, 25298910337081429, -41825269932507728, + 53471026659940509, -53246738660646912, 41280216336284259, -24704503655607728, 11205849753515179, + -3728807256577472, 859236476684231, -122594813904112, 8164168737599 + ], + [ + 362555126427073, -2161567671248849, 9622096909515337, -30607373860520569, 72558117072259733, + -131963191940828581, 187463140112902893, -210020588912321949, 186087544263596643, -129930094104237331, + 70724351582843483, -29417910911251819, 9038571752734087, -1934443196892599, 257650275915823, -16088129229375 + ], + [ + 192996103681340479, -1231887339593444974, 5878428128276811750, -20141834622844109630, 51733880057282977010, + -102651404730855807942, 160414858999474733422, -199694296833704562550, 199061418623907202560, + -158848144481581407370, 100878076849144434322, -50353311405771659322, 19338911944324897550, + -5518639984393844930, 1102560345141059610, -137692773163513234, 8092989203533249 + ], + [ + 401972381695456831, -2735437642844079789, 13930159965811142228, -51150187791975812900, 141500575026572531760, + -304188128232928718008, 518600355541383671092, -710171024091234303204, 786600875277595877750, + -706174326992944287370, 512538584122114046748, -298477260353977522892, 137563142659866897224, + -49070094880794267600, 13071639236569712860, -2448689255584545196, 287848942064256339, -15980174332775873 + ], + [ + 333374427829017307697, -2409687649238345289684, 13044139139831833251471, -51099831122607588046344, + 151474888613495715415020, -350702929608291455167896, 647758157491921902292692, -967713746544629658690408, + 1179078743786280451953222, -1176161829956768365219840, 960377035444205950813626, -639182123082298748001432, + 343690461612471516746028, -147118738993288163742312, 48988597853073465932820, -12236035290567356418552, + 2157574942881818312049, -239560589366324764716, 12600467236042756559 + ], + [ + 691668239157222107697, -5292843584961252933125, 30349492858024727686755, -126346544855927856134295, + 399537307669842150996468, -991168450545135070835076, 1971629028083798845750380, -3191065388846318679544380, + 4241614331208149947151790, -4654326468801478894406214, 4222756879776354065593786, -3161821089800186539248210, + 1943018818982002395655620, -970350191086531368649620, 387739787034699092364924, -121059601023985433003532, + 28462032496476316665705, -4740335757093710713245, 498669220956647866875, -24919383499187492303 + ], +] + +_MOULTON_COEFFICIENTS = [ + [], # order 0 + [1], + [1, 1], + [5, 8, -1], + [9, 19, -5, 1], + [251, 646, -264, 106, -19], + [475, 1427, -798, 482, -173, 27], + [19087, 65112, -46461, 37504, -20211, 6312, -863], + [36799, 139849, -121797, 123133, -88547, 41499, -11351, 1375], + [1070017, 4467094, -4604594, 5595358, -5033120, 3146338, -1291214, 312874, -33953], + [2082753, 9449717, -11271304, 16002320, -17283646, 13510082, -7394032, 2687864, -583435, 57281], + [ + 134211265, 656185652, -890175549, 1446205080, -1823311566, 1710774528, -1170597042, 567450984, -184776195, + 36284876, -3250433 + ], + [ + 262747265, 1374799219, -2092490673, 3828828885, -5519460582, 6043521486, -4963166514, 3007739418, -1305971115, + 384709327, -68928781, 5675265 + ], + [ + 703604254357, 3917551216986, -6616420957428, 13465774256510, -21847538039895, 27345870698436, -26204344465152, + 19058185652796, -10344711794985, 4063327863170, -1092096992268, 179842822566, -13695779093 + ], + [ + 1382741929621, 8153167962181, -15141235084110, 33928990133618, -61188680131285, 86180228689563, -94393338653892, + 80101021029180, -52177910882661, 25620259777835, -9181635605134, 2268078814386, -345457086395, 24466579093 + ], + [ + 8164168737599, 50770967534864, -102885148956217, 251724894607936, -499547203754837, 781911618071632, + -963605400824733, 934600833490944, -710312834197347, 418551804601264, -187504936597931, 61759426692544, + -14110480969927, 1998759236336, -132282840127 + ], + [ + 16088129229375, 105145058757073, -230992163723849, 612744541065337, -1326978663058069, 2285168598349733, + -3129453071993581, 3414941728852893, -2966365730265699, 2039345879546643, -1096355235402331, 451403108933483, + -137515713789319, 29219384284087, -3867689367599, 240208245823 + ], + [ + 8092989203533249, 55415287221275246, -131240807912923110, 375195469874202430, -880520318434977010, + 1654462865819232198, -2492570347928318318, 3022404969160106870, -2953729295811279360, 2320851086013919370, + -1455690451266780818, 719242466216944698, -273894214307914510, 77597639915764930, -15407325991235610, + 1913813460537746, -111956703448001 + ], + [ + 15980174332775873, 114329243705491117, -290470969929371220, 890337710266029860, -2250854333681641520, + 4582441343348851896, -7532171919277411636, 10047287575124288740, -10910555637627652470, 9644799218032932490, + -6913858539337636636, 3985516155854664396, -1821304040326216520, 645008976643217360, -170761422500096220, + 31816981024600492, -3722582669836627, 205804074290625 + ], + [ + 12600467236042756559, 93965550344204933076, -255007751875033918095, 834286388106402145800, + -2260420115705863623660, 4956655592790542146968, -8827052559979384209108, 12845814402199484797800, + -15345231910046032448070, 15072781455122686545920, -12155867625610599812538, 8008520809622324571288, + -4269779992576330506540, 1814584564159445787240, -600505972582990474260, 149186846171741510136, + -26182538841925312881, 2895045518506940460, -151711881512390095 + ], + [ + 24919383499187492303, 193280569173472261637, -558160720115629395555, 1941395668950986461335, + -5612131802364455926260, 13187185898439270330756, -25293146116627869170796, 39878419226784442421820, + -51970649453670274135470, 56154678684618739939910, -50320851025594566473146, 37297227252822858381906, + -22726350407538133839300, 11268210124987992327060, -4474886658024166985340, 1389665263296211699212, + -325187970422032795497, 53935307402575440285, -5652892248087175675, 281550972898020815 + ], +] + +_DIVISOR = [ + None, 11, 2, 12, 24, 720, 1440, 60480, 120960, 3628800, 7257600, 479001600, 958003200, 2615348736000, 5230697472000, + 31384184832000, 62768369664000, 32011868528640000, 64023737057280000, 51090942171709440000, 102181884343418880000 +] + +_BASHFORTH_DIVISOR = [torch.tensor([b / divisor for b in bashforth], dtype=torch.float64) + for bashforth, divisor in zip(_BASHFORTH_COEFFICIENTS, _DIVISOR)] +_MOULTON_DIVISOR = [torch.tensor([m / divisor for m in moulton], dtype=torch.float64) + for moulton, divisor in zip(_MOULTON_COEFFICIENTS, _DIVISOR)] + +_MIN_ORDER = 4 +_MAX_ORDER = 12 +_MAX_ITERS = 4 + + +# TODO: replace this with PyTorch operations (a little hard because y is a deque being used as a circular buffer) +def _dot_product(x, y): + return sum(xi * yi for xi, yi in zip(x, y)) + + +class AdamsBashforthMoulton(FixedGridODESolver): + order = 4 + + def __init__(self, func, y0, rtol=1e-3, atol=1e-4, implicit=True, max_iters=_MAX_ITERS, max_order=_MAX_ORDER, + **kwargs): + super(AdamsBashforthMoulton, self).__init__(func, y0, rtol=rtol, atol=rtol, **kwargs) + assert max_order <= _MAX_ORDER, "max_order must be at most {}".format(_MAX_ORDER) + if max_order < _MIN_ORDER: + warnings.warn("max_order is below {}, so the solver reduces to `rk4`.".format(_MIN_ORDER)) + + self.rtol = torch.as_tensor(rtol, dtype=y0.dtype, device=y0.device) + self.atol = torch.as_tensor(atol, dtype=y0.dtype, device=y0.device) + self.implicit = implicit + self.max_iters = max_iters + self.max_order = int(max_order) + self.prev_f = collections.deque(maxlen=self.max_order - 1) + self.prev_t = None + + self.bashforth = [x.to(y0.device) for x in _BASHFORTH_DIVISOR] + self.moulton = [x.to(y0.device) for x in _MOULTON_DIVISOR] + + def _update_history(self, t, f): + if self.prev_t is None or self.prev_t != t: + self.prev_f.appendleft(f) + self.prev_t = t + + def _has_converged(self, y0, y1): + """Checks that each element is within the error tolerance.""" + error_ratio = _compute_error_ratio(torch.abs(y0 - y1), self.rtol, self.atol, y0, y1, _linf_norm) + return error_ratio < 1 + + def _step_func(self, func, t0, dt, t1, y0): + f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) + self._update_history(t0, f0) + order = min(len(self.prev_f), self.max_order - 1) + if order < _MIN_ORDER - 1: + # Compute using RK4. + return rk4_alt_step_func(func, t0, dt, t1, y0, f0=self.prev_f[0], perturb=self.perturb), f0 + else: + # Adams-Bashforth predictor. + bashforth_coeffs = self.bashforth[order] + dy = _dot_product(dt * bashforth_coeffs, self.prev_f).type_as(y0) # bashforth is float64 so cast back + + # Adams-Moulton corrector. + if self.implicit: + moulton_coeffs = self.moulton[order + 1] + delta = dt * _dot_product(moulton_coeffs[1:], self.prev_f).type_as(y0) # moulton is float64 so cast back + converged = False + for _ in range(self.max_iters): + dy_old = dy + f = func(t1, y0 + dy, perturb=Perturb.PREV if self.perturb else Perturb.NONE) + dy = (dt * (moulton_coeffs[0]) * f).type_as(y0) + delta # moulton is float64 so cast back + converged = self._has_converged(dy_old, dy) + if converged: + break + if not converged: + warnings.warn('Functional iteration did not converge. Solution may be incorrect.', file=sys.stderr) + self.prev_f.pop() + self._update_history(t0, f) + return dy, f0 + + +class AdamsBashforth(AdamsBashforthMoulton): + def __init__(self, func, y0, **kwargs): + super(AdamsBashforth, self).__init__(func, y0, implicit=False, **kwargs) diff --git a/third_party/torchdiffeq/torchdiffeq/_impl/fixed_grid.py b/third_party/torchdiffeq/torchdiffeq/_impl/fixed_grid.py new file mode 100644 index 0000000..7578627 --- /dev/null +++ b/third_party/torchdiffeq/torchdiffeq/_impl/fixed_grid.py @@ -0,0 +1,29 @@ +from .solvers import FixedGridODESolver +from .rk_common import rk4_alt_step_func +from .misc import Perturb + + +class Euler(FixedGridODESolver): + order = 1 + + def _step_func(self, func, t0, dt, t1, y0): + f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) + return dt * f0, f0 + + +class Midpoint(FixedGridODESolver): + order = 2 + + def _step_func(self, func, t0, dt, t1, y0): + half_dt = 0.5 * dt + f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) + y_mid = y0 + f0 * half_dt + return dt * func(t0 + half_dt, y_mid), f0 + + +class RK4(FixedGridODESolver): + order = 4 + + def _step_func(self, func, t0, dt, t1, y0): + f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) + return rk4_alt_step_func(func, t0, dt, t1, y0, f0=f0, perturb=self.perturb), f0 diff --git a/third_party/torchdiffeq/torchdiffeq/_impl/interp.py b/third_party/torchdiffeq/torchdiffeq/_impl/interp.py new file mode 100644 index 0000000..74c6bfe --- /dev/null +++ b/third_party/torchdiffeq/torchdiffeq/_impl/interp.py @@ -0,0 +1,48 @@ +def _interp_fit(y0, y1, y_mid, f0, f1, dt): + """Fit coefficients for 4th order polynomial interpolation. + + Args: + y0: function value at the start of the interval. + y1: function value at the end of the interval. + y_mid: function value at the mid-point of the interval. + f0: derivative value at the start of the interval. + f1: derivative value at the end of the interval. + dt: width of the interval. + + Returns: + List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial + `p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x` + between 0 (start of interval) and 1 (end of interval). + """ + a = 2 * dt * (f1 - f0) - 8 * (y1 + y0) + 16 * y_mid + b = dt * (5 * f0 - 3 * f1) + 18 * y0 + 14 * y1 - 32 * y_mid + c = dt * (f1 - 4 * f0) - 11 * y0 - 5 * y1 + 16 * y_mid + d = dt * f0 + e = y0 + return [e, d, c, b, a] + + +def _interp_evaluate(coefficients, t0, t1, t): + """Evaluate polynomial interpolation at the given time point. + + Args: + coefficients: list of Tensor coefficients as created by `interp_fit`. + t0: scalar float64 Tensor giving the start of the interval. + t1: scalar float64 Tensor giving the end of the interval. + t: scalar float64 Tensor giving the desired interpolation point. + + Returns: + Polynomial interpolation of the coefficients at time `t`. + """ + + assert (t0 <= t) & (t <= t1), 'invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}'.format(t0, t, t1) + x = (t - t0) / (t1 - t0) + x = x.to(coefficients[0].dtype) + + total = coefficients[0] + x * coefficients[1] + x_power = x + for coefficient in coefficients[2:]: + x_power = x_power * x + total = total + x_power * coefficient + + return total diff --git a/third_party/torchdiffeq/torchdiffeq/_impl/misc.py b/third_party/torchdiffeq/torchdiffeq/_impl/misc.py new file mode 100644 index 0000000..e4fa652 --- /dev/null +++ b/third_party/torchdiffeq/torchdiffeq/_impl/misc.py @@ -0,0 +1,353 @@ +from enum import Enum +import math +import numpy as np +import torch +import warnings +from .event_handling import combine_event_functions + + +def _handle_unused_kwargs(solver, unused_kwargs): + if len(unused_kwargs) > 0: + warnings.warn('{}: Unexpected arguments {}'.format(solver.__class__.__name__, unused_kwargs)) + + +def _linf_norm(tensor): + return tensor.max() + + +def _rms_norm(tensor): + return tensor.pow(2).mean().sqrt() + + +def _zero_norm(tensor): + return 0. + + +def _mixed_norm(tensor_tuple): + if len(tensor_tuple) == 0: + return 0. + return max([_rms_norm(tensor) for tensor in tensor_tuple]) + + +def _select_initial_step(func, t0, y0, order, rtol, atol, norm, f0=None): + """Empirically select a good initial step. + + The algorithm is described in [1]_. + + References + ---------- + .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential + Equations I: Nonstiff Problems", Sec. II.4, 2nd edition. + """ + + dtype = y0.dtype + device = y0.device + t_dtype = t0.dtype + t0 = t0.to(dtype) + + if f0 is None: + f0 = func(t0, y0) + + scale = atol + torch.abs(y0) * rtol + + d0 = norm(y0 / scale) + d1 = norm(f0 / scale) + + if d0 < 1e-5 or d1 < 1e-5: + h0 = torch.tensor(1e-6, dtype=dtype, device=device) + else: + h0 = 0.01 * d0 / d1 + + y1 = y0 + h0 * f0 + f1 = func(t0 + h0, y1) + + d2 = norm((f1 - f0) / scale) / h0 + + if d1 <= 1e-15 and d2 <= 1e-15: + h1 = torch.max(torch.tensor(1e-6, dtype=dtype, device=device), h0 * 1e-3) + else: + h1 = (0.01 / max(d1, d2)) ** (1. / float(order + 1)) + + return torch.min(100 * h0, h1).to(t_dtype) + + +def _compute_error_ratio(error_estimate, rtol, atol, y0, y1, norm): + error_tol = atol + rtol * torch.max(y0.abs(), y1.abs()) + return norm(error_estimate / error_tol) + + +@torch.no_grad() +def _optimal_step_size(last_step, error_ratio, safety, ifactor, dfactor, order): + """Calculate the optimal size for the next step.""" + if error_ratio == 0: + return last_step * ifactor + if error_ratio < 1: + dfactor = torch.ones((), dtype=last_step.dtype, device=last_step.device) + error_ratio = error_ratio.type_as(last_step) + exponent = torch.tensor(order, dtype=last_step.dtype, device=last_step.device).reciprocal() + factor = torch.min(ifactor, torch.max(safety / error_ratio ** exponent, dfactor)) + return last_step * factor + + +def _decreasing(t): + return (t[1:] < t[:-1]).all() + + +def _assert_one_dimensional(name, t): + assert t.ndimension() == 1, "{} must be one dimensional".format(name) + + +def _assert_increasing(name, t): + assert (t[1:] > t[:-1]).all(), '{} must be strictly increasing or decreasing'.format(name) + + +def _assert_floating(name, t): + if not torch.is_floating_point(t): + raise TypeError('`{}` must be a floating point Tensor but is a {}'.format(name, t.type())) + + +def _tuple_tol(name, tol, shapes): + try: + iter(tol) + except TypeError: + return tol + tol = tuple(tol) + assert len(tol) == len(shapes), "If using tupled {} it must have the same length as the tuple y0; get {}, {}".format(name, len(tol), len(shapes)) + tol = [torch.as_tensor(tol_).expand(shape.numel()) for tol_, shape in zip(tol, shapes)] + return torch.cat(tol) + + +def _flat_to_shape(tensor, length, shapes): + tensor_list = [] + total = 0 + for shape in shapes: + next_total = total + shape.numel() + # It's important that this be view((...)), not view(...). Else when length=(), shape=() it fails. + tensor_list.append(tensor[..., total:next_total].view((*length, *shape))) + total = next_total + return tuple(tensor_list) + + +class _TupleFunc(torch.nn.Module): + def __init__(self, base_func, shapes): + super(_TupleFunc, self).__init__() + self.base_func = base_func + self.shapes = shapes + + def forward(self, t, y): + f = self.base_func(t, _flat_to_shape(y, (), self.shapes)) + return torch.cat([f_.reshape(-1) for f_ in f]) + + +class _TupleInputOnlyFunc(torch.nn.Module): + def __init__(self, base_func, shapes): + super(_TupleInputOnlyFunc, self).__init__() + self.base_func = base_func + self.shapes = shapes + + def forward(self, t, y): + return self.base_func(t, _flat_to_shape(y, (), self.shapes)) + + +class _ReverseFunc(torch.nn.Module): + def __init__(self, base_func, mul=1.0): + super(_ReverseFunc, self).__init__() + self.base_func = base_func + self.mul = mul + + def forward(self, t, y): + return self.mul * self.base_func(-t, y) + + +class Perturb(Enum): + NONE = 0 + PREV = 1 + NEXT = 2 + + +class _PerturbFunc(torch.nn.Module): + + def __init__(self, base_func): + super(_PerturbFunc, self).__init__() + self.base_func = base_func + + def forward(self, t, y, *, perturb=Perturb.NONE): + assert isinstance(perturb, Perturb), "perturb argument must be of type Perturb enum" + # This dtype change here might be buggy. + # The exact time value should be determined inside the solver, + # but this can slightly change it due to numerical differences during casting. + t = t.to(y.dtype) + if perturb is Perturb.NEXT: + # Replace with next smallest representable value. + t = _nextafter(t, t + 1) + elif perturb is Perturb.PREV: + # Replace with prev largest representable value. + t = _nextafter(t, t - 1) + else: + # Do nothing. + pass + return self.base_func(t, y) + + +def _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS): + + if event_fn is not None: + if len(t) != 2: + raise ValueError(f"We require len(t) == 2 when in event handling mode, but got len(t)={len(t)}.") + + # Combine event functions if the output is multivariate. + event_fn = combine_event_functions(event_fn, t[0], y0) + + # Normalise to tensor (non-tupled) input + shapes = None + is_tuple = not isinstance(y0, torch.Tensor) + if is_tuple: + assert isinstance(y0, tuple), 'y0 must be either a torch.Tensor or a tuple' + shapes = [y0_.shape for y0_ in y0] + rtol = _tuple_tol('rtol', rtol, shapes) + atol = _tuple_tol('atol', atol, shapes) + y0 = torch.cat([y0_.reshape(-1) for y0_ in y0]) + func = _TupleFunc(func, shapes) + if event_fn is not None: + event_fn = _TupleInputOnlyFunc(event_fn, shapes) + _assert_floating('y0', y0) + + # Normalise method and options + if options is None: + options = {} + else: + options = options.copy() + if method is None: + method = 'dopri5' + if method not in SOLVERS: + raise ValueError('Invalid method "{}". Must be one of {}'.format(method, + '{"' + '", "'.join(SOLVERS.keys()) + '"}.')) + + if is_tuple: + # We accept tupled input. This is an abstraction that is hidden from the rest of odeint (exception when + # returning values), so here we need to maintain the abstraction by wrapping norm functions. + + if 'norm' in options: + # If the user passed a norm then get that... + norm = options['norm'] + else: + # ...otherwise we default to a mixed Linf/L2 norm over tupled input. + norm = _mixed_norm + + # In either case, norm(...) is assumed to take a tuple of tensors as input. (As that's what the state looks + # like from the point of view of the user.) + # So here we take the tensor that the machinery of odeint has given us, and turn it in the tuple that the + # norm function is expecting. + def _norm(tensor): + y = _flat_to_shape(tensor, (), shapes) + return norm(y) + options['norm'] = _norm + + else: + if 'norm' in options: + # No need to change the norm function. + pass + else: + # Else just use the default norm. + # Technically we don't need to set that here (RKAdaptiveStepsizeODESolver has it as a default), but it + # makes it easier to reason about, in the adjoint norm logic, if we know that options['norm'] is + # definitely set to something. + options['norm'] = _rms_norm + + # Normalise time + _check_timelike('t', t, True) + t_is_reversed = False + if len(t) > 1 and t[0] > t[1]: + t_is_reversed = True + + if t_is_reversed: + # Change the integration times to ascending order. + # We do this by negating the time values and all associated arguments. + t = -t + + # Ensure time values are un-negated when calling functions. + func = _ReverseFunc(func, mul=-1.0) + if event_fn is not None: + event_fn = _ReverseFunc(event_fn) + + # For fixed step solvers. + try: + _grid_constructor = options['grid_constructor'] + except KeyError: + pass + else: + options['grid_constructor'] = lambda func, y0, t: -_grid_constructor(func, y0, -t) + + # For RK solvers. + _flip_option(options, 'step_t') + _flip_option(options, 'jump_t') + + # Can only do after having normalised time + _assert_increasing('t', t) + + # Tol checking + if torch.is_tensor(rtol): + assert not rtol.requires_grad, "rtol cannot require gradient" + if torch.is_tensor(atol): + assert not atol.requires_grad, "atol cannot require gradient" + + # Backward compatibility: Allow t and y0 to be on different devices + if t.device != y0.device: + warnings.warn("t is not on the same device as y0. Coercing to y0.device.") + t = t.to(y0.device) + # ~Backward compatibility + + # Add perturb argument to func. + func = _PerturbFunc(func) + + return shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed + + +class _StitchGradient(torch.autograd.Function): + @staticmethod + def forward(ctx, x1, out): + return out + + @staticmethod + def backward(ctx, grad_out): + return grad_out, None + + +def _nextafter(x1, x2): + with torch.no_grad(): + if hasattr(torch, "nextafter"): + out = torch.nextafter(x1, x2) + else: + out = np_nextafter(x1, x2) + return _StitchGradient.apply(x1, out) + + +def np_nextafter(x1, x2): + warnings.warn("torch.nextafter is only available in PyTorch 1.7 or newer." + "Falling back to numpy.nextafter. Upgrade PyTorch to remove this warning.") + x1_np = x1.detach().cpu().numpy() + x2_np = x2.detach().cpu().numpy() + out = torch.tensor(np.nextafter(x1_np, x2_np)).to(x1) + return out + + +def _check_timelike(name, timelike, can_grad): + assert isinstance(timelike, torch.Tensor), '{} must be a torch.Tensor'.format(name) + _assert_floating(name, timelike) + assert timelike.ndimension() == 1, "{} must be one dimensional".format(name) + if not can_grad: + assert not timelike.requires_grad, "{} cannot require gradient".format(name) + diff = timelike[1:] > timelike[:-1] + assert diff.all() or (~diff).all(), '{} must be strictly increasing or decreasing'.format(name) + + +def _flip_option(options, option_name): + try: + option_value = options[option_name] + except KeyError: + pass + else: + if isinstance(option_value, torch.Tensor): + options[option_name] = -option_value + # else: an error will be raised when the option is attempted to be used in Solver.__init__, but we defer raising + # the error until then to keep things tidy. diff --git a/third_party/torchdiffeq/torchdiffeq/_impl/odeint.py b/third_party/torchdiffeq/torchdiffeq/_impl/odeint.py new file mode 100644 index 0000000..a174219 --- /dev/null +++ b/third_party/torchdiffeq/torchdiffeq/_impl/odeint.py @@ -0,0 +1,164 @@ +import torch +from torch.autograd.functional import vjp +from .dopri5 import Dopri5Solver +from .bosh3 import Bosh3Solver +from .adaptive_heun import AdaptiveHeunSolver +from .fehlberg2 import Fehlberg2 +from .fixed_grid import Euler, Midpoint, RK4 +from .fixed_adams import AdamsBashforth, AdamsBashforthMoulton +from .dopri8 import Dopri8Solver +from .scipy_wrapper import ScipyWrapperODESolver +from .misc import _check_inputs, _flat_to_shape + +SOLVERS = { + 'dopri8': Dopri8Solver, + 'dopri5': Dopri5Solver, + 'bosh3': Bosh3Solver, + 'fehlberg2': Fehlberg2, + 'adaptive_heun': AdaptiveHeunSolver, + 'euler': Euler, + 'midpoint': Midpoint, + 'rk4': RK4, + 'explicit_adams': AdamsBashforth, + 'implicit_adams': AdamsBashforthMoulton, + # Backward compatibility: use the same name as before + 'fixed_adams': AdamsBashforthMoulton, + # ~Backwards compatibility + 'scipy_solver': ScipyWrapperODESolver, +} + + +def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None): + """Integrate a system of ordinary differential equations. + + Solves the initial value problem for a non-stiff system of first order ODEs: + ``` + dy/dt = func(t, y), y(t[0]) = y0 + ``` + where y is a Tensor or tuple of Tensors of any shape. + + Output dtypes and numerical precision are based on the dtypes of the inputs `y0`. + + Args: + func: Function that maps a scalar Tensor `t` and a Tensor holding the state `y` + into a Tensor of state derivatives with respect to time. Optionally, `y` + can also be a tuple of Tensors. + y0: N-D Tensor giving starting value of `y` at time point `t[0]`. Optionally, `y0` + can also be a tuple of Tensors. + t: 1-D Tensor holding a sequence of time points for which to solve for + `y`, in either increasing or decreasing order. The first element of + this sequence is taken to be the initial time point. + rtol: optional float64 Tensor specifying an upper bound on relative error, + per element of `y`. + atol: optional float64 Tensor specifying an upper bound on absolute error, + per element of `y`. + method: optional string indicating the integration method to use. + options: optional dict of configuring options for the indicated integration + method. Can only be provided if a `method` is explicitly set. + event_fn: Function that maps the state `y` to a Tensor. The solve terminates when + event_fn evaluates to zero. If this is not None, all but the first elements of + `t` are ignored. + + Returns: + y: Tensor, where the first dimension corresponds to different + time points. Contains the solved value of y for each desired time point in + `t`, with the initial value `y0` being the first element along the first + dimension. + + Raises: + ValueError: if an invalid `method` is provided. + """ + + shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS) + + solver = SOLVERS[method](func=func, y0=y0, rtol=rtol, atol=atol, **options) + + if event_fn is None: + solution = solver.integrate(t) + else: + event_t, solution = solver.integrate_until_event(t[0], event_fn) + event_t = event_t.to(t) + if t_is_reversed: + event_t = -event_t + + if shapes is not None: + solution = _flat_to_shape(solution, (len(t),), shapes) + + if event_fn is None: + return solution + else: + return event_t, solution + + +def odeint_event(func, y0, t0, *, event_fn, reverse_time=False, odeint_interface=odeint, **kwargs): + """Automatically links up the gradient from the event time.""" + + if reverse_time: + t = torch.cat([t0.reshape(-1), t0.reshape(-1).detach() - 1.0]) + else: + t = torch.cat([t0.reshape(-1), t0.reshape(-1).detach() + 1.0]) + + event_t, solution = odeint_interface(func, y0, t, event_fn=event_fn, **kwargs) + + # Dummy values for rtol, atol, method, and options. + shapes, _func, _, t, _, _, _, _, event_fn, _ = _check_inputs(func, y0, t, 0.0, 0.0, None, None, event_fn, SOLVERS) + + if shapes is not None: + state_t = torch.cat([s[-1].reshape(-1) for s in solution]) + else: + state_t = solution[-1] + + # Event_fn takes in negated time value if reverse_time is True. + if reverse_time: + event_t = -event_t + + event_t, state_t = ImplicitFnGradientRerouting.apply(_func, event_fn, event_t, state_t) + + # Return the user expected time value. + if reverse_time: + event_t = -event_t + + if shapes is not None: + state_t = _flat_to_shape(state_t, (), shapes) + solution = tuple(torch.cat([s[:-1], s_t[None]], dim=0) for s, s_t in zip(solution, state_t)) + else: + solution = torch.cat([solution[:-1], state_t[None]], dim=0) + + return event_t, solution + + +class ImplicitFnGradientRerouting(torch.autograd.Function): + + @staticmethod + def forward(ctx, func, event_fn, event_t, state_t): + """ event_t is the solution to event_fn """ + ctx.func = func + ctx.event_fn = event_fn + ctx.save_for_backward(event_t, state_t) + return event_t.detach(), state_t.detach() + + @staticmethod + def backward(ctx, grad_t, grad_state): + func = ctx.func + event_fn = ctx.event_fn + event_t, state_t = ctx.saved_tensors + + event_t = event_t.detach().clone().requires_grad_(True) + state_t = state_t.detach().clone().requires_grad_(True) + + f_val = func(event_t, state_t) + + with torch.enable_grad(): + c, (par_dt, dstate) = vjp(event_fn, (event_t, state_t)) + + # Total derivative of event_fn wrt t evaluated at event_t. + dcdt = par_dt + torch.sum(dstate * f_val) + + # Add the gradient from final state to final time value as if a regular odeint was called. + grad_t = grad_t + torch.sum(grad_state * f_val) + + dstate = dstate * (-grad_t / (dcdt + 1e-12)).reshape_as(c) + + grad_state = grad_state + dstate + + return None, None, None, grad_state diff --git a/third_party/torchdiffeq/torchdiffeq/_impl/rk_common.py b/third_party/torchdiffeq/torchdiffeq/_impl/rk_common.py new file mode 100644 index 0000000..4d4f60e --- /dev/null +++ b/third_party/torchdiffeq/torchdiffeq/_impl/rk_common.py @@ -0,0 +1,307 @@ +import bisect +import collections +import torch +from .event_handling import find_event +from .interp import _interp_evaluate, _interp_fit +from .misc import (_compute_error_ratio, + _select_initial_step, + _optimal_step_size) +from .misc import Perturb +from .solvers import AdaptiveStepsizeEventODESolver + + +_ButcherTableau = collections.namedtuple('_ButcherTableau', 'alpha, beta, c_sol, c_error') + + +_RungeKuttaState = collections.namedtuple('_RungeKuttaState', 'y1, f1, t0, t1, dt, interp_coeff') +# Saved state of the Runge Kutta solver. +# +# Attributes: +# y1: Tensor giving the function value at the end of the last time step. +# f1: Tensor giving derivative at the end of the last time step. +# t0: scalar float64 Tensor giving start of the last time step. +# t1: scalar float64 Tensor giving end of the last time step. +# dt: scalar float64 Tensor giving the size for the next time step. +# interp_coeff: list of Tensors giving coefficients for polynomial +# interpolation between `t0` and `t1`. + + +class _UncheckedAssign(torch.autograd.Function): + @staticmethod + def forward(ctx, scratch, value, index): + ctx.index = index + scratch.data[index] = value # sneak past the version checker + return scratch + + @staticmethod + def backward(ctx, grad_scratch): + return grad_scratch, grad_scratch[ctx.index], None + + +def _runge_kutta_step(func, y0, f0, t0, dt, t1, tableau): + """Take an arbitrary Runge-Kutta step and estimate error. + Args: + func: Function to evaluate like `func(t, y)` to compute the time derivative of `y`. + y0: Tensor initial value for the state. + f0: Tensor initial value for the derivative, computed from `func(t0, y0)`. + t0: float64 scalar Tensor giving the initial time. + dt: float64 scalar Tensor giving the size of the desired time step. + t1: float64 scalar Tensor giving the end time; equal to t0 + dt. This is used (rather than t0 + dt) to ensure + floating point accuracy when needed. + tableau: _ButcherTableau describing how to take the Runge-Kutta step. + Returns: + Tuple `(y1, f1, y1_error, k)` giving the estimated function value after + the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`, + estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for + calculating these terms. + """ + + t0 = t0.to(y0.dtype) + dt = dt.to(y0.dtype) + t1 = t1.to(y0.dtype) + + # We use an unchecked assign to put data into k without incrementing its _version counter, so that the backward + # doesn't throw an (overzealous) error about in-place correctness. We know that it's actually correct. + k = torch.empty(*f0.shape, len(tableau.alpha) + 1, dtype=y0.dtype, device=y0.device) + k = _UncheckedAssign.apply(k, f0, (..., 0)) + for i, (alpha_i, beta_i) in enumerate(zip(tableau.alpha, tableau.beta)): + if alpha_i == 1.: + # Always step to perturbing just before the end time, in case of discontinuities. + ti = t1 + perturb = Perturb.PREV + else: + ti = t0 + alpha_i * dt + perturb = Perturb.NONE + yi = y0 + k[..., :i + 1].matmul(beta_i * dt).view_as(f0) + f = func(ti, yi, perturb=perturb) + k = _UncheckedAssign.apply(k, f, (..., i + 1)) + + if not (tableau.c_sol[-1] == 0 and (tableau.c_sol[:-1] == tableau.beta[-1]).all()): + # This property (true for Dormand-Prince) lets us save a few FLOPs. + yi = y0 + k.matmul(dt * tableau.c_sol).view_as(f0) + + y1 = yi + f1 = k[..., -1] + y1_error = k.matmul(dt * tableau.c_error) + return y1, f1, y1_error, k + + +# Precompute divisions +_one_third = 1 / 3 +_two_thirds = 2 / 3 +_one_sixth = 1 / 6 + + +def rk4_step_func(func, t0, dt, t1, y0, f0=None, perturb=False): + k1 = f0 + if k1 is None: + k1 = func(t0, y0, perturb=Perturb.NEXT if perturb else Perturb.NONE) + half_dt = dt * 0.5 + k2 = func(t0 + half_dt, y0 + half_dt * k1) + k3 = func(t0 + half_dt, y0 + half_dt * k2) + k4 = func(t1, y0 + dt * k3, perturb=Perturb.PREV if perturb else Perturb.NONE) + return (k1 + 2 * (k2 + k3) + k4) * dt * _one_sixth + + +def rk4_alt_step_func(func, t0, dt, t1, y0, f0=None, perturb=False): + """Smaller error with slightly more compute.""" + k1 = f0 + if k1 is None: + k1 = func(t0, y0, perturb=Perturb.NEXT if perturb else Perturb.NONE) + k2 = func(t0 + dt * _one_third, y0 + dt * k1 * _one_third) + k3 = func(t0 + dt * _two_thirds, y0 + dt * (k2 - k1 * _one_third)) + k4 = func(t1, y0 + dt * (k1 - k2 + k3), perturb=Perturb.PREV if perturb else Perturb.NONE) + return (k1 + 3 * (k2 + k3) + k4) * dt * 0.125 + + +class RKAdaptiveStepsizeODESolver(AdaptiveStepsizeEventODESolver): + order: int + tableau: _ButcherTableau + mid: torch.Tensor + + def __init__(self, func, y0, rtol, atol, + first_step=None, + step_t=None, + jump_t=None, + safety=0.9, + ifactor=10.0, + dfactor=0.2, + max_num_steps=2 ** 31 - 1, + dtype=torch.float64, + **kwargs): + super(RKAdaptiveStepsizeODESolver, self).__init__(dtype=dtype, y0=y0, **kwargs) + + # We use mixed precision. y has its original dtype (probably float32), whilst all 'time'-like objects use + # `dtype` (defaulting to float64). + dtype = torch.promote_types(dtype, y0.dtype) + device = y0.device + + self.func = func + self.rtol = torch.as_tensor(rtol, dtype=dtype, device=device) + self.atol = torch.as_tensor(atol, dtype=dtype, device=device) + self.first_step = None if first_step is None else torch.as_tensor(first_step, dtype=dtype, device=device) + self.safety = torch.as_tensor(safety, dtype=dtype, device=device) + self.ifactor = torch.as_tensor(ifactor, dtype=dtype, device=device) + self.dfactor = torch.as_tensor(dfactor, dtype=dtype, device=device) + self.max_num_steps = torch.as_tensor(max_num_steps, dtype=torch.int32, device=device) + self.dtype = dtype + + self.step_t = None if step_t is None else torch.as_tensor(step_t, dtype=dtype, device=device) + self.jump_t = None if jump_t is None else torch.as_tensor(jump_t, dtype=dtype, device=device) + + # Copy from class to instance to set device + self.tableau = _ButcherTableau(alpha=self.tableau.alpha.to(device=device, dtype=y0.dtype), + beta=[b.to(device=device, dtype=y0.dtype) for b in self.tableau.beta], + c_sol=self.tableau.c_sol.to(device=device, dtype=y0.dtype), + c_error=self.tableau.c_error.to(device=device, dtype=y0.dtype)) + self.mid = self.mid.to(device=device, dtype=y0.dtype) + + def _before_integrate(self, t): + t0 = t[0] + f0 = self.func(t[0], self.y0) + if self.first_step is None: + first_step = _select_initial_step(self.func, t[0], self.y0, self.order - 1, self.rtol, self.atol, + self.norm, f0=f0) + else: + first_step = self.first_step + self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, [self.y0] * 5) + + # Handle step_t and jump_t arguments. + if self.step_t is None: + step_t = torch.tensor([], dtype=self.dtype, device=self.y0.device) + else: + step_t = _sort_tvals(self.step_t, t0) + step_t = step_t.to(self.dtype) + if self.jump_t is None: + jump_t = torch.tensor([], dtype=self.dtype, device=self.y0.device) + else: + jump_t = _sort_tvals(self.jump_t, t0) + jump_t = jump_t.to(self.dtype) + counts = torch.cat([step_t, jump_t]).unique(return_counts=True)[1] + if (counts > 1).any(): + raise ValueError("`step_t` and `jump_t` must not have any repeated elements between them.") + + self.step_t = step_t + self.jump_t = jump_t + self.next_step_index = min(bisect.bisect(self.step_t.tolist(), t[0]), len(self.step_t) - 1) + self.next_jump_index = min(bisect.bisect(self.jump_t.tolist(), t[0]), len(self.jump_t) - 1) + + def _advance(self, next_t): + """Interpolate through the next time point, integrating as necessary.""" + n_steps = 0 + while next_t > self.rk_state.t1: + assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps) + self.rk_state = self._adaptive_step(self.rk_state) + n_steps += 1 + return _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t) + + def _advance_until_event(self, event_fn): + """Returns t, state(t) such that event_fn(t, state(t)) == 0.""" + if event_fn(self.rk_state.t1, self.rk_state.y1) == 0: + return (self.rk_state.t1, self.rk_state.y1) + + n_steps = 0 + sign0 = torch.sign(event_fn(self.rk_state.t1, self.rk_state.y1)) + while sign0 == torch.sign(event_fn(self.rk_state.t1, self.rk_state.y1)): + assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps) + self.rk_state = self._adaptive_step(self.rk_state) + n_steps += 1 + interp_fn = lambda t: _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, t) + return find_event(interp_fn, sign0, self.rk_state.t0, self.rk_state.t1, event_fn, self.atol) + + def _adaptive_step(self, rk_state): + """Take an adaptive Runge-Kutta step to integrate the ODE.""" + y0, f0, _, t0, dt, interp_coeff = rk_state + t1 = t0 + dt + # dtypes: self.y0.dtype (probably float32); self.dtype (probably float64) + # used for state and timelike objects respectively. + # Then: + # y0.dtype == self.y0.dtype + # f0.dtype == self.y0.dtype + # t0.dtype == self.dtype + # dt.dtype == self.dtype + # for coeff in interp_coeff: coeff.dtype == self.y0.dtype + + ######################################################## + # Assertions # + ######################################################## + assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item()) + assert torch.isfinite(y0).all(), 'non-finite values in state `y`: {}'.format(y0) + + ######################################################## + # Make step, respecting prescribed grid points # + ######################################################## + + on_step_t = False + if len(self.step_t): + next_step_t = self.step_t[self.next_step_index] + on_step_t = t0 < next_step_t < t0 + dt + if on_step_t: + t1 = next_step_t + dt = t1 - t0 + + on_jump_t = False + if len(self.jump_t): + next_jump_t = self.jump_t[self.next_jump_index] + on_jump_t = t0 < next_jump_t < t0 + dt + if on_jump_t: + on_step_t = False + t1 = next_jump_t + dt = t1 - t0 + + # Must be arranged as doing all the step_t handling, then all the jump_t handling, in case we + # trigger both. (i.e. interleaving them would be wrong.) + + y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, t1, tableau=self.tableau) + # dtypes: + # y1.dtype == self.y0.dtype + # f1.dtype == self.y0.dtype + # y1_error.dtype == self.dtype + # k.dtype == self.y0.dtype + + ######################################################## + # Error Ratio # + ######################################################## + error_ratio = _compute_error_ratio(y1_error, self.rtol, self.atol, y0, y1, self.norm) + accept_step = error_ratio <= 1 + # dtypes: + # error_ratio.dtype == self.dtype + + ######################################################## + # Update RK State # + ######################################################## + if accept_step: + t_next = t1 + y_next = y1 + interp_coeff = self._interp_fit(y0, y_next, k, dt) + if on_step_t: + if self.next_step_index != len(self.step_t) - 1: + self.next_step_index += 1 + if on_jump_t: + if self.next_jump_index != len(self.jump_t) - 1: + self.next_jump_index += 1 + # We've just passed a discontinuity in f; we should update f to match the side of the discontinuity + # we're now on. + f1 = self.func(t_next, y_next, perturb=Perturb.NEXT) + f_next = f1 + else: + t_next = t0 + y_next = y0 + f_next = f0 + dt_next = _optimal_step_size(dt, error_ratio, self.safety, self.ifactor, self.dfactor, self.order) + rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff) + return rk_state + + def _interp_fit(self, y0, y1, k, dt): + """Fit an interpolating polynomial to the results of a Runge-Kutta step.""" + dt = dt.type_as(y0) + y_mid = y0 + k.matmul(dt * self.mid).view_as(y0) + f0 = k[..., 0] + f1 = k[..., -1] + return _interp_fit(y0, y1, y_mid, f0, f1, dt) + + +def _sort_tvals(tvals, t0): + # TODO: add warning if tvals come before t0? + tvals = tvals[tvals >= t0] + return torch.sort(tvals).values diff --git a/third_party/torchdiffeq/torchdiffeq/_impl/scipy_wrapper.py b/third_party/torchdiffeq/torchdiffeq/_impl/scipy_wrapper.py new file mode 100644 index 0000000..a28c845 --- /dev/null +++ b/third_party/torchdiffeq/torchdiffeq/_impl/scipy_wrapper.py @@ -0,0 +1,53 @@ +import abc +import torch +from scipy.integrate import solve_ivp +from .misc import _handle_unused_kwargs +import numpy as np + +class ScipyWrapperODESolver(metaclass=abc.ABCMeta): + + def __init__(self, func, y0, rtol, atol, solver="LSODA", max_step=np.inf, **unused_kwargs): + unused_kwargs.pop('norm', None) + unused_kwargs.pop('grid_points', None) + unused_kwargs.pop('eps', None) + _handle_unused_kwargs(self, unused_kwargs) + del unused_kwargs + self.max_step = max_step + self.dtype = y0.dtype + self.device = y0.device + self.shape = y0.shape + self.y0 = y0.detach().cpu().numpy().reshape(-1) + self.rtol = rtol + self.atol = atol + self.solver = solver + self.func = convert_func_to_numpy(func, self.shape, self.device, self.dtype) + + def integrate(self, t): + if t.numel() == 1: + return torch.tensor(self.y0)[None].to(self.device, self.dtype) + t = t.detach().cpu().numpy() + sol = solve_ivp( + self.func, + t_span=[t.min(), t.max()], + y0=self.y0, + t_eval=t, + method=self.solver, + rtol=self.rtol, + atol=self.atol, + max_step=self.max_step + ) + sol = torch.tensor(sol.y).T.to(self.device, self.dtype) + sol = sol.reshape(-1, *self.shape) + return sol + + +def convert_func_to_numpy(func, shape, device, dtype): + + def np_func(t, y): + t = torch.tensor(t).to(device, dtype) + y = torch.reshape(torch.tensor(y).to(device, dtype), shape) + with torch.no_grad(): + f = func(t, y) + return f.detach().cpu().numpy().reshape(-1) + + return np_func diff --git a/third_party/torchdiffeq/torchdiffeq/_impl/solvers.py b/third_party/torchdiffeq/torchdiffeq/_impl/solvers.py new file mode 100644 index 0000000..6915f2b --- /dev/null +++ b/third_party/torchdiffeq/torchdiffeq/_impl/solvers.py @@ -0,0 +1,172 @@ +import abc +import torch +from .event_handling import find_event +from .misc import _handle_unused_kwargs + + +class AdaptiveStepsizeODESolver(metaclass=abc.ABCMeta): + def __init__(self, dtype, y0, norm, **unused_kwargs): + _handle_unused_kwargs(self, unused_kwargs) + del unused_kwargs + + self.y0 = y0 + self.dtype = dtype + + self.norm = norm + + def _before_integrate(self, t): + pass + + @abc.abstractmethod + def _advance(self, next_t): + raise NotImplementedError + + def integrate(self, t): + solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) + solution[0] = self.y0 + t = t.to(self.dtype) + self._before_integrate(t) + for i in range(1, len(t)): + solution[i] = self._advance(t[i]) + return solution + + +class AdaptiveStepsizeEventODESolver(AdaptiveStepsizeODESolver, metaclass=abc.ABCMeta): + + @abc.abstractmethod + def _advance_until_event(self, event_fn): + raise NotImplementedError + + def integrate_until_event(self, t0, event_fn): + t0 = t0.to(self.y0.device, self.dtype) + self._before_integrate(t0.reshape(-1)) + event_time, y1 = self._advance_until_event(event_fn) + solution = torch.stack([self.y0, y1], dim=0) + return event_time, solution + + +class FixedGridODESolver(metaclass=abc.ABCMeta): + order: int + + def __init__(self, func, y0, step_size=None, grid_constructor=None, interp="linear", perturb=False, **unused_kwargs): + self.atol = unused_kwargs.pop('atol') + unused_kwargs.pop('rtol', None) + unused_kwargs.pop('norm', None) + _handle_unused_kwargs(self, unused_kwargs) + del unused_kwargs + + self.func = func + self.y0 = y0 + self.dtype = y0.dtype + self.device = y0.device + self.step_size = step_size + self.interp = interp + self.perturb = perturb + + if step_size is None: + if grid_constructor is None: + self.grid_constructor = lambda f, y0, t: t + else: + self.grid_constructor = grid_constructor + else: + if grid_constructor is None: + self.grid_constructor = self._grid_constructor_from_step_size(step_size) + else: + raise ValueError("step_size and grid_constructor are mutually exclusive arguments.") + + @staticmethod + def _grid_constructor_from_step_size(step_size): + def _grid_constructor(func, y0, t): + start_time = t[0] + end_time = t[-1] + + niters = torch.ceil((end_time - start_time) / step_size + 1).item() + t_infer = torch.arange(0, niters, dtype=t.dtype, device=t.device) * step_size + start_time + t_infer[-1] = t[-1] + + return t_infer + return _grid_constructor + + @abc.abstractmethod + def _step_func(self, func, t0, dt, t1, y0): + pass + + def integrate(self, t): + time_grid = self.grid_constructor(self.func, self.y0, t) + assert time_grid[0] == t[0] and time_grid[-1] == t[-1] + + solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) + solution[0] = self.y0 + + j = 1 + y0 = self.y0 + for t0, t1 in zip(time_grid[:-1], time_grid[1:]): + dt = t1 - t0 + dy, f0 = self._step_func(self.func, t0, dt, t1, y0) + y1 = y0 + dy + + while j < len(t) and t1 >= t[j]: + if self.interp == "linear": + solution[j] = self._linear_interp(t0, t1, y0, y1, t[j]) + elif self.interp == "cubic": + f1 = self.func(t1, y1) + solution[j] = self._cubic_hermite_interp(t0, y0, f0, t1, y1, f1, t[j]) + else: + raise ValueError(f"Unknown interpolation method {self.interp}") + j += 1 + y0 = y1 + + return solution + + def integrate_until_event(self, t0, event_fn): + assert self.step_size is not None, "Event handling for fixed step solvers currently requires `step_size` to be provided in options." + + t0 = t0.type_as(self.y0) + y0 = self.y0 + dt = self.step_size + + sign0 = torch.sign(event_fn(t0, y0)) + max_itrs = 20000 + itr = 0 + while True: + itr += 1 + t1 = t0 + dt + dy, f0 = self._step_func(self.func, t0, dt, t1, y0) + y1 = y0 + dy + + sign1 = torch.sign(event_fn(t1, y1)) + + if sign0 != sign1: + if self.interp == "linear": + interp_fn = lambda t: self._linear_interp(t0, t1, y0, y1, t) + elif self.interp == "cubic": + f1 = self.func(t1, y1) + interp_fn = lambda t: self._cubic_hermite_interp(t0, y0, f0, t1, y1, f1, t) + else: + raise ValueError(f"Unknown interpolation method {self.interp}") + event_time, y1 = find_event(interp_fn, sign0, t0, t1, event_fn, float(self.atol)) + break + else: + t0, y0 = t1, y1 + + if itr >= max_itrs: + raise RuntimeError(f"Reached maximum number of iterations {max_itrs}.") + solution = torch.stack([self.y0, y1], dim=0) + return event_time, solution + + def _cubic_hermite_interp(self, t0, y0, f0, t1, y1, f1, t): + h = (t - t0) / (t1 - t0) + h00 = (1 + 2 * h) * (1 - h) * (1 - h) + h10 = h * (1 - h) * (1 - h) + h01 = h * h * (3 - 2 * h) + h11 = h * h * (h - 1) + dt = (t1 - t0) + return h00 * y0 + h10 * dt * f0 + h01 * y1 + h11 * dt * f1 + + def _linear_interp(self, t0, t1, y0, y1, t): + if t == t0: + return y0 + if t == t1: + return y1 + slope = (t - t0) / (t1 - t0) + return y0 + slope * (y1 - y0) diff --git a/third_party/yacs_config.py b/third_party/yacs_config.py new file mode 100644 index 0000000..0f81ff3 --- /dev/null +++ b/third_party/yacs_config.py @@ -0,0 +1,586 @@ +# --------------------------------------------------------------- +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# This file has been modified from a file in the following repo +# (released under the Apache License 2.0). +# +# Source: +# https://github.com/rbgirshick/yacs/blob/master/yacs/config.py +# +# The license for the original version of this file can be +# found in +# http://www.apache.org/licenses/LICENSE-2.0 +# The modifications +# to this file are subject to the NVIDIA Source Code License for +# LION located at the root directory. +# --------------------------------------------------------------- + +"""YACS -- Yet Another Configuration System is designed to be a simple +configuration management system for academic and industrial research +projects. + +See README.md for usage and examples. +""" +# this code is modified from https://github.com/rbgirshick/yacs/blob/master/yacs/config.py + +import copy +import io +# import logging + +from loguru import logger +import os +import sys +from ast import literal_eval + +import yaml + +# Flag for py2 and py3 compatibility to use when separate code paths are necessary +# When _PY2 is False, we assume Python 3 is in use +_PY2 = sys.version_info.major == 2 + +# Filename extensions for loading configs from files +_YAML_EXTS = {"", ".yaml", ".yml"} +_PY_EXTS = {".py"} + +# py2 and py3 compatibility for checking file object type +# We simply use this to infer py2 vs py3 +if _PY2: + _FILE_TYPES = (file, io.IOBase) +else: + _FILE_TYPES = (io.IOBase, ) + +# CfgNodes can only contain a limited set of valid types +_VALID_TYPES = {tuple, list, str, int, float, bool} +# py2 allow for str and unicode +if _PY2: + _VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821 + +# Utilities for importing modules from file paths +if _PY2: + # imp is available in both py2 and py3 for now, but is deprecated in py3 + import imp +else: + import importlib.util + +# logger = logging.getLogger(__name__) + + +class CfgNode(dict): + """ + CfgNode represents an internal node in the configuration tree. It's a simple + dict-like container that allows for attribute-based access to keys. + """ + + IMMUTABLE = "__immutable__" + DEPRECATED_KEYS = "__deprecated_keys__" + RENAMED_KEYS = "__renamed_keys__" + NEW_ALLOWED = "__new_allowed__" + + def __init__(self, init_dict=None, key_list=None, new_allowed=False): + """ + Args: + init_dict (dict): the possibly-nested dictionary to initailize the CfgNode. + key_list (list[str]): a list of names which index this CfgNode from the root. + Currently only used for logging purposes. + new_allowed (bool): whether adding new key is allowed when merging with + other configs. + """ + # Recursively convert nested dictionaries in init_dict into CfgNodes + init_dict = {} if init_dict is None else init_dict + key_list = [] if key_list is None else key_list + init_dict = self._create_config_tree_from_dict(init_dict, key_list) + super(CfgNode, self).__init__(init_dict) + # Manage if the CfgNode is frozen or not + self.__dict__[CfgNode.IMMUTABLE] = False + # Deprecated options + # If an option is removed from the code and you don't want to break existing + # yaml configs, you can add the full config key as a string to the set below. + self.__dict__[CfgNode.DEPRECATED_KEYS] = set() + # Renamed options + # If you rename a config option, record the mapping from the old name to the new + # name in the dictionary below. Optionally, if the type also changed, you can + # make the value a tuple that specifies first the renamed key and then + # instructions for how to edit the config file. + self.__dict__[CfgNode.RENAMED_KEYS] = { + # 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY', # Dummy example to follow + # 'EXAMPLE.OLD.KEY': ( # A more complex example to follow + # 'EXAMPLE.NEW.KEY', + # "Also convert to a tuple, e.g., 'foo' -> ('foo',) or " + # + "'foo:bar' -> ('foo', 'bar')" + # ), + } + + # Allow new attributes after initialisation + self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed + + @classmethod + def _create_config_tree_from_dict(cls, dic, key_list): + """ + Create a configuration tree using the given dict. + Any dict-like objects inside dict will be treated as a new CfgNode. + + Args: + dic (dict): + key_list (list[str]): a list of names which index this CfgNode from the root. + Currently only used for logging purposes. + """ + dic = copy.deepcopy(dic) + for k, v in dic.items(): + if isinstance(v, dict): + # Convert dict to CfgNode + dic[k] = cls(v, key_list=key_list + [k]) + else: + # Check for valid leaf type or nested CfgNode + _assert_with_logging( + _valid_type(v, allow_cfg_node=False), + "Key {} with value {} is not a valid type; valid types: {}" + .format(".".join(key_list + [str(k)]), type(v), + _VALID_TYPES), + ) + return dic + + def __getattr__(self, name): + if name in self: + return self[name] + else: + raise AttributeError(name) + + def __setattr__(self, name, value): + if self.is_frozen(): + raise AttributeError( + "Attempted to set {} to {}, but CfgNode is immutable".format( + name, value)) + + _assert_with_logging( + name not in self.__dict__, + "Invalid attempt to modify internal CfgNode state: {}".format( + name), + ) + _assert_with_logging( + _valid_type(value, allow_cfg_node=True), + "Invalid type {} for key {}; valid types = {}".format( + type(value), name, _VALID_TYPES), + ) + + self[name] = value + + def __str__(self): + def _indent(s_, num_spaces): + s = s_.split("\n") + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + r = "" + s = [] + for k, v in sorted(self.items()): + seperator = "\n" if isinstance(v, CfgNode) else " " + attr_str = "{}:{}{}".format(str(k), seperator, str(v)) + attr_str = _indent(attr_str, 2) + s.append(attr_str) + r += "\n".join(s) + return r + + def __repr__(self): + return "{}({})".format(self.__class__.__name__, + super(CfgNode, self).__repr__()) + + def to_dict(self, **kwargs): + """Dump to a string.""" + def convert_to_dict(cfg_node, key_list): + if not isinstance(cfg_node, CfgNode): + _assert_with_logging( + _valid_type(cfg_node), + "Key {} with value {} is not a valid type; valid types: {}" + .format(".".join(key_list), type(cfg_node), _VALID_TYPES), + ) + return cfg_node + else: + cfg_dict = dict(cfg_node) + for k, v in cfg_dict.items(): + cfg_dict[k] = convert_to_dict(v, key_list + [k]) + return cfg_dict + + self_as_dict = convert_to_dict(self, []) + return self_as_dict + + def dump(self, **kwargs): + """Dump to a string.""" + def convert_to_dict(cfg_node, key_list): + if not isinstance(cfg_node, CfgNode): + _assert_with_logging( + _valid_type(cfg_node), + "Key {} with value {} is not a valid type; valid types: {}" + .format(".".join(key_list), type(cfg_node), _VALID_TYPES), + ) + return cfg_node + else: + cfg_dict = dict(cfg_node) + for k, v in cfg_dict.items(): + cfg_dict[k] = convert_to_dict(v, key_list + [k]) + return cfg_dict + + self_as_dict = convert_to_dict(self, []) + return yaml.safe_dump(self_as_dict, **kwargs) + + def merge_from_file(self, cfg_filename): + """Load a yaml config file and merge it this CfgNode.""" + with open(cfg_filename, "r") as f: + cfg = self.load_cfg(f) + self.merge_from_other_cfg(cfg) + + def merge_from_other_cfg(self, cfg_other): + """Merge `cfg_other` into this CfgNode.""" + _merge_a_into_b(cfg_other, self, self, []) + + def merge_from_list(self, cfg_list): + """Merge config (keys, values) in a list (e.g., from command line) into + this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`. + """ + _assert_with_logging( + len(cfg_list) % 2 == 0, + "Override list has odd length: {}; it must be a list of pairs". + format(cfg_list), + ) + root = self + for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): + if root.key_is_deprecated(full_key): + continue + if root.key_is_renamed(full_key): + root.raise_key_rename_error(full_key) + key_list = full_key.split(".") + d = self + for subkey in key_list[:-1]: + _assert_with_logging(subkey in d, + "Non-existent key: {}".format(full_key)) + d = d[subkey] + subkey = key_list[-1] + _assert_with_logging(subkey in d, + "Non-existent key: {}".format(full_key)) + value = self._decode_cfg_value(v) + value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, + full_key) + d[subkey] = value + + def freeze(self): + """Make this CfgNode and all of its children immutable.""" + self._immutable(True) + + def defrost(self): + """Make this CfgNode and all of its children mutable.""" + self._immutable(False) + + def is_frozen(self): + """Return mutability.""" + return self.__dict__[CfgNode.IMMUTABLE] + + def _immutable(self, is_immutable): + """Set immutability to is_immutable and recursively apply the setting + to all nested CfgNodes. + """ + self.__dict__[CfgNode.IMMUTABLE] = is_immutable + # Recursively set immutable state + for v in self.__dict__.values(): + if isinstance(v, CfgNode): + v._immutable(is_immutable) + for v in self.values(): + if isinstance(v, CfgNode): + v._immutable(is_immutable) + + def clone(self): + """Recursively copy this CfgNode.""" + return copy.deepcopy(self) + + def register_deprecated_key(self, key): + """Register key (e.g. `FOO.BAR`) a deprecated option. When merging deprecated + keys a warning is generated and the key is ignored. + """ + _assert_with_logging( + key not in self.__dict__[CfgNode.DEPRECATED_KEYS], + "key {} is already registered as a deprecated key".format(key), + ) + self.__dict__[CfgNode.DEPRECATED_KEYS].add(key) + + def register_renamed_key(self, old_name, new_name, message=None): + """Register a key as having been renamed from `old_name` to `new_name`. + When merging a renamed key, an exception is thrown alerting to user to + the fact that the key has been renamed. + """ + _assert_with_logging( + old_name not in self.__dict__[CfgNode.RENAMED_KEYS], + "key {} is already registered as a renamed cfg key".format( + old_name), + ) + value = new_name + if message: + value = (new_name, message) + self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value + + def key_is_deprecated(self, full_key): + """Test if a key is deprecated.""" + if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]: + logger.warning( + "Deprecated config key (ignoring): {}".format(full_key)) + return True + return False + + def key_is_renamed(self, full_key): + """Test if a key is renamed.""" + return full_key in self.__dict__[CfgNode.RENAMED_KEYS] + + def raise_key_rename_error(self, full_key): + new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key] + if isinstance(new_key, tuple): + msg = " Note: " + new_key[1] + new_key = new_key[0] + else: + msg = "" + raise KeyError( + "Key {} was renamed to {}; please update your config.{}".format( + full_key, new_key, msg)) + + def is_new_allowed(self): + return self.__dict__[CfgNode.NEW_ALLOWED] + + @classmethod + def load_cfg(cls, cfg_file_obj_or_str): + """ + Load a cfg. + Args: + cfg_file_obj_or_str (str or file): + Supports loading from: + - A file object backed by a YAML file + - A file object backed by a Python source file that exports an attribute + "cfg" that is either a dict or a CfgNode + - A string that can be parsed as valid YAML + """ + _assert_with_logging( + isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str, )), + "Expected first argument to be of type {} or {}, but it was {}". + format(_FILE_TYPES, str, type(cfg_file_obj_or_str)), + ) + if isinstance(cfg_file_obj_or_str, str): + return cls._load_cfg_from_yaml_str(cfg_file_obj_or_str) + elif isinstance(cfg_file_obj_or_str, _FILE_TYPES): + return cls._load_cfg_from_file(cfg_file_obj_or_str) + else: + raise NotImplementedError( + "Impossible to reach here (unless there's a bug)") + + @classmethod + def _load_cfg_from_file(cls, file_obj): + """Load a config from a YAML file or a Python source file.""" + _, file_extension = os.path.splitext(file_obj.name) + if file_extension in _YAML_EXTS: + return cls._load_cfg_from_yaml_str(file_obj.read()) + elif file_extension in _PY_EXTS: + return cls._load_cfg_py_source(file_obj.name) + else: + raise Exception( + "Attempt to load from an unsupported file type {}; " + "only {} are supported".format(file_obj, + _YAML_EXTS.union(_PY_EXTS))) + + @classmethod + def _load_cfg_from_yaml_str(cls, str_obj): + """Load a config from a YAML string encoding.""" + cfg_as_dict = yaml.safe_load(str_obj) + return cls(cfg_as_dict) + + @classmethod + def _load_cfg_py_source(cls, filename): + """Load a config from a Python source file.""" + module = _load_module_from_file("yacs.config.override", filename) + _assert_with_logging( + hasattr(module, "cfg"), + "Python module from file {} must have 'cfg' attr".format(filename), + ) + VALID_ATTR_TYPES = {dict, CfgNode} + _assert_with_logging( + type(module.cfg) in VALID_ATTR_TYPES, + "Imported module 'cfg' attr must be in {} but is {} instead". + format(VALID_ATTR_TYPES, type(module.cfg)), + ) + return cls(module.cfg) + + @classmethod + def _decode_cfg_value(cls, value): + """ + Decodes a raw config value (e.g., from a yaml config files or command + line argument) into a Python object. + + If the value is a dict, it will be interpreted as a new CfgNode. + If the value is a str, it will be evaluated as literals. + Otherwise it is returned as-is. + """ + # Configs parsed from raw yaml will contain dictionary keys that need to be + # converted to CfgNode objects + if isinstance(value, dict): + return cls(value) + # All remaining processing is only applied to strings + if not isinstance(value, str): + return value + # Try to interpret `value` as a: + # string, number, tuple, list, dict, boolean, or None + try: + value = literal_eval(value) + # The following two excepts allow v to pass through when it represents a + # string. + # + # Longer explanation: + # The type of v is always a string (before calling literal_eval), but + # sometimes it *represents* a string and other times a data structure, like + # a list. In the case that v represents a string, what we got back from the + # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is + # ok with '"foo"', but will raise a ValueError if given 'foo'. In other + # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval + # will raise a SyntaxError. + except ValueError: + pass + except SyntaxError: + pass + return value + + +load_cfg = (CfgNode.load_cfg + ) # keep this function in global scope for backward compatibility + + +def _valid_type(value, allow_cfg_node=False): + return (type(value) in _VALID_TYPES) or (allow_cfg_node + and isinstance(value, CfgNode)) + + +def _merge_a_into_b(a, b, root, key_list): + """Merge config dictionary a into config dictionary b, clobbering the + options in b whenever they are also specified in a. + """ + _assert_with_logging( + isinstance(a, CfgNode), + "`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode), + ) + _assert_with_logging( + isinstance(b, CfgNode), + "`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode), + ) + + for k, v_ in a.items(): + full_key = ".".join(key_list + [k]) + + v = copy.deepcopy(v_) + v = b._decode_cfg_value(v) + + if k in b: + v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key) + # Recursively merge dicts + if isinstance(v, CfgNode): + try: + _merge_a_into_b(v, b[k], root, key_list + [k]) + except BaseException: + raise + else: + b[k] = v + elif b.is_new_allowed(): + b[k] = v + else: + if root.key_is_deprecated(full_key): + continue + elif root.key_is_renamed(full_key): + root.raise_key_rename_error(full_key) + else: + raise KeyError("Non-existent config key: {}".format(full_key)) + + +def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): + """Checks that `replacement`, which is intended to replace `original` is of + the right type. The type is correct if it matches exactly or is one of a few + cases in which the type can be easily coerced. + """ + original_type = type(original) + replacement_type = type(replacement) + + # The types must match (with some exceptions) + if replacement_type == original_type: + return replacement + + # Cast replacement from from_type to to_type if the replacement and original + # types match from_type and to_type + def conditional_cast(from_type, to_type): + if replacement_type == from_type and original_type == to_type: + logger.warning('cast {} to {}', from_type, to_type) + return True, to_type(replacement) + else: + return False, None + + # Conditionally casts + # list <-> tuple + casts = [(tuple, list), (list, tuple), (bool, int)] + # For py2: allow converting from str (bytes) to a unicode string + try: + casts.append((str, unicode)) # noqa: F821 + except Exception: + pass + + for (from_type, to_type) in casts: + converted, converted_value = conditional_cast(from_type, to_type) + if converted: + return converted_value + + # if original_type == int and replacement_type == bool: + # logger.warning( + # "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " + # "key: {}".format( + # original_type, replacement_type, original, replacement, full_key + # )) + # else: + raise ValueError( + "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " + "key: {}".format(original_type, replacement_type, original, + replacement, full_key)) + + +def _assert_with_logging(cond, msg): + if not cond: + logger.debug(msg) + assert cond, msg + + +def _load_module_from_file(name, filename): + if _PY2: + module = imp.load_source(name, filename) + else: + spec = importlib.util.spec_from_file_location(name, filename) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def same_cfg(cfg_node, cfg_other): + def flatten_dict(dd, sep='_', pf=''): + return {pf+sep+k if pf else k: v for kk, vv in dd.items() + for k, v in flatten_dict(vv, sep, kk).items()} \ + if isinstance(dd, dict) else {pf: dd} + + node_s = flatten_dict(cfg_node) + other_s = flatten_dict(cfg_other) + k0, k1 = list(node_s.keys()), list(other_s.keys()) + if sorted(k0) != sorted(k1): + print(f'[LEN]: {len(k0)} VS {len(k1)}') + k_diff1 = [i for i in k0 if i not in k1] + k_diff0 = [i for i in k1 if i not in k0] + print(f'[DIFF] keys: {k_diff1}; {k_diff0}') + assert (False), 'Diff key' + return False + for k, v in node_s.items(): + if k == 'exp_key': + continue + if other_s[k] != v: + msg = f'{k}: {v}; {other_s[k]}' + logger.info(msg) + assert (False), 'Diff key value ' + msg + return False + return True diff --git a/train_dist.py b/train_dist.py new file mode 100644 index 0000000..2786179 --- /dev/null +++ b/train_dist.py @@ -0,0 +1,251 @@ +# --------------------------------------------------------------- +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +# --------------------------------------------------------------- + +import importlib +import argparse +from loguru import logger +from comet_ml import Experiment +import torch +import numpy as np +import os +import sys +import torch.distributed as dist +from torch.multiprocessing import Process +from default_config import cfg as config +from utils import exp_helper, io_helper +from utils import utils + + +@logger.catch(onerror=lambda _: sys.exit(1), reraise=False) +def main(args, config): + # -- trainer -- # + logger.info('use trainer: {}', config.trainer.type) + trainer_lib = importlib.import_module(config.trainer.type) + Trainer = trainer_lib.Trainer + + if config.set_detect_anomaly: + # attention: this makes thing slow + torch.autograd.set_detect_anomaly(True) + logger.info( + '\n\n' + '!'*30 + '\nWARNING: ths set_detect_anomaly is turned on, it can slow down the training! \n' + '!'*30) + + # -- command init -- # + comet_key = config.comet_key + _, writer = utils.common_init(args.global_rank, + config.trainer.seed, config.save_dir, comet_key) + trainer = Trainer(config, args) + writer.add_hparams(config.to_dict(), vars(args)) + nparam = utils.count_parameters_in_M(trainer.model) + logger.info('param size = %fM ' % nparam) + writer.log_other('nparam', nparam) + + if args.global_rank == 0: + trainer.set_writer(writer) + writer.set_model_graph('{}'.format(trainer.model), overwrite=True) + if len(config.bash_name) > 0 and os.path.exists(config.bash_name): + writer.log_asset(config.bash_name) + if len(config.bash_name) > 0 and os.path.exists(os.path.join(config.save_dir, config.bash_name.split('/')[-1])): + writer.log_asset(os.path.join( + config.save_dir, config.bash_name.split('/')[-1])) + ckpt_dir = os.path.join(config.save_dir, 'checkpoints') + snapshot_file = os.path.join(config.save_dir, 'checkpoints', 'snapshot') + + # -- check if prev saved ckpt exist -- # + if os.path.exists(ckpt_dir) and os.path.exists(snapshot_file): + logger.info( + '[Detect saved snapshot at the checkpoint dir] resume from preemption!!! ') + args.resume = True + args.pretrained = os.path.join( + config.save_dir, 'checkpoints', 'snapshot') + else: + logger.info('not find any checkpoint: {}, (exist={}), or snapshot {}, (exist={})', + ckpt_dir, os.path.exists(ckpt_dir), snapshot_file, os.path.exists(snapshot_file)) + + # -- prepare -- # + if args.resume or args.eval_generation: + if args.pretrained is not None: + trainer.start_epoch = trainer.resume( + args.pretrained, eval_generation=args.eval_generation) + else: + raise NotImplementedError + elif args.pretrained is not None: + trainer.load_vae(args.pretrained) + + if not args.eval_generation: + trainer.train_epochs() + else: + logger.info('[skip_sample]={}', args.skip_sample) + + save_file = None + if not args.skip_nll: + trainer.eval_nll(trainer.step, ntest=args.ntest, save_file=True) + logger.info('save as : {}', save_file) + # vis sampled output + if not args.skip_sample: + trainer.vis_sample(num_vis=8, writer=trainer.writer, + step=trainer.step, include_pred_x0=False, + save_file=save_file) + trainer.eval_sample(trainer.step) + logger.info('done') + + # make all nodes wait for rank 0 to finish saving the files + # if args.distributed: + # dist.barrier() + + +def get_args(): + parser = argparse.ArgumentParser('encoder decoder examiner') + # experimental results + parser.add_argument('--exp_root', type=str, default='../exp', + help='location of the results') + # parser.add_argument('--save', type=str, default='exp', + # help='id used for storing intermediate results') + # parser.add_argument('--recont_with_local_prior', type=bool, default=False, + # help='eval nll with local prior sampled from normal distribution') + parser.add_argument('--skip_sample', type=int, default=0, + help='only eval nll, no sampling') + parser.add_argument('--skip_nll', type=int, default=0, + help='skip eval nll ') + # data + parser.add_argument('--ntest', type=str, default=None, + help='number of samples in eval_nll, if None, eval the whole val set') + parser.add_argument('--dataset', type=str, default='cifar10', + choices=['cifar10', 'celeba_64', 'celeba_256', + 'imagenet_32', 'ffhq', 'lsun_bedroom_128'], + help='which dataset to use') + parser.add_argument('--data', type=str, default='/tmp/nvae-diff/data', + help='location of the data corpus') + # DDP. + parser.add_argument('--autocast_train', action='store_true', default=True, + help='This flag enables FP16 in training.') + parser.add_argument('--autocast_eval', action='store_true', default=True, + help='This flag enables FP16 in evaluation.') + parser.add_argument('--num_proc_node', type=int, default=1, + help='The number of nodes in multi node env.') + parser.add_argument('--node_rank', type=int, default=0, + help='The index of node.') + parser.add_argument('--local_rank', type=int, default=0, + help='rank of process in the node') + parser.add_argument('--global_rank', type=int, default=0, + help='rank of process among all the processes') + parser.add_argument('--num_process_per_node', type=int, default=1, + help='number of gpus') + parser.add_argument('--master_address', type=str, default='127.0.0.1', + help='address for master') + parser.add_argument('--seed', type=int, default=1, + help='seed used for initialization') + parser.add_argument('--config', type=str, + help='The configuration file.', default='none') + parser.add_argument("opt", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER) + # Resume: + parser.add_argument('--resume', default=False, action='store_true') + parser.add_argument('--eval_generation', + default=False, action='store_true') + parser.add_argument('--pretrained', + default=None, + type=str, + help="Pretrained cehckpoint") + + args = parser.parse_args() + + # update config + if args.eval_generation or args.resume: + logger.info('[pretrained]: {}', args.pretrained) + args.config = os.path.dirname(args.pretrained) + '/../cfg.yml' + config.merge_from_file(args.config) + elif args.config != 'none': + logger.info('load config: {}', args.config) + cur_exp_name = config.exp_name + cur_hash = config.hash + + config.merge_from_file(args.config) + config.exp_name = cur_exp_name # not following the exp name here + config.hash = cur_hash # not following the exp name here + config.merge_from_list(args.opt) + + # Create log_name + EXP_ROOT = args.exp_root # os.environ.get('EXP_ROOT', '../exp/') + if config.exp_name == '' or config.exp_name == 'none': + config.hash = io_helper.hash_str('%s' % config) + 'h' + cfg_file_name = exp_helper.get_expname(config) + else: + cfg_file_name = config.exp_name + + # Currently save dir and log_dir are the same + if args.eval_generation: + config.save_dir = config.log_dir = config.log_name = os.path.dirname( + args.config) + if config.trainer.type == 'ddim': + tag = 'eval_ddim' + else: + tag = 'eval' + cfg_file_name += f'/{tag}/' + config.log_name += f'/{tag}/' + config.save_dir += f'/{tag}/' + config.log_dir += f'/{tag}/' + else: + config.log_name = os.path.join(EXP_ROOT, cfg_file_name) + config.save_dir = os.path.join(EXP_ROOT, cfg_file_name) + config.log_dir = os.path.join(EXP_ROOT, cfg_file_name) + os.makedirs(config.log_dir, exist_ok=True) + + # save config and log + if args.global_rank == 0 and not args.eval_generation: + logger.add(config.log_dir + '/train.log') + logger.info('EXP_ROOT: {} + exp name: {}, save dir: {}', EXP_ROOT, + cfg_file_name, config.save_dir) + saved_cfg = os.path.join(config.log_dir, 'cfg.yml') + with open(saved_cfg, 'w') as file: + file.write(config.dump()) + logger.info('save config at {}', saved_cfg) + elif args.eval_generation: + logger.add(config.log_dir + '/eval_gen.log') + logger.info('log dir: {}', config.log_dir) + + return args, config + + +if __name__ == '__main__': + args, config = get_args() + args.ntest = int(args.ntest) if args.ntest is not None else None + size = args.num_process_per_node + + if size > 1: + args.distributed = True + processes = [] + for rank in range(size): + logger.info('In Rank={}', rank) + args.local_rank = rank + global_rank = rank + args.node_rank * args.num_process_per_node + global_size = args.num_proc_node * args.num_process_per_node + args.global_size = global_size + args.global_rank = global_rank + logger.info('Node rank %d, local proc %d, global proc %d' % + (args.node_rank, rank, global_rank)) + p = Process(target=utils.init_processes, + args=(global_rank, global_size, main, args, config)) + p.start() + processes.append(p) + + for p in processes: + logger.info('join {}', args.local_rank) + p.join() + else: + # for debugging + args.distributed = False + args.global_size = 1 + utils.init_processes(0, size, main, args, config) + logger.info('should end now') + # if args.distributed: + # logger.info('destroy_process_group') + # dist.destroy_process_group() diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py new file mode 100644 index 0000000..17d38a8 --- /dev/null +++ b/trainers/base_trainer.py @@ -0,0 +1,852 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import os +import time +from abc import ABC, abstractmethod +from comet_ml import Experiment +import torch +import importlib +import numpy as np +from PIL import Image +from loguru import logger +import torchvision +import torch.distributed as dist +from utils.evaluation_metrics_fast import print_results +from utils.checker import * +from utils.vis_helper import visualize_point_clouds_3d +from utils.eval_helper import compute_score, get_ref_pt, get_ref_num +from utils import model_helper, exp_helper, data_helper +from utils.utils import infer_active_variables +from utils.data_helper import normalize_point_clouds +from utils.eval_helper import compute_NLL_metric +from utils.utils import AvgrageMeter +import clip + +class BaseTrainer(ABC): + def __init__(self, cfg, args): + self.cfg, self.args = cfg, args + self.scheduler = None + self.local_rank = args.local_rank + self.cur_epoch = 0 + self.start_epoch = 0 + self.epoch = 0 + self.step = 0 + self.writer = None + self.encoder = None + self.num_val_samples = cfg.num_val_samples + self.train_iter_kwargs = {} + self.num_points = self.cfg.data.tr_max_sample_points + self.best_eval_epoch = 0 + self.best_eval_score = -1 + self.use_grad_scalar = cfg.trainer.use_grad_scalar + device = torch.device('cuda:%d' % args.local_rank) + self.device_str = 'cuda:%d' % args.local_rank + self.t2s_input = [] + if cfg.clipforge.enable: + self.prepare_clip_model_data() + else: + self.clip_feat_list = None + + def set_writer(self, writer): + self.writer = writer + logger.info( + '\n'+'-'*10 + f'\n[url]: {self.writer.url}\n{self.cfg.save_dir}\n' + '-'*10) + + @abstractmethod + def train_iter(self, data, *args, **kwargs): + pass + + @abstractmethod + def sample(self, *args, **kwargs): + pass + + def log_val(self, val_info, writer=None, step=None, epoch=None, **kwargs): + if writer is not None: + for k, v in val_info.items(): + if step is not None: + writer.add_scalar(k, v, step) + else: + writer.add_scalar(k, v, epoch) + + def epoch_start(self, epoch): + pass + + def epoch_end(self, epoch, writer=None, **kwargs): + # Signal now that the epoch ends.... + if self.scheduler is not None: + self.scheduler.step(epoch=epoch) + if writer is not None: + writer.add_scalar( + 'train/opt_lr', self.scheduler.get_lr()[0], epoch) + if writer is not None: + writer.upload_meter(epoch=epoch, step=kwargs.get('step', None)) + + # --- util function -- + def save(self, save_name=None, epoch=None, step=None, appendix=None, save_dir=None, **kwargs): + d = { + 'opt': self.optimizer.state_dict(), + 'model': self.model.state_dict(), + 'epoch': epoch, + 'step': step + } + if appendix is not None: + d.update(appendix) + if self.use_grad_scalar: + d.update({'grad_scalar': self.grad_scalar.state_dict()}) + save_name = "epoch_%s_iters_%s.pt" % ( + epoch, step) if save_name is None else save_name + save_dir = self.cfg.save_dir if save_dir is None else save_dir + path = os.path.join(save_dir, "checkpoints", save_name) + os.makedirs(os.path.dirname(path), exist_ok=True) + logger.info('save model as : {}', path) + torch.save(d, path) + return path + + def filter_name(self, ckpt): + ckpt_new = {} + for k, v in ckpt.items(): + if k[:7] == 'module.': + kn = k[7:] + elif k[:13] == 'model.module.': + kn = k[13:] + else: + kn = k + ckpt_new[kn] = v + return ckpt_new + + def resume(self, path, strict=True, **kwargs): + ckpt = torch.load(path) + strict = True + model_weight = ckpt['model'] if 'model' in ckpt else ckpt['model_state'] + vae_weight = self.filter_name(model_weight) + self.model.load_state_dict(vae_weight, strict=strict) + if 'opt' in ckpt: + self.optimizer.load_state_dict(ckpt['opt']) + else: + logger.info('no optimizer found in ckpt') + + start_epoch = ckpt['epoch'] + self.epoch = start_epoch + self.cur_epoch = start_epoch + self.step = ckpt.get('step', 0) + logger.info('resume from : {}, epo={}', path, start_epoch) + if self.use_grad_scalar: + assert('grad_scalar' in ckpt), 'otherwise set it false' + self.grad_scalar.load_state_dict(ckpt['grad_scalar']) + return start_epoch + + def build_model(self): + cfg, args = self.cfg, self.args + if args.distributed: + dist.barrier() + model_lib = importlib.import_module(cfg.shapelatent.model) + model = model_lib.Model(cfg) + return model + + def build_data(self): + logger.info('start build_data') + cfg, args = self.cfg, self.args + self.args.eval_trainnll = cfg.eval_trainnll + data_lib = importlib.import_module(cfg.data.type) + loaders = data_lib.get_data_loaders(cfg.data, args) + train_loader = loaders['train_loader'] + test_loader = loaders['test_loader'] + return train_loader, test_loader + + def train_epochs(self): + """ train for number of epochs; """ + # main training loop + cfg, args = self.cfg, self.args + train_loader = self.train_loader + writer = self.writer + + if cfg.viz.log_freq <= -1: # treat as per epoch + cfg.viz.log_freq = int(- cfg.viz.log_freq * len(train_loader)) + if cfg.viz.viz_freq <= -1: + cfg.viz.viz_freq = - cfg.viz.viz_freq * len(train_loader) + + logger.info("[rank=%d] Start epoch: %d End epoch: %d, batch-size=%d | " + "Niter/epo=%d | log freq=%d, viz freq %d, val freq %d " % + (args.local_rank, + self.start_epoch, cfg.trainer.epochs, cfg.data.batch_size, + len(train_loader), + cfg.viz.log_freq, cfg.viz.viz_freq, cfg.viz.val_freq)) + tic0 = time.time() + step = 0 + if args.global_rank == 0: + tic_log = time.time() + self.num_total_iter = cfg.trainer.epochs * len(train_loader) + self.model.num_total_iter = self.num_total_iter + + for epoch in range(self.start_epoch, cfg.trainer.epochs): + self.cur_epoch = epoch + if args.global_rank == 0: + tic_epo = time.time() + if args.distributed: + train_loader.sampler.set_epoch(epoch) + if args.global_rank == 0 and cfg.trainer.type in ['trainers.voxel2pts', 'trainers.voxel2pts_ada'] and epoch == 0: + self.eval_nll(step=step) + epoch_loss = [] + self.epoch_start(epoch) + + # remove disabled latent variables by setting their mixing component to a small value + if epoch == 0 and cfg.sde.mixed_prediction and cfg.sde.drop_inactive_var: + raise NotImplementedError + + ## -- train for one epoch -- ## + for bidx, data in enumerate(train_loader): + # let step start from 0 instead of 1 + step = bidx + len(train_loader) * epoch + + if args.global_rank == 0 and self.writer is not None: + tic_iter = time.time() + + # -- train for one iter -- # + logs_info = self.train_iter(data, step=step, + **self.train_iter_kwargs) + + # -- log information within epoch -- # + if self.args.global_rank == 0: + epoch_loss.append(logs_info['loss']) + if self.args.global_rank == 0 and ( + time.time() - tic_log > 60 + ): # log per min + logger.info('[R%d] | E%d iter[%3d/%3d] | [Loss] %2.2f | ' + '[exp] %s | [step] %5d | [url] %s ' % ( + args.global_rank, epoch, bidx, len(train_loader), + np.array(epoch_loss).mean(), + cfg.save_dir, step, writer.url + )) + tic_log = time.time() + + # -- visualize rec and samples -- # + if step % int(cfg.viz.log_freq) == 0 and \ + args.global_rank == 0 and not ( + step == 0 and cfg.sde.ode_sample and + (cfg.trainer.type == 'trainers.train_prior' or cfg.trainer.type == + 'trainers.train_2prior') # this case, skip sampling at first step + ): + avg_loss = np.array(epoch_loss).mean() + epo_loss = [] # clean up epoch loss + self.log_loss({'epo_loss': avg_loss}, + writer=writer, step=step) + visualize = int(cfg.viz.viz_freq) > 0 and \ + (step) % int(cfg.viz.viz_freq) == 0 + vis_recont = visualize + if vis_recont: + self.vis_recont(logs_info, writer, step) + if visualize: + self.model.eval() + self.vis_sample(writer, step=step, + include_pred_x0=False) + self.model.train() + + # -- timer -- # + if args.global_rank == 0 and self.writer is not None: + time_iter = time.time() - tic_iter + self.writer.avg_meter('time_iter', time_iter, step=step) + ## -- log information after one epoch -- ## + if args.global_rank == 0: + epo_time = (time.time() - tic_epo) / 60.0 # min + logger.info('[R%d] | E%d iter[%3d/%3d] | [Loss] %2.2f ' + '| [exp] %s | [step] %5d | [url] %s | [time] %.1fm (~%dh) |' + '[best] %d %.3fx1e-2 ' % ( + args.global_rank, epoch, bidx, len(train_loader), + np.array(epoch_loss).mean(), + cfg.save_dir, step, writer.url, + epo_time, epo_time * (cfg.trainer.epochs - epoch) / 60, + self.best_eval_epoch, self.best_eval_score*1e2 + )) + tic_log = time.time() # reset tic_log + + ## -- save model -- ## + if (epoch + 1) % int(cfg.viz.save_freq) == 0 and \ + int(cfg.viz.save_freq) > 0 and args.global_rank == 0: + self.save(epoch=epoch, step=step) + if ((time.time() - tic0) / 60 > cfg.snapshot_min) and \ + args.global_rank == 0: # save every 30 min + file_name = self.save( + save_name='snapshot_bak', epoch=epoch, step=step) + if file_name is None: + file_name = os.path.join( + self.cfg.save_dir, "checkpoints", "snapshot_bak") + os.rename(file_name, file_name.replace( + 'snapshot_bak', 'snapshot')) + tic0 = time.time() + + ## -- run eval -- ## + if int(cfg.viz.val_freq) > 0 and (epoch + 1) % int(cfg.viz.val_freq) == 0 and \ + args.global_rank == 0: + eval_score = self.eval_nll(step=step, save_file=False) + if eval_score < self.best_eval_score or self.best_eval_score < 0: + self.save(save_name='best_eval.pth', # save_dir=snapshot_dir, + epoch=epoch, step=step) + self.best_eval_score = eval_score + self.best_eval_epoch = epoch + + ## -- Signal the trainer to cleanup now that an epoch has ended -- ## + self.epoch_end(epoch, writer=writer, step=step) + ### -- end of the training -- ### + if args.global_rank == 0: + self.eval_nll(step=step) + if self.cfg.trainer.type == 'trainers.train_prior': # and args.global_rank == 0: + self.model.eval() + self.eval_sample(step) + logger.info('debugging eval-sample; exit now') + + @torch.no_grad() + def log_loss(self, train_info, writer=None, step=None, **kwargs): + """ write to tensorboard and visualize + """ + if writer is None: + return + + # Log training information to tensorboard + train_info = { + k: (v.cpu() if not isinstance(v, float) else v) + for k, v in train_info.items() + } + for k, v in train_info.items(): + if not ('loss' in k): + continue + if step is not None: + writer.add_scalar('train/' + k, v, step) + else: + assert epoch is not None + writer.add_scalar('train/' + k, v, epoch) + + # --------------------------------------------- # + # visulization function and sampling function # + # --------------------------------------------- # + @torch.no_grad() + def vis_recont(self, output, writer, step, normalize_pts=False): + """ + Args: + x_0: Input point cloud, (B, N, d). + """ + if writer is None: + return 0 + # x_0: target + # x_0_pred: recont + # x_t: intermidiate sample at t (if t is not None) + x_0_pred, x_0, x_t = output.get('x_0_pred', None), \ + output.get('x_0', None), output.get('x_t', None) + if x_0_pred is None or x_0 is None or x_t is None: + logger.info('x_0_pred: None? {}; x_0: None? {}, x_t: None? {}', + x_0_pred is None, x_0 is None, x_t is None) + return 0 + + CHECK3D(x_0) + CHECK3D(x_t) + CHECK3D(x_0_pred) + + t = output.get('t', None) + nvis = min(max(x_0.shape[0], 2), 5) + img_list = [] + for b in range(nvis): + x_list, name_list = [], [] + x_list.append(x_0_pred[b]) + name_list.append('pred') + + if t is not None and t[b] > 0: + x_t_name = 'x_t%d' % t[b].item() + name_list.append(x_t_name) + x_list.append(x_t[b]) + + x_list.append(x_0[b]) + name_list.append('target') + + for k, v in output.items(): + if 'vis/' in k: + x_list.append(v[b]) + name_list.append(k) + if normalize_pts: + x_list = normalize_point_clouds(x_list) + + vis_order = self.cfg.viz.viz_order + vis_args = {'vis_order': vis_order} + + img = visualize_point_clouds_3d(x_list, name_list, **vis_args) + img_list.append(img) + img_list = torchvision.utils.make_grid( + [torch.as_tensor(a) for a in img_list], pad_value=0) + writer.add_image('vis_out/recont-train', img_list, step) + + @torch.no_grad() + def eval_sample(self, step=0): + """ compute sample metric: MMD,COV,1-NNA """ + writer = self.writer + batch_size_test = self.cfg.data.batch_size_test + input_dim = self.cfg.ddpm.input_dim + ddim_step = self.cfg.eval_ddim_step + device = model_helper.get_device(self.model) + test_loader = self.test_loader + test_size = batch_size_test * len(test_loader) + sample_num_points = self.cfg.data.tr_max_sample_points + cates = self.cfg.data.cates + num_ref = get_ref_num( + cates) if self.cfg.num_ref == 0 else self.cfg.num_ref + + # option for post-processing + if self.cfg.data.recenter_per_shape or self.cfg.data.normalize_shape_box or self.cfg.data.normalize_range: + norm_box = True + else: + norm_box = False + logger.info('norm_box: {}, recenter : {}, shapebox: {}', + norm_box, self.cfg.data.recenter_per_shape, + self.cfg.data.normalize_shape_box) + + # get exp tag and output name + tag = exp_helper.get_evalname(self.cfg) + if not self.cfg.sde.ode_sample: + tag += 'diet' + else: + tag += 'ode%d' % self.cfg.sde.ode_sample + output_name = os.path.join( + self.cfg.save_dir, f'samples_{step}{tag}.pt') + logger.info('batch_size_test={}, test_size={}, saved output: {} ', + batch_size_test, test_size, output_name) + gen_pcs = [] + + ### ---- ref_pcs ---- # + ##ref_pcs = [] + ##m_pcs, s_pcs = [], [] + # for i, data in enumerate(test_loader): + ## tr_points = data['tr_points'] + ## m, s = data['mean'], data['std'] + # ref_pcs.append(tr_points) # B,N,3 + # m_pcs.append(m.float()) + # s_pcs.append(s.float()) + ## sample_num_points = tr_points.shape[1] + # assert(tr_points.shape[2] in [3,6] + # ), f'expect B,N,3; get {tr_points.shape}' + ##ref_pcs = torch.cat(ref_pcs, dim=0) + ##m_pcs = torch.cat(m_pcs, dim=0) + ##s_pcs = torch.cat(s_pcs, dim=0) + # if VIS: + ## img_list = [] + # for i in range(4): + ## norm_ref, norm_gen = data_helper.normalize_point_clouds([ref_pcs[i], ref_pcs[-i]]) + ## img = visualize_point_clouds_3d([norm_ref, norm_gen], [f'ref-{i}', f'ref-{-i}']) + ## img_list.append(torch.as_tensor(img) / 255.0) + ## path = output_name.replace('.pt', '_ref.png') + # torchvision.utils.save_image(img_list, path) + ## grid = torchvision.utils.make_grid(img_list) + # ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() + ## writer.add_image('ref', grid, 0) + # logger.info(writer.url) + ## logger.info('save vis at {}', path) + + # ---- gen_pcs ---- # + if True: + len_test_loader = num_ref // batch_size_test + 1 + if self.args.distributed: + num_gen_iter = max(1, len_test_loader // self.args.global_size) + if num_gen_iter * batch_size_test * self.args.global_size < num_ref: + num_gen_iter = num_gen_iter + 1 + else: + num_gen_iter = len_test_loader + + index_start = 0 + logger.info('Rank={}, num_gen_iter: {}; num_ref={}, batch_size_test={}', + self.args.global_rank, num_gen_iter, num_ref, batch_size_test) + seed = self.cfg.trainer.seed + for i in range(0, num_gen_iter): + torch.manual_seed(seed + i) + np.random.seed(seed + i) + torch.cuda.manual_seed(seed + i) + torch.cuda.manual_seed_all(seed + i) + + logger.info('#%d/%d; BS=%d' % + (i, num_gen_iter, batch_size_test)) + # ---- draw samples ---- # + self.index_start = index_start + x = self.sample(num_shapes=batch_size_test, + num_points=sample_num_points, + device_str=device.type, + for_vis=False, + ddim_step=ddim_step).permute(0, 2, 1).contiguous() # B,3,N->B,N,3 + assert( + x.shape[-1] == input_dim), f'expect x: B,N,{input_dim}; get {x.shape}' + index_start = index_start + batch_size_test + gen_pcs.append(x.detach().cpu()) + + gen_pcs = torch.cat(gen_pcs, dim=0) + if self.args.distributed: + gen_pcs = gen_pcs.to(torch.device(self.device_str)) + logger.info('before gather: {}, rank={}', + gen_pcs.shape, self.args.global_rank) + gen_pcs_list = [torch.zeros_like(gen_pcs) + for _ in range(self.args.global_size)] + dist.all_gather(gen_pcs_list, gen_pcs) + gen_pcs = torch.cat(gen_pcs_list, dim=0).cpu() + logger.info('after gather: {}, rank={}', + gen_pcs.shape, self.args.global_rank) + logger.info('save as %s' % output_name) + if self.args.global_rank == 0: + torch.save(gen_pcs, output_name) + else: + logger.info('return for rank {}', self.args.global_rank) + return # only do eval on one gpu + if writer is not None: + img_list = [] + for i in range(1): + gen_list = [gen_pcs[k] for k in range(len(gen_pcs))][:8] + norm_ref = data_helper.normalize_point_clouds(gen_list) + img = visualize_point_clouds_3d(norm_ref, [f'gen-{k}' for k in range(len(norm_ref))] + ) + img_list.append(torch.as_tensor(img) / 255.0) + grid = torchvision.utils.make_grid(img_list) + logger.info('ndarr: {}, range: {} img list: {} ', grid.shape, + grid.max(), img_list[0].shape, img_list[0].max()) + writer.add_image('sample', grid, step) + logger.info('\n\t' + writer.url) + #logger.info('early exit') + # exit() + + shape_str = '{}: gen_pcs: {}'.format(self.cfg.save_dir, gen_pcs.shape) + logger.info(shape_str) + + ref = get_ref_pt(self.cfg.data.cates, self.cfg.data.type) + if ref is None: + logger.info('Not computing score') + return 1 + step_str = '%dk' % (step / 1000.0) + epoch_str = '%.1fk' % (self.epoch / 1000.0) + print_kwargs = {'dataset': self.cfg.data.cates, + 'hash': self.cfg.hash + tag, + 'step': step_str, + 'epoch': epoch_str+'-'+os.path.basename(ref).split('.')[0]} + self.model = self.model.cpu() + torch.cuda.empty_cache() + # -- compute the generation metric -- # + results = compute_score(output_name, ref_name=ref, + writer=writer, + batch_size_test=min( + 5, self.cfg.data.batch_size_test), + norm_box=norm_box, + **print_kwargs) + + self.model = self.model.to(device) + # ---- write to logger ---- # + writer.add_scalar('test/Coverage_CD', results['lgan_cov-CD'], step) + writer.add_scalar('test/Coverage_EMD', results['lgan_cov-EMD'], step) + writer.add_scalar('test/MMD_CD', results['lgan_mmd-CD'], step) + writer.add_scalar('test/MMD_EMD', results['lgan_mmd-EMD'], step) + writer.add_scalar('test/1NN_CD', results['1-NN-CD-acc'], step) + writer.add_scalar('test/1NN_EMD', results['1-NN-EMD-acc'], step) + writer.add_scalar('test/JSD', results['jsd'], step) + msg = f'step={step}' + msg += '\n[Test] MinMatDis | CD %.6f | EMD %.6f' % ( + results['lgan_mmd-CD'], results['lgan_mmd-EMD']) + msg += '\n[Test] Coverage | CD %.6f | EMD %.6f' % ( + results['lgan_cov-CD'], results['lgan_cov-EMD']) + msg += '\n[Test] 1NN-Accur | CD %.6f | EMD %.6f' % ( + results['1-NN-CD-acc'], results['1-NN-EMD-acc']) + msg += '\n[Test] JsnShnDis | %.6f ' % (results['jsd']) + + logger.info(msg) + with open(os.path.join(self.cfg.save_dir, 'eval_out.txt'), 'a') as f: + f.write(shape_str+'\n') + f.write(msg+'\n') + # self.cfg.data.cates, self.cfg.hash, step_str, epoch_str) + msg = print_results(results, **print_kwargs) + with open(os.path.join(self.cfg.save_dir, 'eval_out.txt'), 'a') as f: + f.write(msg+'\n') + logger.info('\n\t' + writer.url) + + def vis_sample(self, writer, num_vis=None, step=0, include_pred_x0=True, + save_file=None): + if num_vis is None: + num_vis = self.num_val_samples + logger.info("Sampling.. train-step=%s | N=%d" % (step, num_vis)) + tic = time.time() + # get three list with entry: [L,N,3] + # traj, traj_x0, time_list + traj, pred_x0 = self.sample(num_points=self.num_points, + num_shapes=num_vis, for_vis=True, use_ddim=True, + save_file=save_file) + + toc = time.time() + logger.info('sampling take %.1f sec' % (toc-tic)) + + # display only a few steps + num_shapes = num_vis + vis_num_steps = len(traj) + vis_index = list(traj.keys()) + vis_index = vis_index[::-1] + display_num_step = 5 + step_size = max(1, vis_num_steps // 5) + display_num_step_list = [] + for k in range(0, vis_num_steps, step_size): + display_num_step_list.append(vis_index[k]) + if self.num_steps not in display_num_step_list and self.num_steps in traj: + display_num_step_list.append(self.num_steps) + logger.info('saving vis with N={}', len(display_num_step_list)) + alltraj_list = [] + allpred_x0_list = [] + allstep_list = [] + for b in range(num_shapes): + traj_list = [] + pred_x0_list = [] + step_list = [] + for k in display_num_step_list: + v = traj[k] + traj_list.append(v[b].permute(1, 0).contiguous()) + v = pred_x0[k] + pred_x0_list.append(v[b].permute(1, 0).contiguous()) + step_list.append(k) + # B3N -> 3,N -> N,3 use first sample only + alltraj_list.append(traj_list) + allpred_x0_list.append(pred_x0_list) + allstep_list.append(step_list) + traj, traj_x0, time_list = alltraj_list, allpred_x0_list, allstep_list + + # vis the final images, + all_imgs = [] + all_imgs_torchvis = [] # no preconcat in the image, left to the torchvision + all_imgs_torchvis_norm = [] # no preconcat in the image, left to the torchvision + for idx in range(num_vis): + pcs = traj[idx][0:1] # 1,N,3 + img = [] + # vis the normalized point cloud + title_list = ['#%d normed x_%d' % (idx, 0)] + norm_pcs = data_helper.normalize_point_clouds(pcs) + img.append(visualize_point_clouds_3d(norm_pcs, title_list, + self.cfg.viz.viz_order)) + all_imgs_torchvis_norm.append(img[-1] / 255.0) + if include_pred_x0: + title_list = ['#%d p(x_0|x_%d,t)' % (idx, 0)] + img.append(visualize_point_clouds_3d(traj_x0[idx][0:1], title_list, + self.cfg.viz.viz_order)) + # concat along the height + all_imgs.append(np.concatenate(img, axis=1)) + + # concatenate along the width dimension + img = np.concatenate(all_imgs, axis=2) + writer.add_image('summary/sample', torch.as_tensor(img), step) + + path = os.path.join(self.cfg.save_dir, 'vis', 'sample%06d.png' % step) + if not os.path.exists(os.path.dirname(path)): + os.makedirs(os.path.dirname(path)) + + img_list = [torch.as_tensor(a) for a in all_imgs_torchvis_norm] + grid = torchvision.utils.make_grid(img_list) + ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute( + 1, 2, 0).to('cpu', torch.uint8).numpy() + im = Image.fromarray(ndarr) + im.save(path) + logger.info('save as {}; url: {} ', path, writer.url) + + def prepare_vis_data(self): + device = torch.device(self.device_str) + num_val_samples = self.num_val_samples + c = 0 + val_x = [] + val_input = [] + val_cls = [] + prior_cond = [] + for val_batch in self.test_loader: + val_x.append(val_batch['tr_points']) + val_cls.append(val_batch['cate_idx']) + if 'input_pts' in val_batch: # this is the input_pts to the vae model + val_input.append(val_batch['input_pts']) + if 'prior_cond' in val_batch: + prior_cond.append(val_batch['prior_cond']) + c += val_x[-1].shape[0] + if c >= num_val_samples: + break + self.val_x = torch.cat(val_x, dim=0)[:num_val_samples].to(device) + # this line may trigger error, change dataset output cate_idx from string to int can fix this issue + self.val_cls = torch.cat(val_cls, dim=0)[:num_val_samples].to(device) + self.prior_cond = torch.cat(prior_cond, dim=0)[:num_val_samples].to( + device) if len(prior_cond) else None + self.val_input = torch.cat(val_input, dim=0)[:num_val_samples].to( + device) if len(val_input) else None + c = 0 + tr_x = [] + m_x = [] + s_x = [] + tr_cls = [] + logger.info('[prepare_vis_data] len of train_loader: {}', + len(self.train_loader)) + assert(len(self.train_loader) > 0), f'get zero length train_loader, it could be the batch_size > the number of training sample, and the train drop_last is turn off' + for tr_batch in self.train_loader: + tr_x.append(tr_batch['tr_points']) + m_x.append(tr_batch['mean']) + s_x.append(tr_batch['std']) + tr_cls.append(tr_batch['cate_idx'].view(-1)) + c += tr_x[-1].shape[0] + if c >= num_val_samples: + break + self.tr_cls = torch.cat(tr_cls, dim=0)[:num_val_samples].to(device) + self.tr_x = torch.cat(tr_x, dim=0)[:num_val_samples].to(device) + self.m_pcs = torch.cat(m_x, dim=0)[:num_val_samples].to(device) + self.s_pcs = torch.cat(s_x, dim=0)[:num_val_samples].to(device) + logger.info('tr_x: {}, m_pcs: {}, s_pcs: {}, val_x: {}', + self.tr_x.shape, self.m_pcs.shape, self.s_pcs.shape, self.val_x.shape) + + self.w_prior = torch.randn( + [num_val_samples, self.cfg.shapelatent.latent_dim]).to(device) + if self.clip_feat_list is not None: + self.clip_feat_test = [] + for k in range(len(self.clip_feat_list)): + for i in range(num_val_samples // len(self.clip_feat_list)): + self.clip_feat_test.append(self.clip_feat_list[k]) + for i in range(num_val_samples - len(self.clip_feat_test)): + self.clip_feat_test.append(self.clip_feat_list[-1]) + self.clip_feat_test = torch.stack(self.clip_feat_test, dim=0) + logger.info('[VIS data] clip_feat_test: {}', + self.clip_feat_test.shape) + if self.clip_feat_test.shape[0] > num_val_samples: + self.clip_feat_test = self.clip_feat_test[:num_val_samples] + else: + self.clip_feat_test = None + + def build_other_module(self): + logger.info('no other module to build') + pass + + def swap_vae_param_if_need(self): + if self.cfg.ddpm.ema: + self.optimizer.swap_parameters_with_ema(store_params_in_ema=True) + + # -- shared method for all model with vae component -- # + @torch.no_grad() + def eval_nll(self, step, ntest=None, save_file=False): + loss_dict = {} + cfg = self.cfg + self.swap_vae_param_if_need() + args = self.args + device = torch.device('cuda:%d' % args.local_rank) + tag = exp_helper.get_evalname(self.cfg) + eval_trainnll = 0 + if eval_trainnll: + data_loader = self.train_loader + tag += '-train' + else: + data_loader = self.test_loader + gen_pcs, ref_pcs = [], [] + output_name = os.path.join(self.cfg.save_dir, f'recont_{step}{tag}.pt') + output_name_metric = os.path.join( + self.cfg.save_dir, f'recont_{step}{tag}_metric.pt') + shape_id_start = 0 + batch_metric_all = {} + + for vid, val_batch in enumerate(data_loader): + if vid % 30 == 1: + logger.info('eval: {}/{}', vid, len(data_loader)) + val_x = val_batch['tr_points'].to(device) + m, s = val_batch['mean'], val_batch['std'] + B, N, C = val_x.shape + m = m.view(B, 1, -1) + s = s.view(B, 1, -1) + inputs = val_batch['input_pts'].to( + device) if 'input_pts' in val_batch else None # the noisy points + model_kwargs = {} + + output = self.model.get_loss( + val_x, it=step, is_eval_nll=1, noisy_input=inputs, **model_kwargs) + for k, v in output.items(): + if 'print/' in k: + k = k.split('print/')[-1] + if k not in loss_dict: + loss_dict[k] = AvgrageMeter() + v = v.mean().item() if torch.is_tensor(v) else v + loss_dict[k].update(v) + + gen_x = output['final_pred'] + if gen_x.shape[1] > val_x.shape[1]: + tr_idxs = np.random.permutation(np.arange(gen_x.shape[1]))[ + :val_x.shape[1]] + gen_x = gen_x[:, tr_idxs] + + gen_x = gen_x.cpu() + val_x = val_x.cpu() + gen_x[:, :, :3] = gen_x[:, :, :3] * s + m + val_x[:, :, :3] = val_x[:, :, :3] * s + m + gen_pcs.append(gen_x.detach().cpu()) + ref_pcs.append(val_x.detach().cpu()) + if ntest is not None and shape_id_start >= int(ntest): + logger.info('!! reach number of test={}; has test: {}', + ntest, shape_id_start) + break + shape_id_start += B + # summarized batch-metric if any + for k, v in batch_metric_all.items(): + if len(v) == 0: + continue + v = torch.cat(v, dim=0) + logger.info('{}={}', k, v.mean()) + + gen_pcs = torch.cat(gen_pcs, dim=0) + ref_pcs = torch.cat(ref_pcs, dim=0) + + # Save + if self.writer is not None: + img_list = [] + for i in range(10): + points = gen_pcs[i] + points = normalize_point_clouds([points])[0] + img = visualize_point_clouds_3d([points], bound=1.0) + img_list.append(img) + img = np.concatenate(img_list, axis=2) + self.writer.add_image('nll/rec', torch.as_tensor(img), step) + if save_file: + logger.info('reconstruct point clouds..., output shape: {}, save as {}', + gen_pcs.shape, output_name) + torch.save(gen_pcs, output_name) + + results = compute_NLL_metric( + gen_pcs[:, :, :3], ref_pcs[:, :, :3], device, self.writer, output_name, batch_size=20, step=step) + score = 0 + for n, v in results.items(): + if 'detail' in n: + continue + if self.writer is not None: + logger.info('add: {}', n) + self.writer.add_scalar('eval/%s' % (n), v, step) + if 'CD' in n: + score = v + self.swap_vae_param_if_need() + return score + + def prepare_clip_model_data(self): + cfg = self.cfg + self.clip_model, self.clip_preprocess = clip.load(cfg.clipforge.clip_model, + device=self.device_str) + self.test_img_path = [] + if cfg.data.cates == 'chair': + input_t = [ + "an armchair in the shape of an avocado. an armchair imitating a avocado"] + text = clip.tokenize(input_t).to(self.device_str) + elif cfg.data.cates == 'car': + input_t = ["a ford model T", "a pickup", "an off-road vehicle"] + text = clip.tokenize(input_t).to(self.device_str) + elif cfg.data.cates == 'all': + input_t = ['a boeing', 'an f-16', 'an suv', 'a chunk', 'a limo', + 'a square chair', 'a swivel chair', 'a sniper rifle'] + text = clip.tokenize(input_t).to(self.device_str) + else: + raise NotImplementedError + if len(self.test_img_path): + self.test_img = [Image.open(t).convert("RGB") + for t in self.test_img_path] + self.test_img = [self.clip_preprocess(img).unsqueeze( + 0).to(self.device_str) for img in self.test_img] + self.test_img = torch.cat(self.test_img, dim=0) + else: + self.test_img = [] + self.t2s_input = self.test_img_path + input_t + clip_feat = [] + if len(self.test_img): + clip_feat.append( + self.clip_model.encode_image(self.test_img).float()) + clip_feat.append(self.clip_model.encode_text(text).float()) + self.clip_feat_list = torch.cat(clip_feat, dim=0) + diff --git a/trainers/common_fun.py b/trainers/common_fun.py new file mode 100644 index 0000000..ebc9822 --- /dev/null +++ b/trainers/common_fun.py @@ -0,0 +1,105 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch +import torch.nn.functional as F +import torch.nn as nn +import numpy as np +from loguru import logger +from utils.vis_helper import visualize_point_clouds_3d +from utils.data_helper import normalize_point_clouds +from utils.checker import * + +@torch.no_grad() +def validate_inspect_noprior(model, + it, writer, + sample_num_points, num_samples, + need_sample=1, need_val=1, need_train=1, + w_prior=None, val_x=None, tr_x=None, + test_loader=None, # can be None + has_shapelatent=False, + bound=1.5, val_class_label=None, tr_class_label=None, + cfg={}): + """ visualize the samples, and recont if needed + """ + assert(has_shapelatent) + assert(w_prior is not None and val_x is not None and tr_x is not None) + z_list = [] + num_samples = w_prior.shape[0] if need_sample else 0 + num_recon = val_x.shape[0] + num_recon_val = num_recon if need_val else 0 + num_recon_train = num_recon if need_train else 0 + + assert(need_sample == 0 and need_val > 0 and need_train == 0) + if need_sample: + z_prior = model.pz(w_prior, sample_num_points) + z_list.append(z_prior) + if val_class_label is not None: + output = model.recont(val_x, class_label=val_class_label) + else: + output = model.recont(val_x) # torch.cat([val_x, tr_x])) + gen_x = output['final_pred'] + vis_order = cfg.viz.viz_order + vis_args = {'vis_order': vis_order} + + # vis the samples + if num_samples > 0: + img_list = [] + for i in range(num_samples): + points = gen_x[i] + points = normalize_point_clouds([points])[0] + img = visualize_point_clouds_3d([points], bound=bound, **vis_args) + img_list.append(img) + img = np.concatenate(img_list, axis=2) + writer.add_image('sample', torch.as_tensor(img), it) + + # vis the recont + if num_recon_val > 0: + img_list = [] + for i in range(num_recon_val): + points = gen_x[num_samples + i] + points = normalize_point_clouds([points]) # val_x[i], points]) + img = visualize_point_clouds_3d( + points, ['rec#%d' % i], bound=bound, **vis_args) + img_list.append(img) + gt_list = [] + for i in range(num_recon_val): + points = normalize_point_clouds([val_x[i]]) + img = visualize_point_clouds_3d( + points, ['gt#%d' % i], bound=bound, **vis_args) + gt_list.append(img) + img = np.concatenate(img_list, axis=2) + gt = np.concatenate(gt_list, axis=2) + img = np.concatenate([gt, img], axis=1) + + if 'vis/latent_pts' in output: + latent_pts = output['vis/latent_pts'] + img_list = [] + for i in range(num_recon_val): + points = latent_pts[num_samples + i] + points = normalize_point_clouds([points]) + latent = visualize_point_clouds_3d( + points, ['latent#%d' % i], bound=bound, **vis_args) + img_list.append(latent) + latent_list = np.concatenate(img_list, axis=2) + img = np.concatenate([img, latent_list], axis=1) + + writer.add_image('valrecont', torch.as_tensor(img), it) + + if num_recon_train > 0: + img_list = [] + for i in range(num_recon_train): + points = gen_x[num_samples + num_recon_val + i] + points = normalize_point_clouds([tr_x[i], points]) + img = visualize_point_clouds_3d( + points, ['ori', 'rec'], bound=bound, **vis_args) + img_list.append(img) + img = np.concatenate(img_list, axis=2) + writer.add_image('train/recont', torch.as_tensor(img), it) + + logger.info('writer: {}', writer.url) + diff --git a/trainers/common_fun_prior_train.py b/trainers/common_fun_prior_train.py new file mode 100644 index 0000000..cac11f3 --- /dev/null +++ b/trainers/common_fun_prior_train.py @@ -0,0 +1,363 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch +import numpy as np +import time +from loguru import logger +from utils.ema import EMA +from torch import optim +from torch.optim import Adam as FusedAdam +from torch.cuda.amp import autocast, GradScaler +from utils.sr_utils import SpectralNormCalculator +from utils import utils +from utils.vis_helper import visualize_point_clouds_3d +from utils.diffusion_pvd import DiffusionDiscretized +from utils.eval_helper import compute_NLL_metric +from utils import model_helper, exp_helper, data_helper +from timeit import default_timer as timer +from utils.data_helper import normalize_point_clouds + + +def init_optimizer_train_2prior(cfg, vae, dae, cond_enc=None): + args = cfg.sde + param_dict_dae = dae.parameters() + # optimizer for prior + if args.learning_rate_mlogit > 0: + raise NotImplementedError + if args.use_adamax: + from utils.adamax import Adamax + dae_optimizer = Adamax(param_dict_dae, args.learning_rate_dae, + weight_decay=args.weight_decay, eps=1e-4) + elif args.use_adam: + cfgopt = cfg.trainer.opt + dae_optimizer = optim.Adam(param_dict_dae, + lr=args.learning_rate_dae, + betas=(cfgopt.beta1, cfgopt.beta2), + weight_decay=cfgopt.weight_decay) + + else: + dae_optimizer = FusedAdam(param_dict_dae, args.learning_rate_dae, + weight_decay=args.weight_decay, eps=1e-4) + # add EMA functionality to the optimizer + dae_optimizer = EMA(dae_optimizer, ema_decay=args.ema_decay) + dae_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + dae_optimizer, float(args.epochs - args.warmup_epochs - 1), + eta_min=args.learning_rate_min_dae) + + # optimizer for VAE + if args.use_adamax: + from utils.adamax import Adamax + vae_optimizer = Adamax(vae.parameters(), args.learning_rate_vae, + weight_decay=args.weight_decay, eps=1e-3) + elif args.use_adam: + cfgopt = cfg.trainer.opt + vae_optimizer = optim.Adam(vae.parameters(), + lr=args.learning_rate_min_vae, + betas=(cfgopt.beta1, cfgopt.beta2), + weight_decay=cfgopt.weight_decay) + + else: + vae_optimizer = FusedAdam(vae.parameters(), args.learning_rate_vae, + weight_decay=args.weight_decay, eps=1e-3) + vae_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + vae_optimizer, float(args.epochs - args.warmup_epochs - 1), eta_min=args.learning_rate_min_vae) + logger.info('[grad_scalar] enabled={}', args.autocast_train) + if not args.autocast_train: + grad_scalar = utils.DummyGradScalar() + else: + grad_scalar = GradScaler(2**10, enabled=True) + + # create SN calculator + vae_sn_calculator = SpectralNormCalculator() + dae_sn_calculator = SpectralNormCalculator() + if args.train_vae: + # TODO: require using layer in layers/neural_operations + vae_sn_calculator.add_bn_layers(vae) + dae_sn_calculator.add_bn_layers(dae) + return { + 'vae_scheduler': vae_scheduler, + 'vae_optimizer': vae_optimizer, + 'vae_sn_calculator': vae_sn_calculator, + 'dae_scheduler': dae_scheduler, + 'dae_optimizer': dae_optimizer, + 'dae_sn_calculator': dae_sn_calculator, + 'grad_scalar': grad_scalar + } + + +@torch.no_grad() +def validate_inspect(latent_shape, + model, dae, diffusion, ode_sample, + it, writer, + sample_num_points, num_samples, + autocast_train=False, + need_sample=1, need_val=1, need_train=1, + w_prior=None, val_x=None, tr_x=None, + val_input=None, + prior_cond=None, + m_pcs=None, s_pcs=None, + test_loader=None, # can be None + has_shapelatent=False, vis_latent_point=False, + ddim_step=0, epoch=0, fun_generate_samples_vada=None, clip_feat=None, + cls_emb=None, cfg={}): + """ visualize the samples, and recont if needed + Args: + has_shapelatent (bool): True when the model has shape latent + it (int): step index + num_samples: + need_* : draw samples for * or not + """ + assert(has_shapelatent) + assert(w_prior is not None and val_x is not None and tr_x is not None) + z_list = [] + num_samples = w_prior.shape[0] if need_sample else 0 + num_recon = val_x.shape[0] + num_recon_val = num_recon if need_val else 0 + num_recon_train = num_recon if need_train else 0 + kwargs = {} + if cls_emb is not None: + kwargs['cls_emb'] = cls_emb + assert(need_sample >= 0 and need_val > 0 and need_train == 0) + # draw samples + if need_sample: + # gen_x: B,N,3 + gen_x, nstep, ode_time, sample_time, output_dict = \ + fun_generate_samples_vada(latent_shape, dae, diffusion, + model, w_prior.shape[0], enable_autocast=autocast_train, + prior_cond=prior_cond, + ode_sample=ode_sample, ddim_step=ddim_step, clip_feat=clip_feat, + **kwargs) + logger.info('cast={}, sample step={}, ode_time={}, sample_time={}', + autocast_train, + nstep if ddim_step == 0 else ddim_step, + ode_time, sample_time) + gen_pcs = gen_x + else: + output_dict = {} + + rgb_as_normal = not cfg.data.has_color # if has color, rgb not as normal + vis_order = cfg.viz.viz_order + vis_args = {'rgb_as_normal': rgb_as_normal, 'vis_order': vis_order, + 'is_omap': 'omap' in cfg.data.type} + # vis the samples + if not vis_latent_point and num_samples > 0: + img_list = [] + for i in range(num_samples): + points = gen_x[i] # N,3 + points = normalize_point_clouds([points])[0] + img = visualize_point_clouds_3d([points], **vis_args) + img_list.append(img) + img = np.concatenate(img_list, axis=2) + writer.add_image('sample', torch.as_tensor(img), it) + + if vis_latent_point and num_samples > 0: + img_list = [] + eps_list = [] + prior_cond_list = [] + eps = output_dict['sampled_eps'].view( + num_samples, dae.num_points, dae.num_classes)[:, :, :cfg.ddpm.input_dim] + for i in range(num_samples): + points = gen_x[i] + points = normalize_point_clouds([points])[0] + img = visualize_point_clouds_3d([points], ['samples'], **vis_args) + img_list.append(img) + + points = eps[i] + points = normalize_point_clouds([points])[0] + img = visualize_point_clouds_3d([points], ['eps'], **vis_args) + eps_list.append(img) + if prior_cond is not None: + points = prior_cond[i] + if len(points.shape) > 2: # points shape is (1,X,Y,Z) + output_voxel_XYZ = points[0].cpu().numpy() # XYZ + coordsid = np.where(output_voxel_XYZ == 1) + coordsid = np.stack(coordsid, axis=1) # N,3 + points = torch.from_numpy(coordsid) + voxel_size = 1.0 + X, Y, Z = output_voxel_XYZ.shape + c = torch.tensor([X, Y, Z]).view(1, 3) * 0.5 + points = points - c # center at 1 + vis_points = points + bound = max(X, Y, Z)*0.5 + # logger.info('voxel_size: {}, output_voxel_XYZ: {}, bound: {}', + # voxel_size, output_voxel_XYZ.shape, bound) + + elif vis_args['is_omap']: + vis_points = points * s_pcs[i] # range before norm + bound = s_pcs[0].max().item() + voxel_size = cfg.data.voxel_size + else: + vis_points = points + voxel_size = cfg.data.voxel_size + bound = 1.5 # 2.0 + + img = visualize_point_clouds_3d([vis_points], ['cond'], + is_voxel=1, + voxel_size=voxel_size, + bound=bound, + **vis_args) + + points = normalize_point_clouds([points])[0] + ## print('points', points.shape, points.numpy().min(0), points.numpy().max(0), points[:3]) + img2 = visualize_point_clouds_3d([points], ['cond_center'], + **vis_args) + img = np.concatenate([img, img2], axis=1) + prior_cond_list.append(img) + + img = np.concatenate(img_list, axis=2) + img_eps = np.concatenate(eps_list, axis=2) + prior_cond_list = np.concatenate(prior_cond_list, axis=2) if len( + prior_cond_list) else prior_cond_list + img = np.concatenate([img, img_eps], axis=1) + img = np.concatenate([img, prior_cond_list], axis=1) if len( + prior_cond_list) else img + writer.add_image('sample', torch.as_tensor(img), it) + + inputs = val_input if val_input is not None else val_x + output = model.recont(inputs) if cls_emb is None else model.recont( + inputs, cls_emb=cls_emb) + gen_x = output['final_pred'] + + # vis the recont on val set + if num_recon_val > 0: + img_list = [] + for i in range(num_recon_val): + points = gen_x[i] + points = normalize_point_clouds([points]) + img = visualize_point_clouds_3d(points, ['rec#%d' % i], **vis_args) + img_list.append(img) + gt_list = [] + for i in range(num_recon_val): + points = normalize_point_clouds([val_x[i]]) + img = visualize_point_clouds_3d(points, ['gt#%d' % i], **vis_args) + gt_list.append(img) + img = np.concatenate(img_list, axis=2) + gt = np.concatenate(gt_list, axis=2) + img = np.concatenate([gt, img], axis=1) + if val_input is not None: # also vis the input, used when we take voxel points as input + input_list = [] + for i in range(num_recon_val): + points = val_input[i] + points = normalize_point_clouds([points]) + input_img = visualize_point_clouds_3d( + points, ['input#%d' % i], **vis_args) + input_list.append(input_img) + input_list = np.concatenate(input_list, axis=2) + img = np.concatenate([img, input_list], axis=1) + writer.add_image('valrecont', torch.as_tensor(img), it) + + # vis recont on train set + if num_recon_train > 0: + img_list = [] + for i in range(num_recon_train): + points = gen_x[num_recon_val + i] + points = normalize_point_clouds([tr_x[i], points]) + img = visualize_point_clouds_3d(points, ['ori', 'rec'], **vis_args) + img_list.append(img) + img = np.concatenate(img_list, axis=2) + writer.add_image('train/recont', torch.as_tensor(img), it) + + logger.info('writer: {}', writer.url) + + return output_dict + + +@torch.no_grad() +def generate_samples_vada_2prior(shape, dae, diffusion, vae, num_samples, enable_autocast, + ode_eps=0.00001, ode_solver_tol=1e-5, ode_sample=False, + prior_var=1.0, temp=1.0, vae_temp=1.0, noise=None, + need_denoise=False, prior_cond=None, device=None, cfg=None, + ddim_step=0, clip_feat=None, cls_emb=None): + """ this function is copied from trainers/train_2prior.py + used by trainers/cond_prior.py + should also support trainers/train_2prior.py but not test yet + """ + output = {} + if ode_sample == 1: + assert isinstance( + diffusion, DiffusionBase), 'ODE-based sampling requires cont. diffusion!' + assert ode_eps is not None, 'ODE-based sampling requires integration cutoff ode_eps!' + assert ode_solver_tol is not None, 'ODE-based sampling requires ode solver tolerance!' + start = timer() + condition_input = None + eps_list = [] + for i in range(2): + assert(cls_emb is None), f' not support yet' + eps, nfe, time_ode_solve = diffusion.sample_model_ode( + dae[i], num_samples, shape[i], ode_eps, ode_solver_tol, enable_autocast, temp, noise, + condition_input=condition_input, clip_feat=clip_feat, + ) + condition_input = eps + eps_list.append(eps) + output['sampled_eps'] = eps + eps = vae.compose_eps(eps_list) # torch.cat(eps, dim=1) + elif ode_sample == 0: + assert isinstance( + diffusion, DiffusionDiscretized), 'Regular sampling requires disc. diffusion!' + assert noise is None, 'Noise is not used in ancestral sampling.' + nfe = diffusion._diffusion_steps + time_ode_solve = 999.999 # Yeah I know... + start = timer() + + dae_kwarg = {'is_image': False, 'prior_var': prior_var} + dae_kwarg['clip_feat'] = clip_feat + if cfg.data.cond_on_voxel: + output['prior_cond'] = prior_cond + voxel_grid_enc_out = dae[2](prior_cond.to( + device)) # embed the condition_input + condition_input = voxel_grid_enc_out['global_emb'] + else: + condition_input = None if cls_emb is None else cls_emb + + all_eps = [] + for i in range(2): + if i == 1 and cfg.data.cond_on_voxel: + dae_kwarg['grid_emb'] = voxel_grid_enc_out['grid_emb'] + if ddim_step > 0: + assert(cls_emb is None), f'not support yet' + eps, eps_list = diffusion.run_ddim(dae[i], + num_samples, shape[i], temp, enable_autocast, + ddim_step=ddim_step, + condition_input=condition_input, + skip_type=cfg.sde.ddim_skip_type, + kappa=cfg.sde.ddim_kappa, + dae_index=i, + **dae_kwarg) + else: + eps, eps_list = diffusion.run_denoising_diffusion(dae[i], + num_samples, shape[i], temp, enable_autocast, + condition_input=condition_input, + **dae_kwarg + ) + condition_input = eps + + if cls_emb is not None: + condition_input = torch.cat([condition_input, + cls_emb.unsqueeze(-1).unsqueeze(-1)], dim=1) + if i == 0: + condition_input = vae.global2style(condition_input) + all_eps.append(eps) + + output['sampled_eps'] = eps + eps = vae.compose_eps(all_eps) + output['eps_list'] = eps_list + output['print/sample_mean_global'] = eps.view( + num_samples, -1).mean(-1).mean() + output['print/sample_var_global'] = eps.view( + num_samples, -1).var(-1).mean() + decomposed_eps = vae.decompose_eps(eps) + image = vae.sample(num_samples=num_samples, + decomposed_eps=decomposed_eps, cls_emb=cls_emb) + + end = timer() + sampling_time = end - start + # average over GPUs + nfe_torch = torch.tensor(nfe * 1.0, device='cuda') + sampling_time_torch = torch.tensor(sampling_time * 1.0, device='cuda') + time_ode_solve_torch = torch.tensor(time_ode_solve * 1.0, device='cuda') + return image, nfe_torch, time_ode_solve_torch, sampling_time_torch, output diff --git a/trainers/hvae_trainer.py b/trainers/hvae_trainer.py new file mode 100644 index 0000000..784b374 --- /dev/null +++ b/trainers/hvae_trainer.py @@ -0,0 +1,204 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +""" to train hierarchical VAE model +this trainer only train the vae without prior +""" +import os +import sys +import torch +import torch.nn.functional as F +import numpy as np +from loguru import logger +import torch.distributed as dist +from trainers.base_trainer import BaseTrainer +from utils.eval_helper import compute_NLL_metric +from utils import model_helper, exp_helper, data_helper +from utils.checker import * +from utils import utils +from trainers.common_fun import validate_inspect_noprior +from torch.cuda.amp import autocast, GradScaler +import third_party.pvcnn.functional as pvcnn_fn +from calmsize import size as calmsize + + +class Trainer(BaseTrainer): + def __init__(self, cfg, args): + """ + Args: + cfg: training config + args: used for distributed training + """ + super().__init__(cfg, args) + self.train_iter_kwargs = {} + self.sample_num_points = cfg.data.tr_max_sample_points + device = torch.device('cuda:%d' % args.local_rank) + self.device_str = 'cuda:%d' % args.local_rank + if not cfg.trainer.use_grad_scalar: + self.grad_scalar = utils.DummyGradScalar() + else: + logger.info('Init GradScaler!') + self.grad_scalar = GradScaler(2**10, enabled=True) + + self.model = self.build_model().to(device) + if len(self.cfg.sde.vae_checkpoint): + logger.info('Load vae_checkpoint: {}', self.cfg.sde.vae_checkpoint) + self.model.load_state_dict( + torch.load(self.cfg.sde.vae_checkpoint)['model']) + + logger.info('broadcast_params: device={}', device) + utils.broadcast_params(self.model.parameters(), + args.distributed) + self.build_other_module() + if args.distributed: + logger.info('waitting for barrier, device={}', device) + dist.barrier() + logger.info('pass barrier, device={}', device) + + self.train_loader, self.test_loader = self.build_data() + + # The optimizer + self.optimizer, self.scheduler = utils.get_opt( + self.model.parameters(), + self.cfg.trainer.opt, + cfg.ddpm.ema, self.cfg) + # Build Spectral Norm Regularization if needed + if self.cfg.trainer.sn_reg_vae: + raise NotImplementedError + + # Prepare variable for summy + self.num_points = self.cfg.data.tr_max_sample_points + logger.info('done init trainer @{}', device) + + # Prepare for evaluation + # init the latent for validate + self.prepare_vis_data() + # ------------------------------------------- # + # training fun # + # ------------------------------------------- # + + def epoch_start(self, epoch): + pass + + def epoch_end(self, epoch, writer=None, **kwargs): + return super().epoch_end(epoch, writer=writer) + + def train_iter(self, data, *args, **kwargs): + """ forward one iteration; and step optimizer + Args: + data: (dict) tr_points shape: (B,N,3) + see get_loss in models/shapelatent_diffusion.py + """ + self.model.train() + step = kwargs.get('step', None) + assert(step is not None), 'require step as input' + warmup_iters = len(self.train_loader) * \ + self.cfg.trainer.opt.vae_lr_warmup_epochs + utils.update_vae_lr(self.cfg, step, warmup_iters, self.optimizer) + if 'no_update' in kwargs: + no_update = kwargs['no_update'] + else: + no_update = False + if not no_update: + self.model.train() + self.optimizer.zero_grad() + device = torch.device(self.device_str) + tr_pts = data['tr_points'].to(device) # (B, Npoints, 3) + batch_size = tr_pts.size(0) + model_kwargs = {} + with autocast(enabled=self.cfg.sde.autocast_train): + res = self.model.get_loss(tr_pts, writer=self.writer, + it=step, **model_kwargs) + loss = res['loss'].mean() + lossv = loss.detach().cpu().item() + + if not no_update: + + self.grad_scalar.scale(loss).backward() + utils.average_gradients(self.model.parameters(), + self.args.distributed) + if self.cfg.trainer.opt.grad_clip > 0: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), + max_norm=self.cfg.trainer.opt.grad_clip) + self.grad_scalar.step(self.optimizer) + self.grad_scalar.update() + + output = {} + if self.writer is not None: + for k, v in res.items(): + if 'print/' in k and step is not None: + v0 = v.mean().item() if torch.is_tensor(v) else v + self.writer.avg_meter(k.split('print/')[-1], v0, + step=step) + if 'hist/' in k: + output[k] = v + + output.update({ + 'loss': lossv, + 'x_0_pred': res['x_0_pred'].detach().cpu(), # perturbed data + 'x_0': res['x_0'].detach().cpu(), + 'x_t': res['final_pred'].detach().view(batch_size, -1, res['x_0'].shape[-1]), + 't': res.get('t', None) + }) + for k, v in res.items(): + if 'vis/' in k or 'msg/' in k: + output[k] = v + # if 'x_ref_pred' in res: + # output['x_ref_pred'] = res['x_ref_pred'].detach().cpu() + # if 'x_ref_pred_input' in res: + # output['x_ref_pred_input'] = res['x_ref_pred_input'].detach().cpu() + return output + # --------------------------------------------- # + # visulization function and sampling function # + # --------------------------------------------- # + + @torch.no_grad() + def vis_diffusion(self, data, writer): + pass + + def diffusion_sample(self, *args, **kwargs): + pass + + @torch.no_grad() + def vis_sample(self, writer, num_vis=None, step=0, include_pred_x0=True, + save_file=None): + bound = 1.5 if 'chair' in self.cfg.data.cates else 1.0 + assert(not self.cfg.data.cond_on_cat) + val_class_label = tr_class_label = None + validate_inspect_noprior(self.model, + step, self.writer, self.sample_num_points, + need_sample=0, need_val=1, need_train=0, + num_samples=self.num_val_samples, + test_loader=self.test_loader, + w_prior=self.w_prior, + val_x=self.val_x, tr_x=self.tr_x, + val_class_label=val_class_label, + tr_class_label=tr_class_label, + has_shapelatent=True, + bound=bound, cfg=self.cfg + ) + + @torch.no_grad() + def sample(self, num_shapes=2, num_points=2048, device_str='cuda', + for_vis=True, use_ddim=False, save_file=None, ddim_step=500): + """ return the final samples in shape [B,3,N] """ + # switch to EMA parameters + if self.cfg.ddpm.ema: + self.optimizer.swap_parameters_with_ema(store_params_in_ema=True) + self.model.eval() + + # ---- forward sampling ---- # + gen_x = self.model.sample( + num_samples=num_shapes, device_str=self.device_str) + # gen_x: BNC + CHECKEQ(gen_x.shape[2], self.cfg.ddpm.input_dim) + traj = gen_x.permute(0, 2, 1).contiguous() # BN3->B3N + + # switch back to original parameters + if self.cfg.ddpm.ema: + self.optimizer.swap_parameters_with_ema(store_params_in_ema=True) + return traj diff --git a/trainers/train_2prior.py b/trainers/train_2prior.py new file mode 100644 index 0000000..7f90018 --- /dev/null +++ b/trainers/train_2prior.py @@ -0,0 +1,449 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +""" to train hierarchical VAE model with 2 prior +one for style latent, one for latent pts, +based on trainers/train_prior.py +""" +import os +import time +from PIL import Image +import gc +import functools +import psutil +import torch +import torch.nn.functional as F +import torch.nn as nn +import torchvision +import numpy as np +from loguru import logger +import torch.distributed as dist +from torch import optim +from utils.ema import EMA +from utils.model_helper import import_model, loss_fn +from utils.vis_helper import visualize_point_clouds_3d +from utils.eval_helper import compute_NLL_metric +from utils import model_helper, exp_helper, data_helper +from utils.data_helper import normalize_point_clouds +## from utils.diffusion_discretized import DiffusionDiscretized +from utils.diffusion_pvd import DiffusionDiscretized +from utils.diffusion_continuous import make_diffusion, DiffusionBase +from utils.checker import * +from utils import utils +from matplotlib import pyplot as plt +import third_party.pvcnn.functional as pvcnn_fn +from timeit import default_timer as timer +from torch.optim import Adam as FusedAdam +from torch.cuda.amp import autocast, GradScaler +from trainers.train_prior import Trainer as PriorTrainer +from trainers.train_prior import validate_inspect # import Trainer as PriorTrainer + +quiet = int(os.environ.get('quiet', 0)) +VIS_LATENT_PTS = 0 + + +@torch.no_grad() +def generate_samples_vada_2prior(shape, dae, diffusion, vae, num_samples, enable_autocast, + ode_eps=0.00001, ode_solver_tol=1e-5, # None, + ode_sample=False, prior_var=1.0, temp=1.0, vae_temp=1.0, noise=None, need_denoise=False, + ddim_step=0, clip_feat=None, cls_emb=None, ddim_skip_type='uniform', ddim_kappa=1.0): + output = {} + #kwargs = {} + # if cls_emb is not None: + # kwargs['cls_emb'] = cls_emb + if ode_sample == 1: + assert isinstance( + diffusion, DiffusionBase), 'ODE-based sampling requires cont. diffusion!' + assert ode_eps is not None, 'ODE-based sampling requires integration cutoff ode_eps!' + assert ode_solver_tol is not None, 'ODE-based sampling requires ode solver tolerance!' + start = timer() + condition_input = None + eps_list = [] + for i in range(len(dae)): + assert(cls_emb is None), f' not support yet' + eps, nfe, time_ode_solve = diffusion.sample_model_ode( + dae[i], num_samples, shape[i], ode_eps, ode_solver_tol, enable_autocast, temp, noise, + condition_input=condition_input, clip_feat=clip_feat, + ) + condition_input = eps + eps_list.append(eps) + output['sampled_eps'] = eps + eps = vae.compose_eps(eps_list) # torch.cat(eps, dim=1) + + elif ode_sample == 0: + assert isinstance( + diffusion, DiffusionDiscretized), 'Regular sampling requires disc. diffusion!' + assert noise is None, 'Noise is not used in ancestral sampling.' + nfe = diffusion._diffusion_steps + time_ode_solve = 999.999 # Yeah I know... + start = timer() + condition_input = None if cls_emb is None else cls_emb + all_eps = [] + for i in range(len(dae)): + if ddim_step > 0: + assert(cls_emb is None), f'not support yet' + eps, eps_list = diffusion.run_ddim(dae[i], + num_samples, shape[i], temp, enable_autocast, + is_image=False, prior_var=prior_var, ddim_step=ddim_step, + condition_input=condition_input, clip_feat=clip_feat, + skip_type=ddim_skip_type, kappa=ddim_kappa) + else: + eps, eps_list = diffusion.run_denoising_diffusion(dae[i], + num_samples, shape[i], temp, enable_autocast, + is_image=False, prior_var=prior_var, + condition_input=condition_input, clip_feat=clip_feat, + ) + condition_input = eps + + if cls_emb is not None: + condition_input = torch.cat([condition_input, + cls_emb.unsqueeze(-1).unsqueeze(-1)], dim=1) + if i == 0: + condition_input = vae.global2style(condition_input) + # exit() + all_eps.append(eps) + + output['sampled_eps'] = eps + eps = vae.compose_eps(all_eps) + output['eps_list'] = eps_list + output['print/sample_mean_global'] = eps.view( + num_samples, -1).mean(-1).mean() + output['print/sample_var_global'] = eps.view( + num_samples, -1).var(-1).mean() + decomposed_eps = vae.decompose_eps(eps) + image = vae.sample(num_samples=num_samples, + decomposed_eps=decomposed_eps, cls_emb=cls_emb) + + end = timer() + sampling_time = end - start + # average over GPUs + nfe_torch = torch.tensor(nfe * 1.0, device='cuda') + sampling_time_torch = torch.tensor(sampling_time * 1.0, device='cuda') + time_ode_solve_torch = torch.tensor(time_ode_solve * 1.0, device='cuda') + return image, nfe_torch, time_ode_solve_torch, sampling_time_torch, output + + +class Trainer(PriorTrainer): + is_diffusion = 0 + + def __init__(self, cfg, args): + """ + Args: + cfg: training config + args: used for distributed training + """ + super().__init__(cfg, args) + self.fun_generate_samples_vada = functools.partial( + generate_samples_vada_2prior, ode_eps=cfg.sde.ode_eps, + ddim_skip_type=cfg.sde.ddim_skip_type, + ddim_kappa=cfg.sde.ddim_kappa) + + def compute_loss_vae(self, tr_pts, global_step, **kwargs): + """ compute forward for VAE model, used in global-only prior training + Input: + tr_pts: points + global_step: int + Returns: + output dict including entry: + 'eps': z ~ posterior + 'q_loss': 0 if not train vae else the KL+rec + 'x_0_pred': global points if not train vae + 'x_0_target': target points + + """ + vae = self.model + dae = self.dae + args = self.cfg.sde + distributed = args.distributed + vae_sn_calculator = self.vae_sn_calculator + num_total_iter = self.num_total_iter + ## diffusion = self.diffusion_cont if self.cfg.sde.ode_sample else self.diffusion_disc + if self.cfg.sde.ode_sample == 1: + diffusion = self.diffusion_cont + elif self.cfg.sde.ode_sample == 0: + diffusion = self.diffusion_disc + elif self.cfg.sde.ode_sample == 2: + raise NotImplementedError + # diffusion = [self.diffusion_cont, self.diffusion_disc] + + B = tr_pts.size(0) + with torch.set_grad_enabled(args.train_vae): + with autocast(enabled=args.autocast_train): + # posterior and likelihood + if not args.train_vae: + output = {} + all_eps, all_log_q, latent_list = vae.encode(tr_pts) + x_0_pred = x_0_target = tr_pts + vae_recon_loss = 0 + def make_4d(x): return x.unsqueeze(-1).unsqueeze(-1) if \ + len(x.shape) == 2 else x.unsqueeze(-1) + eps = make_4d(all_eps) + output.update({'eps': eps, 'q_loss': torch.zeros(1), + 'x_0_pred': tr_pts, 'x_0_target': tr_pts, + 'x_0': tr_pts, 'final_pred': tr_pts}) + else: + raise NotImplementedError + return output + # ------------------------------------------- # + # training fun # + # ------------------------------------------- # + + def train_iter(self, data, *args, **kwargs): + """ forward one iteration; and step optimizer + Args: + data: (dict) tr_points shape: (B,N,3) + see get_loss in models/shapelatent_diffusion.py + """ + # some variables + + input_dim = self.cfg.ddpm.input_dim + loss_type = self.cfg.ddpm.loss_type + vae = self.model + dae = self.dae + dae.train() + diffusion = self.diffusion_cont if self.cfg.sde.ode_sample else self.diffusion_disc + if self.cfg.sde.ode_sample == 1: + diffusion = self.diffusion_cont + elif self.cfg.sde.ode_sample == 0: + diffusion = self.diffusion_disc + elif self.cfg.sde.ode_sample == 2: + raise NotImplementedError # not support training with different solver + ## diffusion = [self.diffusion_cont, self.diffusion_disc] + + dae_optimizer = self.dae_optimizer + vae_optimizer = self.vae_optimizer + args = self.cfg.sde + device = torch.device(self.device_str) + num_total_iter = self.num_total_iter + distributed = self.args.distributed + dae_sn_calculator = self.dae_sn_calculator + vae_sn_calculator = self.vae_sn_calculator + grad_scalar = self.grad_scalar + + global_step = step = kwargs.get('step', None) + no_update = kwargs.get('no_update', False) + + # update_lr + warmup_iters = len(self.train_loader) * args.warmup_epochs + utils.update_lr(args, global_step, warmup_iters, + dae_optimizer, vae_optimizer) + + # input + tr_pts = data['tr_points'].to(device) # (B, Npoints, 3) + inputs = data['input_pts'].to( + device) if 'input_pts' in data else None # the noisy points + tr_img = data['tr_img'].to(device) if 'tr_img' in data else None + model_kwargs = {} + if self.cfg.data.cond_on_cat: + class_label_int = data['cate_idx'].view(-1) # .to(device) + nclass = self.cfg.data.nclass + class_label = torch.nn.functional.one_hot(class_label_int, nclass) + model_kwargs['class_label'] = class_label.float().to(device) + + B = batch_size = tr_pts.size(0) + if tr_img is not None: + # tr_img: B,nimg,3,H,W + # logger.info('image: {}', tr_img.shape) + nimg = tr_img.shape[1] + tr_img = tr_img.view(B*nimg, *tr_img.shape[2:]) + clip_feat = self.clip_model.encode_image( + tr_img).view(B, nimg, -1).mean(1).float() + else: + clip_feat = None + + # optimize vae params + vae_optimizer.zero_grad() + output = self.compute_loss_vae( + tr_pts, global_step, inputs=inputs, **model_kwargs) + + # the interface between VAE and DAE is eps. + eps = output['eps'].detach() # 4d: B,D,-1,1 + CHECK4D(eps) + dae_kwarg = {} + if self.cfg.data.cond_on_cat: + dae_kwarg['condition_input'] = output['cls_emb'] + # train prior + if args.train_dae: + dae_optimizer.zero_grad() + with autocast(enabled=args.autocast_train): + # get diffusion quantities for p sampling scheme and reweighting for q + t_p, var_t_p, m_t_p, obj_weight_t_p, _, g2_t_p = \ + diffusion.iw_quantities(B, args.time_eps, + args.iw_sample_p, args.iw_subvp_like_vp_sde) + # logger.info('t_p: {}, var: {}, m_t: {}', t_p[0], var_t_p[0], m_t_p[0]) + + decomposed_eps = self.vae.decompose_eps(eps) + output['vis/eps'] = decomposed_eps[1].view( + -1, self.dae.num_points, self.dae.num_classes)[:, :, :3] + p_loss_list = [] + for latent_id, eps in enumerate(decomposed_eps): + + noise_p = torch.randn(size=eps.size(), device=device) + eps_t_p = diffusion.sample_q(eps, noise_p, var_t_p, m_t_p) + # run the score model + eps_t_p.requires_grad_(True) + mixing_component = diffusion.mixing_component( + eps_t_p, var_t_p, t_p, enabled=args.mixed_prediction) + if latent_id == 0: + pred_params_p = dae[latent_id]( + eps_t_p, t_p, x0=eps, clip_feat=clip_feat, **dae_kwarg) + else: + condition_input = decomposed_eps[0] if not self.cfg.data.cond_on_cat else \ + torch.cat( + [decomposed_eps[0], output['cls_emb'].unsqueeze(-1).unsqueeze(-1)], dim=1) + condition_input = self.model.global2style( + condition_input) + pred_params_p = dae[latent_id](eps_t_p, t_p, x0=eps, + condition_input=condition_input, clip_feat=clip_feat) + + pred_eps_t0 = (eps_t_p - torch.sqrt(var_t_p) + * pred_params_p) / m_t_p + + params = utils.get_mixed_prediction(args.mixed_prediction, + pred_params_p, dae[latent_id].mixing_logit, mixing_component) + if self.cfg.latent_pts.pvd_mse_loss: + p_loss = F.mse_loss( + params.contiguous().view(B, -1), noise_p.view(B, -1), + reduction='mean') + else: + l2_term_p = torch.square(params - noise_p) + p_objective = torch.sum( + obj_weight_t_p * l2_term_p, dim=[1, 2, 3]) + regularization_p, dae_norm_loss, dae_bn_loss, dae_wdn_coeff, \ + jac_reg_loss, kin_reg_loss = utils.dae_regularization( + args, dae_sn_calculator, diffusion, dae, step, t_p, + pred_params_p, eps_t_p, var_t_p, m_t_p, g2_t_p) + reg_mlogit = ((torch.sum(torch.sigmoid(dae.mixing_logit)) - + args.regularize_mlogit_margin)**2) * args.regularize_mlogit \ + if args.regularize_mlogit else 0 + p_loss = torch.mean(p_objective) + \ + regularization_p + reg_mlogit + if self.writer is not None: + self.writer.avg_meter( + 'train/p_loss_%d' % latent_id, p_loss.detach().item()) + p_loss_list.append(p_loss) + p_loss = sum(p_loss_list) # torch.cat(p_loss_list, dim=0).sum() + loss = p_loss + # update dae parameters + grad_scalar.scale(p_loss).backward() + utils.average_gradients(dae.parameters(), distributed) + if args.grad_clip_max_norm > 0.: # apply gradient clipping + grad_scalar.unscale_(dae_optimizer) + torch.nn.utils.clip_grad_norm_(dae.parameters(), + max_norm=args.grad_clip_max_norm) + grad_scalar.step(dae_optimizer) + + # update grade scalar + grad_scalar.update() + + if args.bound_mlogit: + dae.mixing_logit.data.clamp_(max=args.bound_mlogit_value) + # Bookkeeping! + writer = self.writer + if writer is not None: + writer.avg_meter('train/lr_dae', dae_optimizer.state_dict()[ + 'param_groups'][0]['lr'], global_step) + writer.avg_meter('train/lr_vae', vae_optimizer.state_dict()[ + 'param_groups'][0]['lr'], global_step) + if self.cfg.latent_pts.pvd_mse_loss: + writer.avg_meter( + 'train/p_loss', p_loss.item(), global_step) + if args.mixed_prediction and global_step % 500 == 0: + for i in range(len(dae)): + m = torch.sigmoid(dae[i].mixing_logit) + if not torch.isnan(m).any(): + writer.add_histogram( + 'mixing_prob_%d' % i, m.detach().cpu().numpy(), global_step) + + # no other loss + else: + writer.avg_meter( + 'train/p_loss', (p_loss - regularization_p).item(), global_step) + if torch.is_tensor(regularization_p): + writer.avg_meter( + 'train/reg_p', regularization_p.item(), global_step) + if args.regularize_mlogit: + writer.avg_meter( + 'train/m_logit', reg_mlogit / args.regularize_mlogit, global_step) + if args.mixed_prediction: + writer.avg_meter( + 'train/m_logit_sum', torch.sum(torch.sigmoid(dae.mixing_logit)).detach().cpu(), global_step) + if (global_step) % 500 == 0: + writer.add_scalar( + 'train/norm_loss_dae', dae_norm_loss, global_step) + writer.add_scalar('train/bn_loss_dae', + dae_bn_loss, global_step) + writer.add_scalar( + 'train/norm_coeff_dae', dae_wdn_coeff, global_step) + if args.mixed_prediction: + m = torch.sigmoid(dae.mixing_logit) + if not torch.isnan(m).any(): + writer.add_histogram( + 'mixing_prob', m.detach().cpu().numpy(), global_step) + + # write stats + if self.writer is not None: + for k, v in output.items(): + if 'print/' in k and step is not None: + self.writer.avg_meter(k.split('print/')[-1], + v.mean().item() if torch.is_tensor(v) else v, + step=step) + res = output + output_dict = { + 'loss': loss.detach().cpu().item(), + 'x_0_pred': res['x_0_pred'].detach().cpu(), # perturbed data + 'x_0': res['x_0'].detach().cpu(), + # B.B,3 + 'x_t': res['final_pred'].detach().view(batch_size, -1, res['x_0'].shape[-1]), + 't': res.get('t', None) + } + + for k, v in output.items(): + if 'vis/' in k: + output_dict[k] = v + return output_dict + # --------------------------------------------- # + # visulization function and sampling function # + # --------------------------------------------- # + + def build_prior(self): + args = self.cfg.sde + device = torch.device(self.device_str) + arch_instance_dae = utils.get_arch_cells_denoising( + 'res_ho_attn', True, False) + num_input_channels = self.cfg.shapelatent.latent_dim + + DAE = nn.ModuleList( + [ + import_model(self.cfg.latent_pts.style_prior)(args, + self.cfg.latent_pts.style_dim, self.cfg), # style prior + import_model(self.cfg.sde.prior_model)(args, + num_input_channels, self.cfg), # global prior, conditional model + ]) + + self.dae = DAE.to(device) + + # Bad solution! it is used in validate_inspect function + self.dae.num_points = self.dae[1].num_points + self.dae.num_classes = self.dae[1].num_classes + + if len(self.cfg.sde.dae_checkpoint): + logger.info('Load dae checkpoint: {}', + self.cfg.sde.dae_checkpoint) + checkpoint = torch.load( + self.cfg.sde.dae_checkpoint, map_location='cpu') + self.dae.load_state_dict(checkpoint['dae_state_dict']) + + self.diffusion_cont = make_diffusion(args) + self.diffusion_disc = DiffusionDiscretized( + args, self.diffusion_cont.var, self.cfg) + if not quiet: + logger.info('DAE: {}', self.dae) + logger.info('DAE: param size = %fM ' % + utils.count_parameters_in_M(self.dae)) + # sync all parameters between all gpus by sending param from rank 0 to all gpus. + utils.broadcast_params(self.dae.parameters(), self.args.distributed) diff --git a/trainers/train_prior.py b/trainers/train_prior.py new file mode 100644 index 0000000..bd74870 --- /dev/null +++ b/trainers/train_prior.py @@ -0,0 +1,741 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +""" to train hierarchical VAE model with single prior """ +import os +import time +from PIL import Image +import gc +import psutil +import functools +import torch +import torch.nn.functional as F +import torch.nn as nn +import torchvision +import numpy as np +from loguru import logger +import torch.distributed as dist +from torch import optim +from trainers.base_trainer import BaseTrainer +from utils.ema import EMA +from utils.model_helper import import_model, loss_fn +from utils.vis_helper import visualize_point_clouds_3d +from utils.eval_helper import compute_NLL_metric +from utils import model_helper, exp_helper, data_helper +from utils.data_helper import normalize_point_clouds +from utils.diffusion_pvd import DiffusionDiscretized +from utils.diffusion_continuous import make_diffusion, DiffusionBase +from utils.checker import * +from utils import utils +from matplotlib import pyplot as plt +import third_party.pvcnn.functional as pvcnn_fn +from timeit import default_timer as timer +from torch.optim import Adam as FusedAdam +from torch.cuda.amp import autocast, GradScaler +from trainers import common_fun_prior_train + + +@torch.no_grad() +def generate_samples_vada(shape, dae, diffusion, vae, num_samples, + enable_autocast, ode_eps=0.00001, ode_solver_tol=1e-5, # None, + ode_sample=False, prior_var=1.0, temp=1.0, vae_temp=1.0, + noise=None, need_denoise=False, ddim_step=0, clip_feat=None): + output = {} + if ode_sample == 1: + assert isinstance( + diffusion, DiffusionBase), 'ODE-based sampling requires cont. diffusion!' + assert ode_eps is not None, 'ODE-based sampling requires integration cutoff ode_eps!' + assert ode_solver_tol is not None, 'ODE-based sampling requires ode solver tolerance!' + start = timer() + eps, eps_list, nfe, time_ode_solve = diffusion.sample_model_ode( + dae, num_samples, shape, ode_eps, + ode_solver_tol, enable_autocast, temp, noise, return_all_sample=True) + output['sampled_eps'] = eps + output['eps_list'] = eps_list + logger.info('ode_eps={}', ode_eps) + elif ode_sample == 0: + assert isinstance( + diffusion, DiffusionDiscretized), 'Regular sampling requires disc. diffusion!' + assert noise is None, 'Noise is not used in ancestral sampling.' + nfe = diffusion._diffusion_steps + time_ode_solve = 999.999 # Yeah I know... + start = timer() + if ddim_step > 0: + eps, eps_list = diffusion.run_ddim(dae, + num_samples, shape, temp, enable_autocast, + is_image=False, prior_var=prior_var, ddim_step=ddim_step) + else: + eps, eps_list = diffusion.run_denoising_diffusion(dae, + num_samples, shape, temp, enable_autocast, + is_image=False, prior_var=prior_var) + output['sampled_eps'] = eps # latent pts + output['eps_list'] = eps_list + else: + raise NotImplementedError + output['print/sample_mean_global'] = eps.view( + num_samples, -1).mean(-1).mean() + output['print/sample_var_global'] = eps.view( + num_samples, -1).var(-1).mean() + decomposed_eps = vae.decompose_eps(eps) + image = vae.sample(num_samples=num_samples, decomposed_eps=decomposed_eps) + + end = timer() + sampling_time = end - start + # average over GPUs + nfe_torch = torch.tensor(nfe * 1.0, device='cuda') + sampling_time_torch = torch.tensor(sampling_time * 1.0, device='cuda') + time_ode_solve_torch = torch.tensor(time_ode_solve * 1.0, device='cuda') + return image, nfe_torch, time_ode_solve_torch, sampling_time_torch, output + + +@torch.no_grad() +def validate_inspect(latent_shape, + model, dae, diffusion, ode_sample, + it, writer, + sample_num_points, num_samples, + autocast_train=False, + need_sample=1, need_val=1, need_train=1, + w_prior=None, val_x=None, tr_x=None, + val_input=None, + m_pcs=None, s_pcs=None, + test_loader=None, # can be None + has_shapelatent=False, vis_latent_point=False, + ddim_step=0, epoch=0, fun_generate_samples_vada=None, clip_feat=None, + cls_emb=None, cfg={}): + """ visualize the samples, and recont if needed + Args: + has_shapelatent (bool): True when the model has shape latent + it (int): step index + num_samples: + need_* : draw samples for * or not + """ + assert(has_shapelatent) + assert(w_prior is not None and val_x is not None and tr_x is not None) + z_list = [] + num_samples = w_prior.shape[0] if need_sample else 0 + num_recon = val_x.shape[0] + num_recon_val = num_recon if need_val else 0 + num_recon_train = num_recon if need_train else 0 + kwargs = {} + assert(need_sample >= 0 and need_val > 0 and need_train == 0) + if need_sample: + # gen_x: B,N,3 + gen_x, nstep, ode_time, sample_time, output_dict = \ + fun_generate_samples_vada(latent_shape, dae, diffusion, + model, w_prior.shape[0], enable_autocast=autocast_train, + ode_sample=ode_sample, ddim_step=ddim_step, clip_feat=clip_feat, + **kwargs) + logger.info('cast={}, sample step={}, ode_time={}, sample_time={}', + autocast_train, + nstep if ddim_step == 0 else ddim_step, + ode_time, sample_time) + gen_pcs = gen_x + else: + output_dict = {} + vis_order = cfg.viz.viz_order + vis_args = {'vis_order': vis_order, + } + # vis the samples + if not vis_latent_point and num_samples > 0: + img_list = [] + for i in range(num_samples): + points = gen_x[i] # N,3 + points = normalize_point_clouds([points])[0] + img = visualize_point_clouds_3d([points], **vis_args) + img_list.append(img) + img = np.concatenate(img_list, axis=2) + writer.add_image('sample', torch.as_tensor(img), it) + + # vis the latent points + if vis_latent_point and num_samples > 0: + img_list = [] + eps_list = [] + eps = output_dict['sampled_eps'].view( + num_samples, dae.num_points, dae.num_classes)[:, :, :3] + for i in range(num_samples): + points = gen_x[i] + points = normalize_point_clouds([points])[0] + img = visualize_point_clouds_3d([points], **vis_args) + img_list.append(img) + + points = eps[i] + points = normalize_point_clouds([points])[0] + img = visualize_point_clouds_3d([points], **vis_args) + eps_list.append(img) + img = np.concatenate(img_list, axis=2) + img_eps = np.concatenate(eps_list, axis=2) + img = np.concatenate([img, img_eps], axis=1) + writer.add_image('sample', torch.as_tensor(img), it) + logger.info('call recont') + inputs = val_input if val_input is not None else val_x + output = model.recont(inputs) if cls_emb is None else model.recont( + inputs, cls_emb=cls_emb) + gen_x = output['final_pred'] + + # vis the recont + if num_recon_val > 0: + img_list = [] + for i in range(num_recon_val): + points = gen_x[i] + points = normalize_point_clouds([points]) + img = visualize_point_clouds_3d(points, ['rec#%d' % i], **vis_args) + img_list.append(img) + gt_list = [] + for i in range(num_recon_val): + points = normalize_point_clouds([val_x[i]]) + img = visualize_point_clouds_3d(points, ['gt#%d' % i], **vis_args) + gt_list.append(img) + img = np.concatenate(img_list, axis=2) + gt = np.concatenate(gt_list, axis=2) + img = np.concatenate([gt, img], axis=1) + + if 'vis/latent_pts' in output: + # also vis the input, used when we take voxel points as input + input_list = [] + for i in range(num_recon_val): + points = output['vis/latent_pts'][i, :, :3] + points = normalize_point_clouds([points]) + input_img = visualize_point_clouds_3d( + points, ['input#%d' % i], **vis_args) + input_list.append(input_img) + input_list = np.concatenate(input_list, axis=2) + img = np.concatenate([img, input_list], axis=1) + writer.add_image('valrecont', torch.as_tensor(img), it) + + if num_recon_train > 0: + img_list = [] + for i in range(num_recon_train): + points = gen_x[num_recon_val + i] + points = normalize_point_clouds([tr_x[i], points]) + img = visualize_point_clouds_3d(points, ['ori', 'rec'], **vis_args) + img_list.append(img) + img = np.concatenate(img_list, axis=2) + writer.add_image('train/recont', torch.as_tensor(img), it) + + logger.info('writer: {}', writer.url) + return output_dict + + +class Trainer(BaseTrainer): + is_diffusion = 0 + + def __init__(self, cfg, args): + """ + Args: + cfg: training config + args: used for distributed training + """ + super().__init__(cfg, args) + self.draw_sample_when_vis = 1 + self.fun_generate_samples_vada = functools.partial( + generate_samples_vada, ode_eps=cfg.sde.ode_eps) + self.train_iter_kwargs = {} + self.cfg.sde.distributed = args.distributed + self.sample_num_points = cfg.data.tr_max_sample_points + self.model_var_type = cfg.ddpm.model_var_type + self.clip_denoised = cfg.ddpm.clip_denoised + self.num_steps = cfg.ddpm.num_steps + self.model_mean_type = cfg.ddpm.model_mean_type + self.loss_type = cfg.ddpm.loss_type + device = torch.device(self.device_str) + + self.model = self.build_model().to(device) + if len(self.cfg.sde.vae_checkpoint) and not args.pretrained and self.cfg.sde.vae_checkpoint != 'none': + # if has pretrained ckpt, we dont need to load the vae ckpt anymore + logger.info('Load vae_checkpoint: {}', self.cfg.sde.vae_checkpoint) + vae_ckpt = torch.load(self.cfg.sde.vae_checkpoint) + vae_weight = vae_ckpt['model'] + self.model.load_state_dict(vae_weight) + + if self.cfg.shapelatent.model == 'models.hvae_ddpm': + self.model.build_other_module(device) + logger.info('broadcast_params: device={}', device) + utils.broadcast_params(self.model.parameters(), + args.distributed) + self.build_other_module() + self.build_prior() + + if args.distributed: + logger.info('waitting for barrier, device={}', device) + dist.barrier() + logger.info('pass barrier, device={}', device) + + self.train_loader, self.test_loader = self.build_data() + # The optimizer + self.init_optimizer() + # Prepare variable for summy + self.num_points = self.cfg.data.tr_max_sample_points + logger.info('done init trainer @{}', device) + + # Prepare for evaluation + # init the latent for validate + self.prepare_vis_data() + self.alpha_i = utils.kl_balancer_coeff( + num_scales=2, + groups_per_scale=[1, 1], fun='square') + + @property + def vae(self): + return self.model + + def init_optimizer(self): + out_dict = common_fun_prior_train.init_optimizer_train_2prior( + self.cfg, self.vae, self.dae) + self.dae_sn_calculator, self.vae_sn_calculator = out_dict[ + 'dae_sn_calculator'], out_dict['vae_sn_calculator'] + self.vae_scheduler, self.vae_optimizer = out_dict['vae_scheduler'], out_dict['vae_optimizer'] + self.dae_scheduler, self.dae_optimizer = out_dict['dae_scheduler'], out_dict['dae_optimizer'] + self.grad_scalar = out_dict['grad_scalar'] + + def resume(self, path, strict=True, **kwargs): + dae, vae = self.dae, self.vae + vae_optimizer, vae_scheduler, dae_optimizer, dae_scheduler = \ + self.vae_optimizer, self.vae_scheduler, self.dae_optimizer, self.dae_scheduler + grad_scalar = self.grad_scalar + + checkpoint = torch.load(path, map_location='cpu') + init_epoch = checkpoint['epoch'] + epoch = init_epoch + dae.load_state_dict(checkpoint['dae_state_dict']) + # load dae + dae = dae.cuda() + dae_optimizer.load_state_dict(checkpoint['dae_optimizer']) + dae_scheduler.load_state_dict(checkpoint['dae_scheduler']) + # load vae + if self.cfg.eval.load_other_vae_ckpt: + raise NotImplementedError + else: + vae.load_state_dict(checkpoint['vae_state_dict']) + vae_optimizer.load_state_dict(checkpoint['vae_optimizer']) + vae = vae.cuda() + + # need to commend if load regular vae from voxel2input_ada trainer + vae_scheduler.load_state_dict(checkpoint['vae_scheduler']) + grad_scalar.load_state_dict(checkpoint['grad_scalar']) + global_step = checkpoint['global_step'] + ## logger.info('loaded the model at epoch %d.'%init_epoch) + + start_epoch = epoch + self.epoch = start_epoch + self.step = global_step + logger.info('resumedd from : {}, epo={}', path, start_epoch) + return start_epoch + + def save(self, save_name=None, epoch=None, step=None, appendix=None, save_dir=None, **kwargs): + dae, vae = self.dae, self.vae + vae_optimizer, vae_scheduler, dae_optimizer, dae_scheduler = \ + self.vae_optimizer, self.vae_scheduler, self.dae_optimizer, self.dae_scheduler + grad_scalar = self.grad_scalar + content = {'epoch': epoch + 1, 'global_step': step, + # 'args': self.cfg.sde, 'cfg': self.cfg, + 'grad_scalar': grad_scalar.state_dict(), + 'dae_state_dict': dae.state_dict(), 'dae_optimizer': dae_optimizer.state_dict(), + 'dae_scheduler': dae_scheduler.state_dict(), 'vae_state_dict': vae.state_dict(), + 'vae_optimizer': vae_optimizer.state_dict(), 'vae_scheduler': vae_scheduler.state_dict()} + if appendix is not None: + content.update(appendix) + save_name = "epoch_%s_iters_%s.pt" % ( + epoch, step) if save_name is None else save_name + if save_dir is None: + save_dir = self.cfg.save_dir + path = os.path.join(save_dir, "checkpoints", save_name) + if not os.path.exists(os.path.dirname(path)): + os.makedirs(os.path.dirname(path)) + logger.info('save model as : {}', path) + torch.save(content, path) + return path + + def epoch_start(self, epoch): + if epoch > self.cfg.sde.warmup_epochs: + self.dae_scheduler.step() + self.vae_scheduler.step() + + def compute_loss_vae(self, tr_pts, global_step, **kwargs): + """ compute forward for VAE model, used in global-only prior training + Input: + tr_pts: points + global_step: int + Returns: + output dict including entry: + 'eps': z ~ posterior + 'q_loss': 0 if not train vae else the KL+rec + 'x_0_pred': global points if not train vae + 'x_0_target': target points + + """ + vae = self.model + dae = self.dae + args = self.cfg.sde + distributed = args.distributed + vae_sn_calculator = self.vae_sn_calculator + num_total_iter = self.num_total_iter + if self.cfg.sde.ode_sample == 1: + diffusion = self.diffusion_cont + elif self.cfg.sde.ode_sample == 0: + diffusion = self.diffusion_disc + elif self.cfg.sde.ode_sample == 2: + raise NotImplementedError + ## diffusion = [self.diffusion_cont, self.diffusion_disc] + + ## diffusion = self.diffusion_cont if self.cfg.sde.ode_sample else self.diffusion_disc + B = tr_pts.size(0) + with torch.set_grad_enabled(args.train_vae): + with autocast(enabled=args.autocast_train): + # posterior and likelihood + if not args.train_vae: + dist = vae.encode(tr_pts) + eps = dist.sample()[0] # B,D or B,N,D or BN,D + all_log_q = [dist.log_p(eps)] + x_0_pred = x_0_target = tr_pts + vae_recon_loss = 0 + + def make_4d( + x): return x.unsqueeze(-1).unsqueeze(-1) if len(x.shape) == 2 else x.unsqueeze(-1) + eps = make_4d(eps) + output = {'eps': eps, 'q_loss': torch.zeros(1), + 'x_0_pred': tr_pts, 'x_0_target': tr_pts, + 'x_0': tr_pts, 'final_pred': tr_pts} + else: + raise NotImplementedError + return output + # ------------------------------------------- # + # training fun # + # ------------------------------------------- # + + def train_iter(self, data, *args, **kwargs): + """ forward one iteration; and step optimizer + Args: + data: (dict) tr_points shape: (B,N,3) + see get_loss in models/shapelatent_diffusion.py + """ + # some variables + + input_dim = self.cfg.ddpm.input_dim + loss_type = self.cfg.ddpm.loss_type + vae = self.model + dae = self.dae + dae.train() + diffusion = self.diffusion_cont if self.cfg.sde.ode_sample else self.diffusion_disc + dae_optimizer = self.dae_optimizer + vae_optimizer = self.vae_optimizer + args = self.cfg.sde + device = torch.device(self.device_str) + num_total_iter = self.num_total_iter + distributed = self.args.distributed + dae_sn_calculator = self.dae_sn_calculator + vae_sn_calculator = self.vae_sn_calculator + grad_scalar = self.grad_scalar + + global_step = step = kwargs.get('step', None) + no_update = kwargs.get('no_update', False) + + # update_lr + warmup_iters = len(self.train_loader) * args.warmup_epochs + utils.update_lr(args, global_step, warmup_iters, + dae_optimizer, vae_optimizer) + + # input + tr_pts = data['tr_points'].to(device) # (B, Npoints, 3) + # the noisy points, used in trainers/voxel2pts.py and trainers/voxel2pts_ada.py + inputs = data['input_pts'].to(device) if 'input_pts' in data else None + B = batch_size = tr_pts.size(0) + + # optimize vae params + vae_optimizer.zero_grad() + output = self.compute_loss_vae(tr_pts, global_step, inputs=inputs) + + # backpropagate q_loss for vae and update vae params, if trained + if args.train_vae: + q_loss = output['q_loss'] + loss = q_loss + grad_scalar.scale(q_loss).backward() + utils.average_gradients(vae.parameters(), distributed) + if args.grad_clip_max_norm > 0.: # apply gradient clipping + grad_scalar.unscale_(vae_optimizer) + torch.nn.utils.clip_grad_norm_(vae.parameters(), + max_norm=args.grad_clip_max_norm) + grad_scalar.step(vae_optimizer) + + # train prior + if args.train_dae: + # the interface between VAE and DAE is eps. + eps = output['eps'].detach() # 4d: B,D,-1,1 + CHECK4D(eps) + dae_optimizer.zero_grad() + with autocast(enabled=args.autocast_train): + noise_p = torch.randn(size=eps.size(), device=device) + # get diffusion quantities for p sampling scheme and reweighting for q + t_p, var_t_p, m_t_p, obj_weight_t_p, _, g2_t_p = \ + diffusion.iw_quantities(B, args.time_eps, + args.iw_sample_p, args.iw_subvp_like_vp_sde) + # logger.info('t_p: {}, var: {}, m_t: {}', t_p[0], var_t_p[0], m_t_p[0]) + eps_t_p = diffusion.sample_q(eps, noise_p, var_t_p, m_t_p) + # run the score model + eps_t_p.requires_grad_(True) + mixing_component = diffusion.mixing_component( + eps_t_p, var_t_p, t_p, enabled=args.mixed_prediction) + pred_params_p = dae(eps_t_p, t_p, x0=eps) + + # pred_eps_t0 = (eps_t_p - torch.sqrt(var_t_p) * noise_p) / m_t_p # this will recover the true eps + pred_eps_t0 = (eps_t_p - torch.sqrt(var_t_p) + * pred_params_p) / m_t_p + params = utils.get_mixed_prediction(args.mixed_prediction, + pred_params_p, dae.mixing_logit, mixing_component) + if self.cfg.latent_pts.pvd_mse_loss: + p_loss = F.mse_loss( + params.contiguous().view(B, -1), noise_p.view(B, -1), + reduction='mean') + else: + l2_term_p = torch.square(params - noise_p) + p_objective = torch.sum( + obj_weight_t_p * l2_term_p, dim=[1, 2, 3]) + + regularization_p, dae_norm_loss, dae_bn_loss, dae_wdn_coeff, \ + jac_reg_loss, kin_reg_loss = utils.dae_regularization( + args, dae_sn_calculator, diffusion, dae, step, t_p, + pred_params_p, eps_t_p, var_t_p, m_t_p, g2_t_p) + if args.regularize_mlogit: + reg_mlogit = ((torch.sum(torch.sigmoid(dae.mixing_logit)) - + args.regularize_mlogit_margin)**2) * args.regularize_mlogit + else: + reg_mlogit = 0 + p_loss = torch.mean(p_objective) + \ + regularization_p + reg_mlogit + + loss = p_loss + # update dae parameters + grad_scalar.scale(p_loss).backward() + utils.average_gradients(dae.parameters(), distributed) + if args.grad_clip_max_norm > 0.: # apply gradient clipping + grad_scalar.unscale_(dae_optimizer) + torch.nn.utils.clip_grad_norm_(dae.parameters(), + max_norm=args.grad_clip_max_norm) + grad_scalar.step(dae_optimizer) + + # update grade scalar + grad_scalar.update() + + if args.bound_mlogit: + dae.mixing_logit.data.clamp_(max=args.bound_mlogit_value) + # Bookkeeping! + writer = self.writer + if writer is not None: + writer.avg_meter('train/lr_dae', dae_optimizer.state_dict()[ + 'param_groups'][0]['lr'], global_step) + writer.avg_meter('train/lr_vae', vae_optimizer.state_dict()[ + 'param_groups'][0]['lr'], global_step) + if self.cfg.latent_pts.pvd_mse_loss: + writer.avg_meter( + 'train/p_loss', p_loss.item(), global_step) + if args.mixed_prediction and global_step % 500 == 0: + m = torch.sigmoid(dae.mixing_logit) + if not torch.isnan(m).any(): + writer.add_histogram( + 'mixing_prob', m.detach().cpu().numpy(), global_step) + + # no other loss + else: + writer.avg_meter( + 'train/p_loss', (p_loss - regularization_p).item(), global_step) + if torch.is_tensor(regularization_p): + writer.avg_meter( + 'train/reg_p', regularization_p.item(), global_step) + if args.regularize_mlogit: + writer.avg_meter( + 'train/m_logit', reg_mlogit / args.regularize_mlogit, global_step) + if args.mixed_prediction: + writer.avg_meter( + 'train/m_logit_sum', torch.sum(torch.sigmoid(dae.mixing_logit)).detach().cpu(), global_step) + if (global_step) % 500 == 0: + writer.add_scalar( + 'train/norm_loss_dae', dae_norm_loss, global_step) + writer.add_scalar('train/bn_loss_dae', + dae_bn_loss, global_step) + writer.add_scalar( + 'train/norm_coeff_dae', dae_wdn_coeff, global_step) + if args.mixed_prediction: + m = torch.sigmoid(dae.mixing_logit) + if not torch.isnan(m).any(): + writer.add_histogram( + 'mixing_prob', m.detach().cpu().numpy(), global_step) + + # write stats + if self.writer is not None: + for k, v in output.items(): + if 'print/' in k and step is not None: + self.writer.avg_meter(k.split('print/')[-1], + v.mean().item() if torch.is_tensor(v) else v, + step=step) + if 'hist/' in k: + output[k] = v + res = output + output_dict = { + 'loss': loss.detach().cpu().item(), + 'x_0_pred': res['x_0_pred'].detach().cpu(), # perturbed data + 'x_0': res['x_0'].detach().cpu(), + # B.B,3 + 'x_t': res['final_pred'].detach().view(batch_size, -1, res['x_0'].shape[-1]), + 't': res.get('t', None) + } + + for k, v in output.items(): + if 'vis/' in k or 'msg/' in k: + output_dict[k] = v + return output_dict + # --------------------------------------------- # + # visulization function and sampling function # + # --------------------------------------------- # + + @torch.no_grad() + def vis_sample(self, writer, num_vis=None, step=0, include_pred_x0=True, + save_file=None): + if self.cfg.ddpm.ema: + self.swap_vae_param_if_need() + self.dae_optimizer.swap_parameters_with_ema( + store_params_in_ema=True) + shape = self.model.latent_shape() + logger.info('Latent shape for prior: {}; num_val_samples: {}', + shape, self.num_val_samples) + # [self.vae.latent_dim, .num_input_channels, dae.input_size, dae.input_size] + ode_sample = self.cfg.sde.ode_sample + ## diffusion = self.diffusion_cont if ode_sample else self.diffusion_disc + if self.cfg.sde.ode_sample == 1: + diffusion = self.diffusion_cont + elif self.cfg.sde.ode_sample == 0: + diffusion = self.diffusion_disc + if self.cfg.clipforge.enable: + assert(self.clip_feat_test is not None) + kwargs = {} + output = validate_inspect(shape, self.model, self.dae, + diffusion, ode_sample, + step, self.writer, self.sample_num_points, + epoch=self.cur_epoch, + autocast_train=self.cfg.sde.autocast_train, + need_sample=self.draw_sample_when_vis, + need_val=1, need_train=0, + num_samples=self.num_val_samples, + test_loader=self.test_loader, + w_prior=self.w_prior, + val_x=self.val_x, tr_x=self.tr_x, + val_input=self.val_input, + m_pcs=self.m_pcs, s_pcs=self.s_pcs, + has_shapelatent=True, + vis_latent_point=self.cfg.vis_latent_point, + ddim_step=self.cfg.viz.vis_sample_ddim_step, + clip_feat=self.clip_feat_test, + cfg=self.cfg, + fun_generate_samples_vada=self.fun_generate_samples_vada, + **kwargs + ) + if writer is not None: + for n, v in output.items(): + if 'print/' not in n: + continue + self.writer.add_scalar('%s' % (n.split('print/')[-1]), v, step) + + if self.cfg.ddpm.ema: + self.swap_vae_param_if_need() + self.dae_optimizer.swap_parameters_with_ema( + store_params_in_ema=True) + + @torch.no_grad() + def sample(self, num_shapes=2, num_points=2048, device_str='cuda', + for_vis=True, use_ddim=False, save_file=None, ddim_step=0, clip_feat=None): + """ return the final samples in shape [B,3,N] """ + # switch to EMA parameters + assert( + not self.cfg.clipforge.enable), f'not suuport yet, not sure what the clip feat will be' + cfg = self.cfg + if cfg.ddpm.ema: + self.swap_vae_param_if_need() + self.dae_optimizer.swap_parameters_with_ema( + store_params_in_ema=True) + self.model.eval() # Draw sample under train mode + S = self.num_steps + logger.info('num_shapes={}, num_points={}, use_ddim={}, Nstep={}', + num_shapes, num_points, use_ddim, S) + latent_shape = self.model.latent_shape() + + ode_sample = self.cfg.sde.ode_sample + ## diffusion = self.diffusion_cont if ode_sample else self.diffusion_disc + if self.cfg.sde.ode_sample == 1: + diffusion = self.diffusion_cont + elif self.cfg.sde.ode_sample == 0: + diffusion = self.diffusion_disc + elif self.cfg.sde.ode_sample == 2: + diffusion = [self.diffusion_cont, self.diffusion_disc] + + # ---- forward sampling ---- # + gen_x, nstep, ode_time, sample_time, output_fsample = \ + self.fun_generate_samples_vada(latent_shape, self.dae, + diffusion, self.model, num_shapes, + enable_autocast=self.cfg.sde.autocast_train, + ode_sample=ode_sample, + need_denoise=self.cfg.eval.need_denoise, + ddim_step=ddim_step, + clip_feat=clip_feat) + # gen_x: BNC + CHECKEQ(gen_x.shape[2], self.cfg.ddpm.input_dim) + if gen_x.shape[1] > self.sample_num_points: + gen_x = pvcnn_fn.furthest_point_sample(gen_x.permute(0, 2, 1).contiguous(), + self.sample_num_points).permute(0, 2, 1).contiguous() # [B,C,npoint] + + traj = gen_x.permute(0, 2, 1).contiguous() # BN3->B3N + + # ---- debug perpuse ---- # + if save_file: + if not os.path.exists(os.path.dirname(save_file)): + os.makedirs(os.path.dirname(save_file)) + torch.save(traj.permute(0, 2, 1), save_file) + exit() + + # switch back to original parameters + if cfg.ddpm.ema: + self.dae_optimizer.swap_parameters_with_ema( + store_params_in_ema=True) + self.swap_vae_param_if_need() + return traj + + def build_prior(self): + args = self.cfg.sde + device = torch.device(self.device_str) + arch_instance_dae = utils.get_arch_cells_denoising( + 'res_ho_attn', True, False) + num_input_channels = self.cfg.shapelatent.latent_dim + + if self.cfg.sde.hier_prior: + if self.cfg.sde.prior_model == 'sim': + DAE = NCSNppPointHie + else: + DAE = import_model(self.cfg.sde.prior_model) + elif self.cfg.sde.prior_model == 'sim': + DAE = NCSNppPoint + else: + DAE = import_model(self.cfg.sde.prior_model) + + self.dae = DAE(args, num_input_channels, self.cfg).to(device) + if len(self.cfg.sde.dae_checkpoint): + logger.info('Load dae checkpoint: {}', + self.cfg.sde.dae_checkpoint) + checkpoint = torch.load( + self.cfg.sde.dae_checkpoint, map_location='cpu') + self.dae.load_state_dict(checkpoint['dae_state_dict']) + + self.diffusion_cont = make_diffusion(args) + self.diffusion_disc = DiffusionDiscretized( + args, self.diffusion_cont.var, self.cfg) + logger.info('DAE: {}', self.dae) + logger.info('DAE: param size = %fM ' % + utils.count_parameters_in_M(self.dae)) + + ## self.check_consistence(self.diffusion_cont, self.diffusion_disc) + # sync all parameters between all gpus by sending param from rank 0 to all gpus. + utils.broadcast_params(self.dae.parameters(), self.args.distributed) + + def swap_vae_param_if_need(self): + if self.cfg.eval.load_other_vae_ckpt: + self.optimizer.swap_parameters_with_ema(store_params_in_ema=True) diff --git a/utils/checker.py b/utils/checker.py new file mode 100644 index 0000000..c89a778 --- /dev/null +++ b/utils/checker.py @@ -0,0 +1,80 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch + +def CHECKDIM(tensor, dim, val): + if type(tensor) == list: + for t in tensor: + CHECKDIM(t, dim, val) + else: + assert(len(tensor.shape) >= dim), 'expect {} to have {} dim shape {}'.format(tensor.shape, dim, val) + if type(val) is list: + assert(tensor.shape[dim] in val), 'expect {} to have {} dim shape {}'.format( + tensor.shape, dim, val) + else: + assert(tensor.shape[dim] == val), 'expect tensor with shape: {} having dim {} as {}'.format( + tensor.shape, dim, val) + + return True + +def CHECK5D(tensor, *args): + assert(len(tensor.shape) == 5), 'get {} {}'.format(tensor.shape, len(tensor.shape)) + for t in args: + CHECK5D(t) + return tensor.shape + +def CHECK3D(tensor, *args): + assert(len(tensor.shape) == 3), 'get {} {}'.format(tensor.shape, len(tensor.shape)) + for t in args: + CHECK3D(t) + return tensor.shape + +def CHECK4D(tensor): + assert(len(tensor.shape) == 4), 'get {} {}'.format(tensor.shape, len(tensor.shape)) + return tensor.shape +def CHECKND(tensor, N): + assert(len(tensor.shape) == N), 'get tensor shape:{} DIM={}, expect:{}'.format(tensor.shape, len(tensor.shape), N) + return tensor.shape + +def CHECK2D(tensor): + assert(len(tensor.shape) == 2), 'get {} {}'.format(tensor.shape, len(tensor.shape)) + return tensor.shape + +def CHECK_N3or6(input): + # expect input in shape (N,3) or (N,6) + CHECK_TENSOR(input) + CHECK2D(input) + assert(input.shape[1] == 3 or input.shape[1] == 6), f'expect shape N,3 or N,6; get {input.shape}' + return input.shape + +def CHECK_N3or6or9(input): + # expect input in shape (N,3) or (N,6) + CHECK_TENSOR(input) + CHECK2D(input) + assert(input.shape[1] == 3 or input.shape[1] == 6 or input.shape[1] == 9), f'expect shape N,3 or N,6; get {input.shape}' + return input.shape + +def CHECK_N3(input): + # expect input in shape (N,3) + CHECK_TENSOR(input) + CHECK2D(input) + CHECKDIM(input, dim=1, val=3) + return input.shape + +def CHECK_TENSOR(input): + assert(torch.is_tensor(input)), f'expect tensor, get {type(input)}' + +def CHECKEQ(a, b): + assert(a == b), f'expect a=b, get a={a} and b={b}' + +def CHECKSIZE(t, values): + CHECKND(t, len(values)) + for iv, vv in enumerate(values): + CHECKDIM(t, iv, vv) +def CHECKSAMESIZE(t1, t2): + CHECKSIZE(t1, t2.shape) diff --git a/utils/data_helper.py b/utils/data_helper.py new file mode 100644 index 0000000..f32fa78 --- /dev/null +++ b/utils/data_helper.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +def normalize_point_clouds(pcs, mode='shape_bbox'): # can be a property func + ''' + copied from https://github.com/luost26/diffusion-point-cloud/blob/0bfd688379e78ac75fa75e6a2c5029e362496169/test_gen.py#L16 + Args: + pcs: list of [N,3] or tensor in shape: B,N,3 + ''' + # logger.debug('Normalization mode: %s' % mode) + assert(type(pcs) == list or len(pcs.shape) == + 3), f'expect pcs to be list, get: {type(pcs)} or 3d tensor; ' + output_list = [] + for i in range(len(pcs)): # , desc='Normalize'): + pc = pcs[i] + pc = pc.detach().clone() + assert(mode == 'shape_bbox') + assert(len(pc.shape) == 2 and pc.shape[-1] in [3, 4, + 6, 9]), f'expect get (N,3 or 6), get {pc.shape}' + pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3) + pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3) + pc_min = pc_min[:, :3] + pc_max = pc_max[:, :3] + shift = ((pc_min + pc_max) / 2).view(1, 3) + scale = (pc_max - pc_min).max().reshape(1, 1) / 2 + pc[:, :3] = (pc[:, :3] - shift) / scale + # pcs[i] = pc + output_list.append(pc) + return output_list + diff --git a/utils/diffusion.py b/utils/diffusion.py new file mode 100644 index 0000000..13dce95 --- /dev/null +++ b/utils/diffusion.py @@ -0,0 +1,170 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +""" +copied and modified from source: + https://github.com/NVlabs/LSGM/blob/5eae2f385c014f2250c3130152b6be711f6a3a5a/diffusion_discretized.py +""" +from loguru import logger +import time +import torch +import torch.nn.functional as F +from torch.nn import Module, Parameter, ModuleList +import numpy as np + + +def extract(input, t, shape): + B = t.shape[0] + out = torch.gather(input, 0, t.to(input.device)) + reshape = [shape[0]] + [1] * (len(shape) - 1) + out = out.reshape(*reshape) + return out + + +def make_beta_schedule(schedule, start, end, n_timestep): + if schedule == "cust": # airplane + b_start = start + b_end = end + time_num = n_timestep + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.1) + betas[:warmup_time] = np.linspace( + b_start, b_end, warmup_time, dtype=np.float64) + betas = torch.from_numpy(betas) + + #betas = torch.zeros(n_timestep, dtype=torch.float64) + end + #n_timestep_90 = int(n_timestep*0.9) + # betas_0 = torch.linspace(start, + # end, + # n_timestep_90, + # dtype=torch.float64) + #betas[:n_timestep_90] = betas_0 + + elif schedule == "quad": + betas = torch.linspace(start**0.5, + end**0.5, + n_timestep, + dtype=torch.float64)**2 + elif schedule == 'linear': + betas = torch.linspace(start, end, n_timestep, dtype=torch.float64) + elif schedule == 'warmup10': + betas = _warmup_beta(start, end, n_timestep, 0.1) + elif schedule == 'warmup50': + betas = _warmup_beta(start, end, n_timestep, 0.5) + elif schedule == 'const': + betas = end * torch.ones(n_timestep, dtype=torch.float64) + elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1. / (torch.linspace( + n_timestep, 1, n_timestep, dtype=torch.float64)) + else: + raise NotImplementedError(schedule) + return betas + + +class VarianceSchedule(Module): + def __init__(self, num_steps, beta_1, beta_T, mode='linear'): + super().__init__() + assert mode in ('linear', 'cust') + self.num_steps = num_steps + self.beta_1 = beta_1 + self.beta_T = beta_T + self.mode = mode + beta_start = self.beta_1 + beta_end = self.beta_T + assert (beta_start <= beta_end), 'require beta_start < beta_end ' + + logger.info('use beta: {} - {}', beta_1, beta_T) + tic = time.time() + # betas = torch.linspace(beta_1, beta_T, steps=num_steps) + betas = make_beta_schedule(mode, beta_start, beta_end, num_steps) + # elif mode == 'customer': + # beta_0 = 10−5 and beta_T = 0.008 for 90% step, beta_T=0.0088 + ## num_steps_90 = int(0.9 * num_steps) + # logger.info('use beta_0=1e-5 and beta_T=0.008 ' + ## 'for {} step and 0.008 for the rest', + # num_steps_90) + ## betas_sub = torch.linspace(1e-5, 0.008, steps=num_steps_90) + ## betas_full = torch.zeros(num_steps) + 0.008 + ## betas_full[:num_steps_90] = betas_sub + ## betas = betas_full + + # betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding + + #alphas = 1 - betas + #log_alphas = torch.log(alphas) + # for i in range(1, log_alphas.size(0)): # 1 to T + # log_alphas[i] += log_alphas[i - 1] + #alpha_bars = log_alphas.exp() + + #sigmas_flex = torch.sqrt(betas) + #sigmas_inflex = torch.zeros_like(sigmas_flex) + # for i in range(1, sigmas_flex.size(0)): + # sigmas_inflex[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i] + #sigmas_inflex = torch.sqrt(sigmas_inflex) + #sqrt_recip_alphas_cumprod = torch.rsqrt(alpha_bars) + #sqrt_recipm1_alphas_cumprod = torch.sqrt(1 / alpha_bars - 1) + + #self.register_buffer('betas', betas) + #self.register_buffer('alphas', alphas) + #self.register_buffer('alpha_bars', alpha_bars) + #self.register_buffer('sigmas_flex', sigmas_flex) + #self.register_buffer('sigmas_inflex', sigmas_inflex) + #self.register_buffer('sqrt_recip_alphas_cumprod', sqrt_recip_alphas_cumprod) + # self.register_buffer('sqrt_recipm1_alphas_cumprod', + # sqrt_recipm1_alphas_cumprod) + alphas = 1 - betas + alphas_cumprod = torch.cumprod(alphas, 0) + alphas_cumprod_prev = torch.cat( + (torch.tensor([1], dtype=torch.float64), alphas_cumprod[:-1]), 0) + posterior_variance = betas * (1 - alphas_cumprod_prev) / ( + 1 - alphas_cumprod) + self.register("betas", betas) + self.register("alphas_cumprod", alphas_cumprod) + self.register("alphas_cumprod_prev", alphas_cumprod_prev) + + self.register("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) + self.register("sqrt_one_minus_alphas_cumprod", + torch.sqrt(1 - alphas_cumprod)) + self.register("log_one_minus_alphas_cumprod", + torch.log(1 - alphas_cumprod)) + self.register("sqrt_recip_alphas_cumprod", torch.rsqrt(alphas_cumprod)) + self.register("sqrt_recipm1_alphas_cumprod", + torch.sqrt(1 / alphas_cumprod - 1)) + self.register("posterior_variance", posterior_variance) + if len(posterior_variance) > 1: + self.register("posterior_log_variance_clipped", + torch.log( + torch.cat((posterior_variance[1].view( + 1, 1), posterior_variance[1:].view(-1, 1)), 0)).view(-1) + ) + else: + self.register("posterior_log_variance_clipped", + torch.log(posterior_variance[0].view(-1))) + self.register("posterior_mean_coef1", + (betas * torch.sqrt(alphas_cumprod_prev) / + (1 - alphas_cumprod))) + self.register("posterior_mean_coef2", + ((1 - alphas_cumprod_prev) * torch.sqrt(alphas) / + (1 - alphas_cumprod))) + logger.info('built beta schedule: t={:.2f}s', time.time() - tic) + + def register(self, name, tensor): + self.register_buffer(name, tensor.type(torch.float32)) + + def all_sample_t(self): + if self.num_steps > 20: + step = 50 + else: + step = 1 + ts = np.arange(0, self.num_steps, step) + return ts.tolist() + + def get_sigmas(self, t, flexibility): + assert 0 <= flexibility and flexibility <= 1 + sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * ( + 1 - flexibility) + return sigmas diff --git a/utils/diffusion_continuous.py b/utils/diffusion_continuous.py new file mode 100644 index 0000000..f6e8062 --- /dev/null +++ b/utils/diffusion_continuous.py @@ -0,0 +1,845 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +"""modified from https://github.com/NVlabs/LSGM/blob/5eae2f385c014f2250c3130152b6be711f6a3a5a/diffusion_continuous.py""" +from abc import ABC, abstractmethod +import numpy as np +import torch +import gc +# import utils.distributions as distributions +from utils.utils import trace_df_dx_hutchinson, sample_gaussian_like, sample_rademacher_like, get_mixed_prediction +from third_party.torchdiffeq.torchdiffeq import odeint +from torch.cuda.amp import autocast +from timeit import default_timer as timer +from loguru import logger + + +def make_diffusion(args): + """ simple diffusion factory function to return diffusion instances. Only use this to create continuous diffusions """ + if args.sde_type == 'geometric_sde': + return DiffusionGeometric(args) + elif args.sde_type == 'vpsde': + return DiffusionVPSDE(args) + elif args.sde_type == 'sub_vpsde': + return DiffusionSubVPSDE(args) + elif args.sde_type == 'power_vpsde': + return DiffusionPowerVPSDE(args) + elif args.sde_type == 'sub_power_vpsde': + return DiffusionSubPowerVPSDE(args) + elif args.sde_type == 'vesde': + return DiffusionVESDE(args) + else: + raise ValueError("Unrecognized sde type: {}".format(args.sde_type)) + + +class DiffusionBase(ABC): + """ + Abstract base class for all diffusion implementations. + """ + + def __init__(self, args): + super().__init__() + self.sigma2_0 = args.sigma2_0 + self.sde_type = args.sde_type + + @abstractmethod + def f(self, t): + """ returns the drift coefficient at time t: f(t) """ + pass + + @abstractmethod + def g2(self, t): + """ returns the squared diffusion coefficient at time t: g^2(t) """ + pass + + @abstractmethod + def var(self, t): + """ returns variance at time t, \sigma_t^2 + q(zt|z0) = N(zt; \mu_t(z0), \sigma_t^2 I) + """ + pass + + @abstractmethod + def e2int_f(self, t): + """ returns e^{\int_0^t f(s) ds} which corresponds to the coefficient of mean at time t. """ + pass + + @abstractmethod + def inv_var(self, var): + """ inverse of the variance function at input variance var. """ + pass + + @abstractmethod + def mixing_component(self, x_noisy, var_t, t, enabled): + """ returns mixing component which is the optimal denoising model assuming that q(z_0) is N(0, 1) """ + pass + + def sample_q(self, x_init, noise, var_t, m_t): + """ returns a sample from diffusion process at time t """ + return m_t * x_init + torch.sqrt(var_t) * noise + + def cross_entropy_const(self, ode_eps): + """ returns cross entropy factor with variance according to ode integration cutoff ode_eps """ + # _, c, h, w = x_init.shape + return 0.5 * (1.0 + torch.log(2.0 * np.pi * self.var(t=torch.tensor(ode_eps, device='cuda')))) + + def compute_ode_nll(self, dae, eps, ode_eps, ode_solver_tol, enable_autocast=False, + no_autograd=False, num_samples=1, report_std=False, + condition_input=None, clip_feat=None): + ## raise NotImplementedError + """ calculates NLL based on ODE framework, assuming integration cutoff ode_eps """ + # ODE solver starts consuming the CPU memory without this on large models + # https://github.com/scipy/scipy/issues/10070 + gc.collect() + + dae.eval() + + def ode_func(t, x): + """ the ode function (including log probability integration for NLL calculation) """ + global nfe_counter + nfe_counter = nfe_counter + 1 + + # x = state[0].detach() + x = x.detach() + x.requires_grad_(False) + # noise = sample_gaussian_like(x) # could also use rademacher noise (sample_rademacher_like) + with torch.set_grad_enabled(False): + with autocast(enabled=enable_autocast): + variance = self.var(t=t) + mixing_component = self.mixing_component( + x_noisy=x, var_t=variance, t=t, enabled=dae.mixed_prediction) + pred_params = dae( + x=x, t=t, condition_input=condition_input, clip_feat=clip_feat) + # Warning: here mixing_logit can be NOne + params = get_mixed_prediction( + dae.mixed_prediction, pred_params, dae.mixing_logit, mixing_component) + dx_dt = self.f(t=t) * x + 0.5 * self.g2(t=t) * \ + params / torch.sqrt(variance) + # dx_dt = - 0.5 * self.g2(t=t) * (x - params / torch.sqrt(variance)) + + # with autocast(enabled=False): + # dlogp_x_dt = -trace_df_dx_hutchinson(dx_dt, x, noise, no_autograd).view(x.shape[0], 1) + + return dx_dt + + # NFE counter + global nfe_counter + + nll_all, nfe_all = [], [] + for i in range(num_samples): + # integrated log probability + # logp_diff_t0 = torch.zeros(eps.shape[0], 1, device='cuda') + + nfe_counter = 0 + + # solve the ODE + x_t = odeint( + ode_func, + eps, + torch.tensor([ode_eps, 1.0], device='cuda'), + atol=ode_solver_tol, # 1e-5 + rtol=ode_solver_tol, # 1e-5 + # 'dopri5' or 'dopri8' methods also seems good. + method="scipy_solver", + options={"solver": 'RK45'}, # only for scipy solvers + ) + # last output values + x_t0 = x_t[-1] + ## x_t0, logp_diff_t0 = x_t[-1], logp_diff_t[-1] + + # prior + # if self.sde_type == 'vesde': + # logp_prior = torch.sum(distributions.log_p_var_normal(x_t0, var=self.sigma2_max), dim=[1, 2, 3]) + # else: + # logp_prior = torch.sum(distributions.log_p_standard_normal(x_t0), dim=[1, 2, 3]) + + #log_likelihood = logp_prior - logp_diff_t0.view(-1) + + # nll_all.append(-log_likelihood) + nfe_all.append(nfe_counter) + print('nfe_counter: ', nfe_counter) + + #nfe_mean = np.mean(nfe_all) + ##nll_all = torch.stack(nll_all, dim=1) + #nll_mean = torch.mean(nll_all, dim=1) + # if num_samples > 1 and report_std: + # nll_stddev = torch.std(nll_all ,dii=1) + # nll_stddev_batch = torch.mean(nll_stddev) + # nll_stderror_batch = nll_stddev_batch / np.sqrt(num_samples) + # else: + # nll_stddev_batch = None + # nll_stderror_batch = None + return x_t0 # nll_mean, nfe_mean, nll_stddev_batch, nll_stderror_batch + + def sample_model_ode(self, dae, num_samples, shape, ode_eps, + ode_solver_tol, enable_autocast, temp, noise=None, + condition_input=None, mixing_logit=None, + use_cust_ode_func=0, init_t=1.0, return_all_sample=False, clip_feat=None + ): + """ generates samples using the ODE framework, assuming integration cutoff ode_eps """ + # ODE solver starts consuming the CPU memory without this on large models + # https://github.com/scipy/scipy/issues/10070 + gc.collect() + + dae.eval() + + def cust_ode_func(t, x): + """ the ode function (sampling only, no NLL stuff) """ + global nfe_counter + nfe_counter = nfe_counter + 1 + if nfe_counter % 100 == 0: + logger.info('nfe_counter={}', nfe_counter) + with autocast(enabled=enable_autocast): + variance = self.var(t=t) + params = dae(x, x, t, condition_input=condition_input) + dx_dt = self.f(t=t) * x + 0.5 * self.g2(t=t) * \ + params / torch.sqrt(variance) + # dx_dt = - 0.5 * self.g2(t=t) * (x - params / torch.sqrt(variance)) + + return dx_dt + + def ode_func(t, x): + """ the ode function (sampling only, no NLL stuff) """ + global nfe_counter + nfe_counter = nfe_counter + 1 + if nfe_counter % 100 == 0: + logger.info('nfe_counter={}', nfe_counter) + with autocast(enabled=enable_autocast): + variance = self.var(t=t) + mixing_component = self.mixing_component( + x_noisy=x, var_t=variance, t=t, enabled=dae.mixed_prediction) + pred_params = dae( + x=x, t=t, condition_input=condition_input, clip_feat=clip_feat) + input_mixing_logit = mixing_logit if mixing_logit is not None else dae.mixing_logit + params = get_mixed_prediction( + dae.mixed_prediction, pred_params, input_mixing_logit, mixing_component) + dx_dt = self.f(t=t) * x + 0.5 * self.g2(t=t) * \ + params / torch.sqrt(variance) + # dx_dt = - 0.5 * self.g2(t=t) * (x - params / torch.sqrt(variance)) + + return dx_dt + + # the initial noise + if noise is None: + noise = torch.randn(size=[num_samples] + shape, device='cuda') + + if self.sde_type == 'vesde': + noise_init = temp * noise * np.sqrt(self.sigma2_max) + else: + noise_init = temp * noise + + # NFE counter + global nfe_counter + nfe_counter = 0 + + # solve the ODE + start = timer() + samples_out = odeint( + ode_func if not use_cust_ode_func else cust_ode_func, + noise_init, + torch.tensor([init_t, ode_eps], device='cuda'), + atol=ode_solver_tol, # 1e-5 + rtol=ode_solver_tol, # 1e-5 + # 'dopri5' or 'dopri8' methods also seems good. + method="scipy_solver", + options={"solver": 'RK45'}, # only for scipy solvers + ) + end = timer() + ode_solve_time = end - start + if return_all_sample: + return samples_out[-1], samples_out, nfe_counter, ode_solve_time + return samples_out[-1], nfe_counter, ode_solve_time + + # def compute_dsm_nll(self, dae, eps, time_eps, enable_autocast, num_samples, report_std): + # with torch.no_grad(): + # neg_log_prob_all = [] + # for i in range(num_samples): + # assert self.sde_type in ['geometric_sde', 'vpsde', 'power_vpsde'], "we don't support subVPSDE yet." + # t, var_t, m_t, obj_weight_t, _, _ = \ + # self.iw_quantities(eps.shape[0], time_eps, iw_sample_mode='ll_iw', iw_subvp_like_vp_sde=False) + + # noise = torch.randn(size=eps.size(), device='cuda') + # eps_t = self.sample_q(eps, noise, var_t, m_t) + # mixing_component = self.mixing_component(eps_t, var_t, t, enabled=dae.mixed_prediction) + # with autocast(enabled=enable_autocast): + # pred_params = dae(eps_t, t) + # params = get_mixed_prediction(dae.mixed_prediction, pred_params, dae.mixing_logit, mixing_component) + # l2_term = torch.square(params - noise) + + # neg_log_prob_per_var = obj_weight_t * l2_term + # neg_log_prob_per_var += self.cross_entropy_const(time_eps) + # neg_log_prob = torch.sum(neg_log_prob_per_var, dim=[1, 2, 3]) + + # neg_log_prob_all.append(neg_log_prob) + + # neg_log_prob_all = torch.stack(neg_log_prob_all, dim=1) + # nll = torch.mean(neg_log_prob_all, dim=1) + # if num_samples > 1 and report_std: + # nll_std = torch.std(neg_log_prob_all, dim=1) + # print('std nll:', nll_std) + + # return nll + + def iw_quantities(self, size, time_eps, iw_sample_mode, iw_subvp_like_vp_sde): + if self.sde_type in ['geometric_sde', 'vpsde', 'power_vpsde']: + return self._iw_quantities_vpsdelike(size, time_eps, iw_sample_mode) + elif self.sde_type in ['sub_vpsde', 'sub_power_vpsde']: + return self._iw_quantities_subvpsdelike(size, time_eps, iw_sample_mode, iw_subvp_like_vp_sde) + elif self.sde_type in ['vesde']: + return self._iw_quantities_vesde(size, time_eps, iw_sample_mode) + else: + raise NotImplementedError + + def debug_sheduler(self, time_eps): + # time_eps, 1-time_eps, 1000) ##-1) / 1000.0 + time_eps + t = torch.linspace(0, 1, 1000) + t = torch.range(1, 1000) / 1000.0 + t = t.cuda() + var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + obj_weight_t_p = torch.ones(1, device='cuda') + obj_weight_t_q = g2_t / (2.0 * var_t) + return t, var_t.view(-1, 1, 1, 1), m_t.view(-1, 1, 1, 1), \ + obj_weight_t_p.view(-1, 1, 1, 1), \ + obj_weight_t_q.view(-1, 1, 1, 1), g2_t.view(-1, 1, 1, 1) + + def _iw_quantities_vpsdelike(self, size, time_eps, iw_sample_mode): + """ + For all SDEs where the underlying SDE is of the form dz = -0.5 * beta(t) * z * dt + sqrt{beta(t)} * dw, like + for the VPSDE. + """ + rho = torch.rand(size=[size], device='cuda') + + if iw_sample_mode == 'll_uniform': + # uniform t sampling - likelihood obj. for both q and p + t = rho * (1. - time_eps) + time_eps + var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + obj_weight_t_p = obj_weight_t_q = g2_t / (2.0 * var_t) + + elif iw_sample_mode == 'll_iw': + # importance sampling for likelihood obj. - likelihood obj. for both q and p + ones = torch.ones_like(rho, device='cuda') + sigma2_1, sigma2_eps = self.var(ones), self.var(time_eps * ones) + log_sigma2_1, log_sigma2_eps = torch.log( + sigma2_1), torch.log(sigma2_eps) + var_t = torch.exp(rho * log_sigma2_1 + (1 - rho) * log_sigma2_eps) + t = self.inv_var(var_t) + m_t, g2_t = self.e2int_f(t), self.g2(t) + obj_weight_t_p = obj_weight_t_q = 0.5 * \ + (log_sigma2_1 - log_sigma2_eps) / (1.0 - var_t) + + elif iw_sample_mode == 'drop_all_uniform': + # uniform t sampling - likelihood obj. for q, all-prefactors-dropped obj. for p + t = rho * (1. - time_eps) + time_eps + var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + obj_weight_t_p = torch.ones(1, device='cuda') + obj_weight_t_q = g2_t / (2.0 * var_t) + + elif iw_sample_mode == 'drop_all_iw': + # importance sampling for all-pref.-dropped obj. - likelihood obj. for q, all-pref.-dropped obj. for p + assert self.sde_type == 'vpsde', 'Importance sampling for fully unweighted objective is currently only ' \ + 'implemented for the regular VPSDE.' + t = torch.sqrt(1.0 / self.delta_beta_half) * torch.erfinv(rho * + self.const_norm_2 + self.const_erf) - self.beta_frac + var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + obj_weight_t_p = self.const_norm / (1.0 - var_t) + obj_weight_t_q = obj_weight_t_p * g2_t / (2.0 * var_t) + + elif iw_sample_mode == 'drop_sigma2t_iw': + # importance sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p + ones = torch.ones_like(rho, device='cuda') + sigma2_1, sigma2_eps = self.var(ones), self.var(time_eps * ones) + var_t = rho * sigma2_1 + (1 - rho) * sigma2_eps + t = self.inv_var(var_t) + m_t, g2_t = self.e2int_f(t), self.g2(t) + obj_weight_t_p = 0.5 * (sigma2_1 - sigma2_eps) / (1.0 - var_t) + obj_weight_t_q = obj_weight_t_p / var_t + + elif iw_sample_mode == 'drop_sigma2t_uniform': + # uniform sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p + t = rho * (1. - time_eps) + time_eps + var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + obj_weight_t_p = g2_t / 2.0 + obj_weight_t_q = g2_t / (2.0 * var_t) + + elif iw_sample_mode == 'rescale_iw': + # importance sampling for 1/(1-sigma2_t) resc. obj. - likelihood obj. for q, 1/(1-sigma2_t) resc. obj. for p + t = rho * (1. - time_eps) + time_eps + var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + obj_weight_t_p = 0.5 / (1.0 - var_t) + obj_weight_t_q = g2_t / (2.0 * var_t) + + else: + raise ValueError( + "Unrecognized importance sampling type: {}".format(iw_sample_mode)) + + return t, var_t.view(-1, 1, 1, 1), m_t.view(-1, 1, 1, 1), obj_weight_t_p.view(-1, 1, 1, 1), \ + obj_weight_t_q.view(-1, 1, 1, 1), g2_t.view(-1, 1, 1, 1) + + # def _iw_quantities_subvpsdelike(self, size, time_eps, iw_sample_mode, iw_subvp_like_vp_sde): + # """ + # For all SDEs where the underlying SDE is of the form + # dz = -0.5 * beta(t) * z * dt + sqrt{beta(t) * (1 - exp[-2 * betaintegral])} * dw, like for the Sub-VPSDE. + # When iw_subvp_like_vp_sde is True, then we define the importance sampling distributions based on an analogous + # VPSDE, while stile using the Sub-VPSDE. The motivation is that deriving the correct importance sampling + # distributions for the Sub-VPSDE itself is hard, but the importance sampling distributions from analogous VPSDEs + # probably already significantly reduce the variance also for the Sub-VPSDE. + # """ + # rho = torch.rand(size=[size], device='cuda') + + # if iw_sample_mode == 'll_uniform': + # # uniform t sampling - likelihood obj. for both q and p + # t = rho * (1. - time_eps) + time_eps + # var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + # obj_weight_t_p = obj_weight_t_q = g2_t / (2.0 * var_t) + + # elif iw_sample_mode == 'll_iw': + # if iw_subvp_like_vp_sde: + # # importance sampling for vpsde likelihood obj. - sub-vpsde likelihood obj. for both q and p + # ones = torch.ones_like(rho, device='cuda') + # sigma2_1, sigma2_eps = self.var_vpsde(ones), self.var_vpsde(time_eps * ones) + # log_sigma2_1, log_sigma2_eps = torch.log(sigma2_1), torch.log(sigma2_eps) + # var_t_vpsde = torch.exp(rho * log_sigma2_1 + (1 - rho) * log_sigma2_eps) + # t = self.inv_var_vpsde(var_t_vpsde) + # var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + # obj_weight_t_p = obj_weight_t_q = g2_t / (2.0 * var_t) * \ + # (log_sigma2_1 - log_sigma2_eps) * var_t_vpsde / (1 - var_t_vpsde) / self.beta(t) + # else: + # raise NotImplementedError + + # elif iw_sample_mode == 'drop_all_uniform': + # # uniform t sampling - likelihood obj. for q, all-prefactors-dropped obj. for p + # t = rho * (1. - time_eps) + time_eps + # var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + # obj_weight_t_p = torch.ones(1, device='cuda') + # obj_weight_t_q = g2_t / (2.0 * var_t) + + # elif iw_sample_mode == 'drop_all_iw': + # if iw_subvp_like_vp_sde: + # # importance sampling for all-pref.-dropped obj. - likelihood obj. for q, all-pref.-dropped obj. for p + # assert self.sde_type == 'sub_vpsde', 'Importance sampling for fully unweighted objective is ' \ + # 'currently only implemented for the Sub-VPSDE.' + # t = torch.sqrt(1.0 / self.delta_beta_half) * torch.erfinv(rho * self.const_norm_2 + self.const_erf) - self.beta_frac + # var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + # obj_weight_t_p = self.const_norm / (1.0 - self.var_vpsde(t)) + # obj_weight_t_q = obj_weight_t_p * g2_t / (2.0 * var_t) + # else: + # raise NotImplementedError + + # elif iw_sample_mode == 'drop_sigma2t_iw': + # if iw_subvp_like_vp_sde: + # # importance sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p + # ones = torch.ones_like(rho, device='cuda') + # sigma2_1, sigma2_eps = self.var_vpsde(ones), self.var_vpsde(time_eps * ones) + # var_t_vpsde = rho * sigma2_1 + (1 - rho) * sigma2_eps + # t = self.inv_var_vpsde(var_t_vpsde) + # var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + # obj_weight_t_p = 0.5 * g2_t / self.beta(t) * (sigma2_1 - sigma2_eps) / (1.0 - var_t_vpsde) + # obj_weight_t_q = obj_weight_t_p / var_t + # else: + # raise NotImplementedError + + # elif iw_sample_mode == 'drop_sigma2t_uniform': + # # uniform sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p + # t = rho * (1. - time_eps) + time_eps + # var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + # obj_weight_t_p = g2_t / 2.0 + # obj_weight_t_q = g2_t / (2.0 * var_t) + + # elif iw_sample_mode == 'rescale_iw': + # # importance sampling for 1/(1-sigma2_t) resc. obj. - likelihood obj. for q, 1/(1-sigma2_t) resc. obj. for p + # # Note that we use the sub-vpsde variance to scale the p objective! It's not clear what's optimal here! + # t = rho * (1. - time_eps) + time_eps + # var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + # obj_weight_t_p = 0.5 / (1.0 - var_t) + # obj_weight_t_q = g2_t / (2.0 * var_t) + + # else: + # raise ValueError("Unrecognized importance sampling type: {}".format(iw_sample_mode)) + + # return t, var_t.view(-1, 1, 1, 1), m_t.view(-1, 1, 1, 1), obj_weight_t_p.view(-1, 1, 1, 1), \ + # obj_weight_t_q.view(-1, 1, 1, 1), g2_t.view(-1, 1, 1, 1) + + # def _iw_quantities_vesde(self, size, time_eps, iw_sample_mode): + # """ + # For the VESDE. + # """ + # rho = torch.rand(size=[size], device='cuda') + + # if iw_sample_mode == 'll_uniform': + # # uniform t sampling - likelihood obj. for both q and p + # t = rho * (1. - time_eps) + time_eps + # var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + # obj_weight_t_p = obj_weight_t_q = g2_t / (2.0 * var_t) + + # elif iw_sample_mode == 'll_iw': + # # importance sampling for likelihood obj. - likelihood obj. for both q and p + # ones = torch.ones_like(rho, device='cuda') + # nsigma2_1, nsigma2_eps, sigma2_eps = self.var_N(ones), self.var_N(time_eps * ones), self.var(time_eps * ones) + # log_frac_sigma2_1, log_frac_sigma2_eps = torch.log(self.sigma2_max / nsigma2_1), torch.log(nsigma2_eps / sigma2_eps) + # var_N_t = (1.0 - self.sigma2_min) / (1.0 - torch.exp(rho * (log_frac_sigma2_1 + log_frac_sigma2_eps) - log_frac_sigma2_eps)) + # t = self.inv_var_N(var_N_t) + # var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + # obj_weight_t_p = obj_weight_t_q = 0.5 * (log_frac_sigma2_1 + log_frac_sigma2_eps) * self.var_N(t) / (1.0 - self.sigma2_min) + + # elif iw_sample_mode == 'drop_all_uniform': + # # uniform t sampling - likelihood obj. for q, all-prefactors-dropped obj. for p + # t = rho * (1. - time_eps) + time_eps + # var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + # obj_weight_t_p = torch.ones(1, device='cuda') + # obj_weight_t_q = g2_t / (2.0 * var_t) + + # elif iw_sample_mode == 'drop_all_iw': + # # importance sampling for all-pref.-dropped obj. - likelihood obj. for q, all-pref.-dropped obj. for p + # ones = torch.ones_like(rho, device='cuda') + # nsigma2_1, nsigma2_eps, sigma2_eps = self.var_N(ones), self.var_N(time_eps * ones), self.var(time_eps * ones) + # log_frac_sigma2_1, log_frac_sigma2_eps = torch.log(self.sigma2_max / nsigma2_1), torch.log(nsigma2_eps / sigma2_eps) + # var_N_t = (1.0 - self.sigma2_min) / (1.0 - torch.exp(rho * (log_frac_sigma2_1 + log_frac_sigma2_eps) - log_frac_sigma2_eps)) + # t = self.inv_var_N(var_N_t) + # var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + # obj_weight_t_q = 0.5 * (log_frac_sigma2_1 + log_frac_sigma2_eps) * self.var_N(t) / (1.0 - self.sigma2_min) + # obj_weight_t_p = 2.0 * obj_weight_t_q / np.log(self.sigma2_max / self.sigma2_min) + + # elif iw_sample_mode == 'drop_sigma2t_iw': + # # importance sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p + # ones = torch.ones_like(rho, device='cuda') + # nsigma2_1, nsigma2_eps = self.var_N(ones), self.var_N(time_eps * ones) + # var_N_t = torch.exp(rho * torch.log(nsigma2_1) + (1 - rho) * torch.log(nsigma2_eps)) + # t = self.inv_var_N(var_N_t) + # var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + # obj_weight_t_p = 0.5 * torch.log(nsigma2_1 / nsigma2_eps) * self.var_N(t) + # obj_weight_t_q = obj_weight_t_p / var_t + + # elif iw_sample_mode == 'drop_sigma2t_uniform': + # # uniform sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p + # t = rho * (1. - time_eps) + time_eps + # var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + # obj_weight_t_p = g2_t / 2.0 + # obj_weight_t_q = g2_t / (2.0 * var_t) + + # elif iw_sample_mode == 'rescale_iw': + # # uniform sampling for 1/(1-sigma2_t) resc. obj. - likelihood obj. for q, 1/(1-sigma2_t) resc. obj. for p + # t = rho * (1. - time_eps) + time_eps + # var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t) + # obj_weight_t_p = 0.5 / (1.0 - var_t) + # obj_weight_t_q = g2_t / (2.0 * var_t) + + # else: + # raise ValueError("Unrecognized importance sampling type: {}".format(iw_sample_mode)) + + # return t, var_t.view(-1, 1, 1, 1), m_t.view(-1, 1, 1, 1), obj_weight_t_p.view(-1, 1, 1, 1), \ + # obj_weight_t_q.view(-1, 1, 1, 1), g2_t.view(-1, 1, 1, 1) + + +# class DiffusionGeometric(DiffusionBase): +# """ +# Diffusion implementation with dz = -0.5 * beta(t) * z * dt + sqrt(beta(t)) * dW SDE and geometric progression of +# variance. This is our new diffusion. +# """ +# def __init__(self, args): +# super().__init__(args) +# self.sigma2_min = args.sigma2_min +# self.sigma2_max = args.sigma2_max +# +# def f(self, t): +# return -0.5 * self.g2(t) +# +# def g2(self, t): +# sigma2_geom = self.sigma2_min * ((self.sigma2_max / self.sigma2_min) ** t) +# log_term = np.log(self.sigma2_max / self.sigma2_min) +# return sigma2_geom * log_term / (1.0 - self.sigma2_0 + self.sigma2_min - sigma2_geom) +# +# def var(self, t): +# return self.sigma2_min * ((self.sigma2_max / self.sigma2_min) ** t) - self.sigma2_min + self.sigma2_0 +# +# def e2int_f(self, t): +# return torch.sqrt(1.0 + self.sigma2_min * (1.0 - (self.sigma2_max / self.sigma2_min) ** t) / (1.0 - self.sigma2_0)) +# +# def inv_var(self, var): +# return torch.log((var + self.sigma2_min - self.sigma2_0) / self.sigma2_min) / np.log(self.sigma2_max / self.sigma2_min) +# +# def mixing_component(self, x_noisy, var_t, t, enabled): +# if enabled: +# return torch.sqrt(var_t) * x_noisy +# else: +# return None +# + +class DiffusionVPSDE(DiffusionBase): + """ + Diffusion implementation of the VPSDE. This uses the same SDE like DiffusionGeometric but with linear beta(t). + Note that we need to scale beta_start and beta_end by 1000 relative to JH's DDPM values, since our t is in [0,1]. + """ + + def __init__(self, args): + super().__init__(args) + self.beta_start = args.beta_start + self.beta_end = args.beta_end + logger.info('VPSDE: beta_start={}, beta_end={}, sigma2_0={}', + self.beta_start, self.beta_end, self.sigma2_0) + # auxiliary constants (yes, this is not super clean...) + self.time_eps = args.time_eps + self.delta_beta_half = torch.tensor( + 0.5 * (self.beta_end - self.beta_start), device='cuda') + self.beta_frac = torch.tensor( + self.beta_start / (self.beta_end - self.beta_start), device='cuda') + self.const_aq = (1.0 - self.sigma2_0) * torch.exp(0.5 * + self.beta_frac) * torch.sqrt(0.25 * np.pi / self.delta_beta_half) + self.const_erf = torch.erf(torch.sqrt( + self.delta_beta_half) * (self.time_eps + self.beta_frac)) + self.const_norm = self.const_aq * \ + (torch.erf(torch.sqrt(self.delta_beta_half) + * (1.0 + self.beta_frac)) - self.const_erf) + self.const_norm_2 = torch.erf(torch.sqrt( + self.delta_beta_half) * (1.0 + self.beta_frac)) - self.const_erf + + def f(self, t): + return -0.5 * self.g2(t) + + def g2(self, t): + return self.beta_start + (self.beta_end - self.beta_start) * t + + def var(self, t): + return 1.0 - (1.0 - self.sigma2_0) * torch.exp(-self.beta_start * t - 0.5 * (self.beta_end - self.beta_start) * t * t) + + def e2int_f(self, t): + return torch.exp(-0.5 * self.beta_start * t - 0.25 * (self.beta_end - self.beta_start) * t * t) + + def inv_var(self, var): + c = torch.log((1 - var) / (1 - self.sigma2_0)) + a = self.beta_end - self.beta_start + t = (-self.beta_start + torch.sqrt(np.square(self.beta_start) - 2 * a * c)) / a + return t + + def mixing_component(self, x_noisy, var_t, t, enabled): + if enabled: + return torch.sqrt(var_t) * x_noisy + else: + return None + + +# class DiffusionSubVPSDE(DiffusionBase): +# """ +# Diffusion implementation of the sub-VPSDE. Note that this uses a different SDE compared to the above two diffusions. +# """ +# def __init__(self, args): +# super().__init__(args) +# self.beta_start = args.beta_start +# self.beta_end = args.beta_end +# +# # auxiliary constants (assumes regular VPSDE... yes, this is not super clean...) +# self.time_eps = args.time_eps +# self.delta_beta_half = torch.tensor(0.5 * (self.beta_end - self.beta_start), device='cuda') +# self.beta_frac = torch.tensor(self.beta_start / (self.beta_end - self.beta_start), device='cuda') +# self.const_aq = (1.0 - self.sigma2_0) * torch.exp(0.5 * self.beta_frac) * torch.sqrt(0.25 * np.pi / self.delta_beta_half) +# self.const_erf = torch.erf(torch.sqrt(self.delta_beta_half) * (self.time_eps + self.beta_frac)) +# self.const_norm = self.const_aq * (torch.erf(torch.sqrt(self.delta_beta_half) * (1.0 + self.beta_frac)) - self.const_erf) +# self.const_norm_2 = torch.erf(torch.sqrt(self.delta_beta_half) * (1.0 + self.beta_frac)) - self.const_erf +# +# def f(self, t): +# return -0.5 * self.beta(t) +# +# def g2(self, t): +# return self.beta(t) * (1.0 - torch.exp(-2.0 * self.beta_start * t - (self.beta_end - self.beta_start) * t * t)) +# +# def var(self, t): +# int_term = torch.exp(-self.beta_start * t - 0.5 * (self.beta_end - self.beta_start) * t * t) +# return torch.square(1.0 - int_term) + self.sigma2_0 * int_term +# +# def e2int_f(self, t): +# return torch.exp(-0.5 * self.beta_start * t - 0.25 * (self.beta_end - self.beta_start) * t * t) +# +# def beta(self, t): +# """ auxiliary beta function """ +# return self.beta_start + (self.beta_end - self.beta_start) * t +# +# def inv_var(self, var): +# raise NotImplementedError +# +# def mixing_component(self, x_noisy, var_t, t, enabled): +# if enabled: +# int_term = torch.exp(-self.beta_start * t - 0.5 * (self.beta_end - self.beta_start) * t * t).view(-1, 1, 1, 1) +# return torch.sqrt(var_t) * x_noisy / (torch.square(1.0 - int_term) + int_term) +# else: +# return None +# +# def var_vpsde(self, t): +# return 1.0 - (1.0 - self.sigma2_0) * torch.exp(-self.beta_start * t - 0.5 * (self.beta_end - self.beta_start) * t * t) +# +# def inv_var_vpsde(self, var): +# c = torch.log((1 - var) / (1 - self.sigma2_0)) +# a = self.beta_end - self.beta_start +# t = (-self.beta_start + torch.sqrt(np.square(self.beta_start) - 2 * a * c)) / a +# return t +# +# +# class DiffusionPowerVPSDE(DiffusionBase): +# """ +# Diffusion implementation of the power-VPSDE. This uses the same SDE like DiffusionGeometric but with beta function +# that is a power function with user specified power. Note that for power=1, this reproduces the vanilla +# DiffusionVPSDE above. +# """ +# def __init__(self, args): +# super().__init__(args) +# self.beta_start = args.beta_start +# self.beta_end = args.beta_end +# self.power = args.vpsde_power +# +# def f(self, t): +# return -0.5 * self.g2(t) +# +# def g2(self, t): +# return self.beta_start + (self.beta_end - self.beta_start) * t ** self.power +# +# def var(self, t): +# return 1.0 - (1.0 - self.sigma2_0) * torch.exp(-self.beta_start * t - (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0)) +# +# def e2int_f(self, t): +# return torch.exp(-0.5 * self.beta_start * t - 0.5 * (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0)) +# +# def inv_var(self, var): +# if self.power == 2: +# c = torch.log((1 - var) / (1 - self.sigma2_0)) +# p = 3.0 * self.beta_start / (self.beta_end - self.beta_start) +# q = 3.0 * c / (self.beta_end - self.beta_start) +# a = -0.5 * q + torch.sqrt(q ** 2 / 4.0 + p ** 3 / 27.0) +# b = -0.5 * q - torch.sqrt(q ** 2 / 4.0 + p ** 3 / 27.0) +# return torch.pow(a, 1.0 / 3.0) + torch.pow(b, 1.0 / 3.0) +# else: +# raise NotImplementedError +# +# def mixing_component(self, x_noisy, var_t, t, enabled): +# if enabled: +# return torch.sqrt(var_t) * x_noisy +# else: +# return None +# +# +# class DiffusionSubPowerVPSDE(DiffusionBase): +# """ +# Diffusion implementation of the sub-power-VPSDE. This uses the same SDE like DiffusionSubVPSDE but with beta +# function that is a power function with user specified power. Note that for power=1, this reproduces the vanilla +# DiffusionSubVPSDE above. +# """ +# def __init__(self, args): +# super().__init__(args) +# self.beta_start = args.beta_start +# self.beta_end = args.beta_end +# self.power = args.vpsde_power +# +# def f(self, t): +# return -0.5 * self.beta(t) +# +# def g2(self, t): +# return self.beta(t) * (1.0 - torch.exp(-2.0 * self.beta_start * t - 2.0 * (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0))) +# +# def var(self, t): +# int_term = torch.exp(-self.beta_start * t - (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0)) +# return torch.square(1.0 - int_term) + self.sigma2_0 * int_term +# +# def e2int_f(self, t): +# return torch.exp(-0.5 * self.beta_start * t - 0.5 * (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0)) +# +# def beta(self, t): +# """ internal auxiliary beta function """ +# return self.beta_start + (self.beta_end - self.beta_start) * t ** self.power +# +# def inv_var(self, var): +# raise NotImplementedError +# +# def mixing_component(self, x_noisy, var_t, t, enabled): +# if enabled: +# int_term = torch.exp(-self.beta_start * t - (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0)).view(-1, 1, 1, 1) +# return torch.sqrt(var_t) * x_noisy / (torch.square(1.0 - int_term) + int_term) +# else: +# return None +# +# def var_vpsde(self, t): +# return 1.0 - (1.0 - self.sigma2_0) * torch.exp(-self.beta_start * t - (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0)) +# +# def inv_var_vpsde(self, var): +# if self.power == 2: +# c = torch.log((1 - var) / (1 - self.sigma2_0)) +# p = 3.0 * self.beta_start / (self.beta_end - self.beta_start) +# q = 3.0 * c / (self.beta_end - self.beta_start) +# a = -0.5 * q + torch.sqrt(q ** 2 / 4.0 + p ** 3 / 27.0) +# b = -0.5 * q - torch.sqrt(q ** 2 / 4.0 + p ** 3 / 27.0) +# return torch.pow(a, 1.0 / 3.0) + torch.pow(b, 1.0 / 3.0) +# else: +# raise NotImplementedError +# +# +# class DiffusionVESDE(DiffusionBase): +# """ +# Diffusion implementation of the VESDE with dz = sqrt(beta(t)) * dW +# """ +# def __init__(self, args): +# super().__init__(args) +# self.sigma2_min = args.sigma2_min +# self.sigma2_max = args.sigma2_max +# assert self.sigma2_min == self.sigma2_0, "VESDE was proposed implicitly assuming sigma2_min = sigma2_0!" +# +# def f(self, t): +# return torch.zeros_like(t, device='cuda') +# +# def g2(self, t): +# return self.sigma2_min * np.log(self.sigma2_max / self.sigma2_min) * ((self.sigma2_max / self.sigma2_min) ** t) +# +# def var(self, t): +# return self.sigma2_min * ((self.sigma2_max / self.sigma2_min) ** t) - self.sigma2_min + self.sigma2_0 +# +# def e2int_f(self, t): +# return torch.ones_like(t, device='cuda') +# +# def inv_var(self, var): +# return torch.log((var + self.sigma2_min - self.sigma2_0) / self.sigma2_min) / np.log(self.sigma2_max / self.sigma2_min) +# +# def mixing_component(self, x_noisy, var_t, t, enabled): +# if enabled: +# return torch.sqrt(var_t) * x_noisy / (self.sigma2_min * ((self.sigma2_max / self.sigma2_min) ** t.view(-1, 1, 1, 1)) - self.sigma2_min + 1.0) +# else: +# return None +# +# def var_N(self, t): +# return 1.0 - self.sigma2_min + self.sigma2_min * ((self.sigma2_max / self.sigma2_min) ** t) +# +# def inv_var_N(self, var): +# return torch.log((var + self.sigma2_min - 1.0) / self.sigma2_min) / np.log(self.sigma2_max / self.sigma2_min) + + +if __name__ == "__main__": + import matplotlib.pyplot as plt + + class Foo: + def __init__(self): + self.sde_type = 'vpsde' + self.sigma2_0 = 0.01 + self.sigma2_min = 3e-5 + self.sigma2_max = 0.999 + self.beta_start = 0.1 + self.beta_end = 20 + + # A unit test to check the implementation of e2intf and var_t + diff = make_diffusion(Foo()) + + print(diff.inv_var(diff.var(torch.tensor(0.5)))) + exit() + + delta = 1e-8 + t = np.arange(start=0.001, stop=0.999, step=delta) + t = torch.tensor(t) + + f_t = diff.f(t) + e2intf = diff.e2int_f(t) + # compute finite differences for e2intf + grad_fd = (e2intf[1:] - e2intf[:-1]) / delta + grad_analytic = f_t[:-1] * e2intf[:-1] + print(torch.max(torch.abs(grad_fd - grad_analytic))) + + var_t = diff.var(t) + grad_fd = (var_t[1:] - var_t[:-1]) / delta + grad_analytic = (2 * f_t * var_t + diff.g2(t))[:-1] + print(torch.max(torch.abs(grad_fd - grad_analytic))) diff --git a/utils/diffusion_pvd.py b/utils/diffusion_pvd.py new file mode 100644 index 0000000..b7204dd --- /dev/null +++ b/utils/diffusion_pvd.py @@ -0,0 +1,563 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +"""copied and modified from https://github.com/NVlabs/LSGM/blob/5eae2f385c014f2250c3130152b6be711f6a3a5a/diffusion_discretized.py""" +import torch +from torch.cuda.amp import autocast +import numpy as np +from utils.diffusion import make_beta_schedule +from utils import utils +from loguru import logger + + +class DiffusionDiscretized(object): + """ + This class constructs the diffusion process and provides all related methods and constants. + """ + + def __init__(self, args, var_fun, cfg): # alpha_bars_fun + self.cfg = cfg + + self._diffusion_steps = cfg.ddpm.num_steps # args.diffusion_steps + self._denoising_stddevs = 'beta' # args.denoising_stddevs + #self._var_fun = var_fun + beta_start = cfg.ddpm.beta_1 + beta_end = cfg.ddpm.beta_T + mode = cfg.ddpm.sched_mode + num_steps = cfg.ddpm.num_steps + self.p2_gamma = cfg.ddpm.p2_gamma + self.p2_k = cfg.ddpm.p2_k + self.use_p2_weight = self.cfg.ddpm.use_p2_weight + + logger.info( + f'[Build Discrete Diffusion object] beta_start={beta_start}, beta_end={beta_end}, mode={mode}, num_steps={num_steps}') + self.betas = make_beta_schedule( + mode, beta_start, beta_end, num_steps).numpy() + self._betas_init, self._alphas, self._alpha_bars, self._betas_post_init, self.snr = \ + self._generate_base_constants( + diffusion_steps=self._diffusion_steps) + + def iw_quantities_t(self, B, timestep, *args): + timestep = timestep.view(B) + timestep = timestep + 1 # [1,T] + alpha_bars = torch.gather(self._alpha_bars, 0, timestep-1) # [0,T-1] + weight_init = alpha_bars_sqrt = torch.sqrt(alpha_bars) + weight_noise_power = 1.0 - alpha_bars + + weight_noise_power = weight_noise_power[:, None, None, None] + weight_init = weight_init[:, None, None, None] + if self.use_p2_weight: + p2_weight = torch.gather( + 1 / (self.p2_k + self.snr)**self.p2_gamma, 0, timestep-1).view(B) + loss_weight = p2_weight + else: + loss_weight = 1.0 + return timestep, weight_noise_power, weight_init, loss_weight, None, None + + def iw_quantities(self, B, *args): + rho = torch.rand(size=[B], device='cuda') * self._diffusion_steps + timestep = rho.type(torch.int64) # [0, T-1] + assert(timestep.max() <= self._diffusion_steps - + 1), f'get max at {timestep.max()}' + timestep = timestep + 1 # [1,T] + alpha_bars = torch.gather(self._alpha_bars, 0, timestep-1) # [0,T-1] + weight_init = alpha_bars_sqrt = torch.sqrt(alpha_bars) + weight_noise_power = 1.0 - alpha_bars + + weight_noise_power = weight_noise_power[:, None, None, None] + weight_init = weight_init[:, None, None, None] + if self.use_p2_weight: + p2_weight = torch.gather( + 1 / (self.p2_k + self.snr)**self.p2_gamma, 0, timestep-1).view(B) + loss_weight = p2_weight + else: + loss_weight = 1.0 + + return timestep, weight_noise_power, weight_init, loss_weight, None, None + + def debug_sheduler(self): + rho = torch.range(0, 1000-1).cuda() # / 1000.0 + time_eps + timestep = rho.type(torch.int64) # [0, T-1] + assert(timestep.max() <= self._diffusion_steps - + 1), f'get max at {timestep.max()}' + timestep = timestep + 1 # [1,T] + alpha_bars = torch.gather(self._alpha_bars, 0, timestep-1) # [0,T-1] + weight_init = alpha_bars_sqrt = torch.sqrt(alpha_bars) + weight_noise_power = 1.0 - alpha_bars + + weight_noise_power = weight_noise_power[:, None, None, None] + weight_init = weight_init[:, None, None, None] + return timestep, weight_noise_power, weight_init, 1, None, None + + def sample_q(self, x_init, noise, var_t, m_t): + """ returns a sample from diffusion process at time t + x_init: [B,ND,1,1] + noise: + vae_t: weight noise; [B,1,1,1] + m_t: weight init; [B,1,1,1] + """ + assert(len(x_init.shape) == 4) + assert(len(var_t.shape) == 4) + assert(len(m_t.shape) == 4) + #CHECK4D(x_init) + #CHECK4D(var_t) + #CHECK4D(m_t) + #CHECKEQ(x_init.shape[0], m_t.shape[0]) + assert(x_init.shape[0] == m_t.shape[0]) + output = m_t * x_init + torch.sqrt(var_t) * noise + + return output + + def cross_entropy_const(self, ode_eps): + return 0 + + def _generate_base_constants(self, diffusion_steps): + """ + Generates torch tensors with basic constants for all timesteps. + """ + betas_np = self.betas # self._generate_betas_from_continuous_fun(diffusion_steps) + + alphas_np = 1.0 - betas_np + alphas_cumprod = alpha_bars_np = np.cumprod(alphas_np) + snr = 1.0 / (1 - alphas_cumprod) - 1 + + # posterior variances only make sense for t>1, hence the array is short by 1 + betas_post_np = betas_np[1:] * \ + (1.0 - alpha_bars_np[:-1]) / (1.0 - alpha_bars_np[1:]) + # we add beta_post_2 to the beginning of both beta arrays, since this is used as final decoder variance and + # requires special treatment (as in diffusion paper) + betas_post_init_np = np.append(betas_post_np[0], betas_post_np) + #betas_init_np = np.append(betas_post_np[0], betas_np[1:]) + + betas_init = torch.from_numpy(betas_np).float().cuda() + snr = torch.from_numpy(snr).float().cuda() + alphas = torch.from_numpy(alphas_np).float().cuda() + alpha_bars = torch.from_numpy(alpha_bars_np).float().cuda() + betas_post_init = torch.from_numpy(betas_post_init_np).float().cuda() + + return betas_init, alphas, alpha_bars, betas_post_init, snr + + # def _generate_betas_from_continuous_fun(self, diffusion_steps): + # t = np.arange(1, diffusion_steps + 1, dtype=np.float64) + # t = t / diffusion_steps + + # # alpha_bars = self._alpha_bars_fun(t) + # alpha_bars = 1.0 - self._var_fun(torch.tensor(t)).numpy() + # betas = 1 - alpha_bars[1:] / alpha_bars[:-1] + # betas = np.hstack((1 - alpha_bars[0], betas)) + + # return betas + + def get_p_log_scales(self, timestep, stddev_type): + """ + Grab log std devs. of backward denoising process p, if we decide to fix them. + """ + if stddev_type == 'beta': + # use diffusion variances, except for t=1, for which we use posterior variance beta_post_2 + return 0.5 * torch.log(torch.gather(self._betas_init, 0, timestep-1)) + elif stddev_type == 'beta_post': + # use diffusion posterior variances, except for t=1, for which there is no posterior, so we use beta_post_2 + return 0.5 * torch.log(torch.gather(self._betas_post_init, 0, timestep-1)) + elif stddev_type == 'learn': + return None + else: + raise ValueError('Unknown stddev_type: {}'.format(stddev_type)) + # @torch.no_grad() + # def debug_run_denoising_diffusion(self, model, num_samples, shape, x_noisy, timestep, + # temp=1.0, enable_autocast=False, is_image=False, prior_var=1.0, + # condition_input=None): + # """ + # Run the full denoising sampling loop. + # """ + # # set model to eval mode + # # initialize sample + # #x_noisy_size = [num_samples] + shape + # #x_noisy = torch.randn(size=x_noisy_size, device='cuda') ## * np.sqrt(prior_var) * temp + # model.eval() + # x_noisy_size = x_noisy.shape + + # x_noisy = x_noisy[0:1].expand(x_noisy.shape[0],-1,-1,-1) # + # timestep_start = timestep[0].item() + # output_list = [] + # output_pred_list = [] + # logger.info('timestep_start: {}', timestep_start) + # # denoising loop + # for t in reversed(range(0, self._diffusion_steps)): + # if t > timestep_start: + # continue + # if t % 100 == 0: + # logger.info('t={}', t) + # timestep = torch.ones(num_samples, dtype=torch.int64, device='cuda') * (t+1) # the model uses (1 ... T) without 0 + # fixed_log_scales = self.get_p_log_scales(timestep=timestep, stddev_type=self._denoising_stddevs) + # mixing_component = self.get_mixing_component(x_noisy, timestep, enabled=model.mixed_prediction) + + # # run model + # with autocast(enable_autocast): + # pred_logits = model(x=x_noisy, t=timestep.float() , condition_input=condition_input) + # # pred_logits = model(x_noisy, timestep.float() / self._diffusion_steps) + # logits = utils.get_mixed_prediction(model.mixed_prediction, pred_logits, model.mixing_logit, mixing_component) + + # output_dist = utils.decoder_output('place_holder', logits, fixed_log_scales=fixed_log_scales) + # noise = torch.randn(size=x_noisy_size, device='cuda') + # mean = self.get_q_posterior_mean(x_noisy, output_dist.means, t) + + # _, var_t_p, m_t_p, _, _, _ = self.iw_quantities_t( + # num_samples, timestep) + # pred_eps_t0 = (x_noisy - torch.sqrt(var_t_p) * pred_logits) / m_t_p + # if t == 0: + # x_image = mean + # else: + # x_noisy = mean + torch.exp(output_dist.log_scales) * noise * temp + # output_list.append(x_noisy) + # output_pred_list.append(pred_eps_t0) + # if is_image: + # x_image = x_image.clamp(min=-1., max=1.) + # x_image = utils.unsymmetrize_image_data(x_image) + # model.train() + # return x_image, output_list, output_pred_list + + @torch.no_grad() + def run_denoising_diffusion(self, model, num_samples, shape, temp=1.0, + enable_autocast=False, is_image=False, prior_var=1.0, + condition_input=None, given_noise=None, clip_feat=None, cls_emb=None, grid_emb=None): + """ + Run the full denoising sampling loop. + """ + # set model to eval mode + model.eval() + + # initialize sample + x_noisy_size = [num_samples] + shape + if given_noise is None: + # * np.sqrt(prior_var) * temp + x_noisy = torch.randn(size=x_noisy_size, device='cuda') + else: + x_noisy = given_noise[0] + output_list = {} + output_list['pred_x'] = [] + # output_list['init_x_noisy'] = x_noisy + # output_list['input_x'] = [] + # output_list['input_t'] = [] + # output_list['output_e'] = [] + # output_list['noise_t'] = [] + # output_list['condition_input'] = [] + # denoising loop + kwargs = {} + if grid_emb is not None: + kwargs['grid_emb'] = grid_emb + for t in reversed(range(0, self._diffusion_steps)): + if t % 500 == 0: + logger.info('t={}; shape={}, num_samples={}, sample shape: {}', + t, shape, num_samples, x_noisy.shape) + # the model uses (1 ... T) without 0 + timestep = torch.ones( + num_samples, dtype=torch.int64, device='cuda') * (t+1) + fixed_log_scales = self.get_p_log_scales( + timestep=timestep, stddev_type=self._denoising_stddevs) + mixing_component = self.get_mixing_component( + x_noisy, timestep, enabled=model.mixed_prediction) + + # run model + with autocast(enable_autocast): + if cls_emb is not None and condition_input is not None: + condition_input = torch.cat( + [condition_input, cls_emb], dim=1) + elif cls_emb is not None and condition_input is None: + condition_input = cls_emb + # output_list['input_x'].append(x_noisy) + # output_list['input_t'].append(timestep) + # output_list['condition_input'].append(condition_input) + + pred_logits = model(x=x_noisy, t=timestep.float(), + condition_input=condition_input, clip_feat=clip_feat, **kwargs) + # output_list['output_e'].append(pred_logits) + + # pred_logits = model(x_noisy, timestep.float() / self._diffusion_steps) + logits = utils.get_mixed_prediction( + model.mixed_prediction, pred_logits, model.mixing_logit, mixing_component) + + output_dist = utils.decoder_output( + 'place_holder', logits, fixed_log_scales=fixed_log_scales) + if given_noise is None: + noise = torch.randn(size=x_noisy_size, device='cuda') + else: + # torch.randn(size=x_noisy_size, device='cuda') + noise = given_noise[1][t] + + mean = self.get_q_posterior_mean(x_noisy, output_dist.means, t) + if t == 0: + x_image = mean + else: + x_noisy = mean + \ + torch.exp(output_dist.log_scales) * noise * temp + # output_list['noise_t'].append(noise) + output_list['pred_x'].append(x_noisy) + if is_image: + x_image = x_image.clamp(min=-1., max=1.) + x_image = utils.unsymmetrize_image_data(x_image) + model.train() + return x_image, output_list + + def run_ddim_forward(self, dae, eps, ddim_step, ddim_skip_type, condition_input=None, clip_feat=None): + ## raise NotImplementedError + """ calculates NLL based on ODE framework, assuming integration cutoff ode_eps """ + model.eval() + + # initialize sample + x_noisy_size = [num_samples] + shape + x_noisy = torch.randn( + size=x_noisy_size, device='cuda') if x_noisy is None else x_noisy.cuda() + output_list = [] + S = ddim_step + + # even spaced t + if skip_type == 'uniform': + c = (self._diffusion_steps - 1.0) / (S - 1.0) + list_tau = [np.floor(i * c) for i in range(S)] + list_tau = [int(s) for s in list_tau] + elif skip_type == 'quad': + seq = (np.linspace( + 0, np.sqrt(self._diffusion_steps * 0.8), S + ) ** 2 + ) + list_tau = [int(s) for s in list(seq)] + + user_defined_steps = sorted(list(list_tau), reverse=True) + T_user = len(user_defined_steps) + kwargs = {} + if grid_emb is not None: + kwargs['grid_emb'] = grid_emb + + def ode_func(t, x): + """ the ode function (including log probability integration for NLL calculation) """ + global nfe_counter + nfe_counter = nfe_counter + 1 + + x = x.detach() + x.requires_grad_(False) + with torch.set_grad_enabled(False): + with autocast(enabled=enable_autocast): + variance = self.var(t=t) + mixing_component = self.mixing_component( + x_noisy=x, var_t=variance, t=t, enabled=dae.mixed_prediction) + pred_params = dae( + x=x, t=t, condition_input=condition_input, clip_feat=clip_feat) + # Warning: here mixing_logit can be NOne + params = get_mixed_prediction( + dae.mixed_prediction, pred_params, dae.mixing_logit, mixing_component) + dx_dt = self.f(t=t) * x + 0.5 * self.g2(t=t) * \ + params / torch.sqrt(variance) + # dx_dt = - 0.5 * self.g2(t=t) * (x - params / torch.sqrt(variance)) + + # with autocast(enabled=False): + # dlogp_x_dt = -trace_df_dx_hutchinson(dx_dt, x, noise, no_autograd).view(x.shape[0], 1) + + return dx_dt + + # NFE counter + global nfe_counter + + nll_all, nfe_all = [], [] + for i in range(num_samples): + # integrated log probability + # logp_diff_t0 = torch.zeros(eps.shape[0], 1, device='cuda') + + nfe_counter = 0 + + # solve the ODE + x_t = odeint( + ode_func, + eps, + torch.tensor([ode_eps, 1.0], device='cuda'), + atol=ode_solver_tol, # 1e-5 + rtol=ode_solver_tol, # 1e-5 + # 'dopri5' or 'dopri8' methods also seems good. + method="scipy_solver", + options={"solver": 'RK45'}, # only for scipy solvers + ) + x_t0 = x_t[-1] + + nfe_all.append(nfe_counter) + print('nfe_counter: ', nfe_counter) + + return x_t0 + + @torch.no_grad() + def run_ddim(self, model, num_samples, shape, temp=1.0, enable_autocast=False, is_image=True, prior_var=1.0, + condition_input=None, ddim_step=100, skip_type='uniform', kappa=1.0, clip_feat=None, grid_emb=None, + x_noisy=None, dae_index=-1): + """ + Run the full denoising sampling loop. + kappa = 1.0 # this one is the eta in DDIM algorithm + """ + # set model to eval mode + model.eval() + + # initialize sample + x_noisy_size = [num_samples] + shape + x_noisy = torch.randn( + size=x_noisy_size, device='cuda') if x_noisy is None else x_noisy.cuda() + output_list = [] + S = ddim_step + + # even spaced t + if skip_type == 'uniform': + c = (self._diffusion_steps - 1.0) / (S - 1.0) + list_tau = [np.floor(i * c) for i in range(S)] + list_tau = [int(s) for s in list_tau] + elif skip_type == 'quad': + seq = (np.linspace( + 0, np.sqrt(self._diffusion_steps * 0.8), S + ) ** 2 + ) + list_tau = [int(s) for s in list(seq)] + + user_defined_steps = sorted(list(list_tau), reverse=True) + T_user = len(user_defined_steps) + kwargs = {} + if grid_emb is not None: + kwargs['grid_emb'] = grid_emb + # denoising loop + # for t in user_defined_steps: ## reversed(range(0, self._diffusion_steps)): + Alpha_bar = self._alpha_bars # self.var_sched.alphas_cumprod + # the following computation is the same as the function in https://github.com/ermongroup/ddim/blob/51cb290f83049e5381b09a4cc0389f16a4a02cc9/functions/denoising.py#L10 + for i, t in enumerate(user_defined_steps): + if i % 500 == 0: + logger.info('t={} / {}, ori={}', i, S, self._diffusion_steps) + tau = t + # the model uses (1 ... T) without 0 + timestep = torch.ones( + num_samples, dtype=torch.int64, device='cuda') * (t+1) + fixed_log_scales = self.get_p_log_scales( + timestep=timestep, stddev_type=self._denoising_stddevs) + mixing_component = self.get_mixing_component( + x_noisy, timestep, enabled=model.mixed_prediction) + + # --- copied --- # + if i == T_user - 1: # the next step is to generate x_0 + assert t == 0 + alpha_next = torch.tensor(1.0) + sigma = torch.tensor(0.0) + else: + alpha_next = Alpha_bar[user_defined_steps[i+1]] + sigma = kappa * \ + torch.sqrt( + (1-alpha_next) / (1-Alpha_bar[tau]) * (1 - Alpha_bar[tau] / alpha_next)) + + x = x_noisy * torch.sqrt(alpha_next / Alpha_bar[tau]) + c = torch.sqrt(1 - alpha_next - sigma ** 2) - torch.sqrt(1 - + Alpha_bar[tau]) * torch.sqrt(alpha_next / Alpha_bar[tau]) + + # --- run model forward --- # + with autocast(enable_autocast): + pred_logits = model(x=x_noisy, t=timestep.float( + ), condition_input=condition_input, clip_feat=clip_feat, **kwargs) + # pred_logits = model(x_noisy, timestep.float() / self._diffusion_steps) + logits = utils.get_mixed_prediction( + model.mixed_prediction, pred_logits, model.mixing_logit, mixing_component) + epsilon_theta = logits + # xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et + # x_{t-1} = c * et + sigma * randn + sqrt(alpha_next / alpha_bar_t) * x_t + x += c * epsilon_theta + sigma * \ + torch.randn(x_noisy_size).to(x.device) + x_noisy = x + output_list.append(x_noisy) + # if is_image: + # x_image = x_image.clamp(min=-1., max=1.) + # x_image = utils.unsymmetrize_image_data(x_image) + model.train() + return x_noisy, output_list + + def get_q_posterior_mean(self, x_noisy, prediction, t): + # last step works differently (for better FIDs we NEVER sample in last conditional images output!) + # Line 4 in algorithm 2 in DDPM: + if t == 0: + mean = 1.0 / torch.sqrt(self._alpha_bars[0]) * \ + (x_noisy - torch.sqrt(1.0 - self._alpha_bars[0]) * prediction) + else: + mean = 1.0 / torch.sqrt(self._alphas[t]) * \ + (x_noisy - self._betas_init[t] * prediction / + torch.sqrt(1.0 - self._alpha_bars[t])) + + return mean + + def get_mixing_component(self, x_noisy, timestep, enabled): + size = x_noisy.size() + alpha_bars = torch.gather(self._alpha_bars, 0, timestep-1) + if enabled: + one_minus_alpha_bars_sqrt = utils.view4D( + torch.sqrt(1.0 - alpha_bars), size) + mixing_component = one_minus_alpha_bars_sqrt * x_noisy + else: + mixing_component = None + + return mixing_component + + def mixing_component(self, eps, var, t, enabled): + return self.get_mixing_component(eps, t, enabled) + + @torch.no_grad() + def run_denoising_diffusion_from_t(self, model, num_samples, shape, time_start, x_noisy, + temp=1.0, enable_autocast=False, is_image=False, prior_var=1.0, + condition_input=None, given_noise=None): + """ + Run the full denoising sampling loop. + given_noise: Nstep,*x_noisy_size + """ + # set model to eval mode + model.eval() + + # initialize sample + x_noisy_size = [num_samples] + shape + # if given_noise is None: + ## raise ValueError('given_noise is required') + # raise NotImplementedError + # x_noisy = torch.randn(size=x_noisy_size, device='cuda') ## * np.sqrt(prior_var) * temp + # else: + ## x_noisy = given_noise[0] + + output_list = [] + # denoising loop + for t in reversed(range(0, time_start)): # self._diffusion_steps)): + # if t % 100 == 0: + # logger.info('t={}', t) + # the model uses (1 ... T) without 0 + timestep = torch.ones( + num_samples, dtype=torch.int64, device='cuda') * (t+1) + fixed_log_scales = self.get_p_log_scales( + timestep=timestep, stddev_type=self._denoising_stddevs) + mixing_component = self.get_mixing_component( + x_noisy, timestep, enabled=model.mixed_prediction) + + # run model + with autocast(enable_autocast): + pred_logits = model( + x=x_noisy, t=timestep.float(), condition_input=condition_input) + # pred_logits = model(x_noisy, timestep.float() / self._diffusion_steps) + logits = utils.get_mixed_prediction( + model.mixed_prediction, pred_logits, model.mixing_logit, mixing_component) + + output_dist = utils.decoder_output( + 'place_holder', logits, fixed_log_scales=fixed_log_scales) + if given_noise is None: + noise = torch.randn(size=x_noisy_size, device='cuda') + else: + # torch.randn(size=x_noisy_size, device='cuda') + noise = given_noise[1][t] + + mean = self.get_q_posterior_mean(x_noisy, output_dist.means, t) + if t == 0: # < 10: + x_image = mean + else: + x_noisy = mean + \ + torch.exp(output_dist.log_scales) * noise * temp + output_list.append(x_noisy) + # if is_image: + # x_image = x_image.clamp(min=-1., max=1.) + # x_image = utils.unsymmetrize_image_data(x_image) + model.train() + return x_image, output_list diff --git a/utils/ema.py b/utils/ema.py new file mode 100644 index 0000000..8248e2e --- /dev/null +++ b/utils/ema.py @@ -0,0 +1,120 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +""" src: ddim/model/ema.py +implement the EMA model +usage: + ema_helper = EMAHelper(mu=self.config.model.ema_rate) + ema_helper.register(model) + ema_helper.load_state_dict(states[-1]) + ema_helper.ema(model) + +after optimizer.step() + ema_helper.update(model) + +copied and modified from + https://github.com/NVlabs/LSGM/blob/5eae2f385c014f2250c3130152b6be711f6a3a5a/util/ema.py +""" + +import warnings +import torch +from torch.optim import Optimizer +from loguru import logger +import torch.nn as nn +import os + + +class EMA(Optimizer): + def __init__(self, opt, ema_decay): + self.ema_decay = ema_decay + self.apply_ema = self.ema_decay > 0. + logger.info('[EMA] apply={}', self.apply_ema) + self.optimizer = opt + self.state = opt.state + self.param_groups = opt.param_groups + + def zero_grad(self): + self.optimizer.zero_grad() + + def step(self, *args, **kwargs): + retval = self.optimizer.step(*args, **kwargs) + + # stop here if we are not applying EMA + if not self.apply_ema: + return retval + + for group in self.optimizer.param_groups: + ema, params = {}, {} + for i, p in enumerate(group['params']): + if p.grad is None: + continue + state = self.optimizer.state[p] + + # State initialization + if 'ema' not in state: + state['ema'] = p.data.clone() + + if p.shape not in params: + params[p.shape] = {'idx': 0, 'data': []} + ema[p.shape] = [] + + params[p.shape]['data'].append(p.data) + ema[p.shape].append(state['ema']) + + for i in params: + params[i]['data'] = torch.stack(params[i]['data'], dim=0) + ema[i] = torch.stack(ema[i], dim=0) + ema[i].mul_(self.ema_decay).add_( + params[i]['data'], alpha=1. - self.ema_decay) + + for p in group['params']: + if p.grad is None: + continue + idx = params[p.shape]['idx'] + self.optimizer.state[p]['ema'] = ema[p.shape][idx, :] + params[p.shape]['idx'] += 1 + + return retval + + def load_state_dict(self, state_dict): + super(EMA, self).load_state_dict(state_dict) + # load_state_dict loads the data to self.state and self.param_groups. We need to pass this data to + # the underlying optimizer too. + # logger.info('state size: {}', len(self.state)) + self.optimizer.state = self.state + self.optimizer.param_groups = self.param_groups + + def swap_parameters_with_ema(self, store_params_in_ema): + """ This function swaps parameters with their ema values. It records original parameters in the ema + parameters, if store_params_in_ema is true.""" + + # stop here if we are not applying EMA + if not self.apply_ema: + warnings.warn( + 'swap_parameters_with_ema was called when there is no EMA weights.') + return + logger.info('swap with ema') + count_no_found = 0 + for group in self.optimizer.param_groups: + for i, p in enumerate(group['params']): + if not p.requires_grad: + # logger.info('no swap for i={}, param shape={}', i, p.shape) + continue + if p not in self.optimizer.state: + count_no_found += 1 + # logger.info('no found i={}, {}/{} p {}', i, + # count_no_found, len(group['params']), p.shape) + continue + # if count_no_found > 100: + # logger.info('found: i={}, p={}', i, p.shape) + ema = self.optimizer.state[p]['ema'] + if store_params_in_ema: + tmp = p.data.detach() + p.data = ema.detach() + self.optimizer.state[p]['ema'] = tmp + else: + p.data = ema.detach() diff --git a/utils/eval_helper.py b/utils/eval_helper.py new file mode 100644 index 0000000..95585fb --- /dev/null +++ b/utils/eval_helper.py @@ -0,0 +1,341 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import os +import json +from comet_ml import Experiment, OfflineExperiment +## import open3d as o3d +import time +import numpy as np +import torch +from loguru import logger +import torchvision +from PIL import Image +from utils.vis_helper import visualize_point_clouds_3d +from utils.data_helper import normalize_point_clouds +from utils.checker import * +import torchvision +import sys +import math +from utils.evaluation_metrics_fast import compute_all_metrics, \ + jsd_between_point_cloud_sets, print_results, write_results +from utils.evaluation_metrics_fast import EMD_CD +CD_ONLY = int(os.environ.get('CD_ONLY', 0)) +VIS = 1 + +def pair_vis(gen_x, tr_x, titles, subtitles, writer, step=-1): + img_list = [] + num_recon = len(gen_x) + for i in range(num_recon): + points = gen_x[i] + points = normalize_point_clouds([tr_x[i], points]) + img = visualize_point_clouds_3d(points, subtitles[i]) + img_list.append(torch.as_tensor(img) / 255.0) + grid = torchvision.utils.make_grid(img_list, nrow=num_recon//2) + if writer is not None: + writer.add_image(titles, grid, step) + +def compute_NLL_metric(gen_pcs, ref_pcs, device, writer=None, output_name='', batch_size=200, step=-1, tag=''): + # evaluate the reconstrution results + metrics = EMD_CD(gen_pcs.to(device), ref_pcs.to(device), + batch_size=batch_size, accelerated_cd=True, reduced=False) + titles = 'nll/first-10-%s' % tag + k1, k2 = list(metrics.keys()) + subtitles = [['ori', 'gen-%s=%.1fx1e-2;%s=%.1fx1e-2' % + (k1, metrics[k1][j]*1e2, k2, metrics[k2][j]*1e2)] for j in range(10)] + pair_vis(gen_pcs[:10], ref_pcs[:10], titles, subtitles, writer, step=step) + results = {} + + for k in metrics.keys(): + sorted, indices = torch.sort(metrics[k]) + worse_ten, worse_score = indices[-10:], sorted[-10:] + titles = 'nll/worst-%s-%s' % (k, tag) + subtitles = [['ori', 'gen-%s=%.2fx1e-2' % + (k, worse_score[j]*1e2)] for j in range(len(worse_score))] + pair_vis(gen_pcs[worse_ten], ref_pcs[worse_ten], + titles, subtitles, writer, step=step) + if 'score_detail' not in results: + results['score_detail'] = metrics[k] + metrics[k] = metrics[k].mean() + + logger.info('best 10: {}', indices[:10]) + results.update({k: v.item() for k, v in metrics.items()}) + output = '' + for k, v in results.items(): + if 'detail' in k: + continue + output += '%s=%.3fx1e-2 ' % (k, v*1e2) + logger.info('{}: {}', k, v) + if 'CD' in k: + score = v + + url = writer.url if writer is not None else '' + logger.info('\n' + '-'*60 + + f'\n{output_name} | \n{output} step={step} \n {url} \n ' + '-'*60) + return results + + +def get_ref_num(cats, luo_split=False): + #ref = './scripts/test_data/ref_%s.pt'%cats + #assert(os.path.exists(ref)), f'file not found: {ref}' + num_test = { + 'animal': 100, + 'airplane': 405, + 'airplane_ps': 405, + 'chair': 662, + 'chair_ps': 662, + 'car': 352, + 'car_ps': 352, + 'all': 1000, + 'mug': 22, + 'bottle': 43 + } + if luo_split: + num_test = { + 'airplane': 607, + 'chair': 989, + 'car': 528 + } + + assert(cats in num_test), f'not found: {cats} in {num_test}' + return num_test[cats] + + +def get_cats(cats): + # return the category name for this dataset + all_cats = ['airplane', 'chair', 'car', 'all', 'animal', 'mug', 'bottle'] + for c in all_cats: + if c in cats or c == cats: + cats = c + break + assert(cats in all_cats), f'not foud cats for {cats} in {all_cats}' + return cats + + +def get_ref_pt(cats, data_type="datasets.pointflow_datasets", luo_split=False): + cats = get_cats(cats) + root = './datasets/test_data/' + if 'pointflow' in data_type: + ref = 'ref_val_%s.pt' % cats + elif 'neuralspline_datasets' in data_type: + ref = 'ref_ns_val_%s.pt' % cats + else: + logger.info('get_ref_pt not support data_type: %s' % data_type) + return None + + ref = os.path.join(root, ref) + assert(os.path.exists(ref)), f'file not found: {ref}' + return ref + + +#@torch.no_grad() +#def compute_score_fast(gen_pcs, ref_pcs, m_pcs, s_pcs, +# batch_size_test=256, device_str='cuda', cd_only=1, +# exp=None, verbose=False, +# device=None, accelerated_cd=True, writer=None, norm_box=False, **print_kwargs): +# """ used to eval the pcs during training; all the files will not be dumpped into disk (to save time) +# the ref_pcs will be part of the full dataset only +# Args: +# output_name (str) path to sample obj: tensor: (Nsample.Npoint.3or6) +# ref_name (str) path to torch obj: +# torch.save({'ref': ref_pcs, 'mean': m_pcs, 'std': s_pcs}, ref_name) +# print_kwargs (dict): entries: dataset, hash, step, epoch; +# """ +# if gen_pcs.shape[1] > ref_pcs.shape[1]: +# xperm = np.random.permutation(np.arange(gen_pcs.shape[1]))[ +# :ref_pcs.shape[1]] +# gen_pcs = gen_pcs[:, xperm] +# if ref_pcs.shape[0] > gen_pcs.shape[0]: +# ref_pcs = ref_pcs[:gen_pcs.shape[0]] +# m_pcs = m_pcs[:gen_pcs.shape[0]] +# s_pcs = s_pcs[:gen_pcs.shape[0]] +# elif ref_pcs.shape[0] < gen_pcs.shape[0]: +# gen_pcs = gen_pcs[:ref_pcs.shape[0]] +# +# device = torch.device(device_str) if device is None else device +# CHECKEQ(ref_pcs.shape[0], gen_pcs.shape[0]) +# N_ref = ref_pcs.shape[0] # subset it +# batch_size_test = N_ref # * 0.5 +# if gen_pcs.shape[2] == 6: # B,N,3 or 6 +# gen_pcs = gen_pcs[:, :, :3] +# ref_pcs = ref_pcs[:, :, :3] +# if norm_box: +# ref_pcs = 0.5 * torch.stack(normalize_point_clouds(ref_pcs), dim=0) +# gen_pcs = 0.5 * torch.stack(normalize_point_clouds(gen_pcs), dim=0) +# print_kwargs['dataset'] = print_kwargs.get('dataset', +# '')+'-normbox' +# +# #ref_pcs = normalize_point_clouds(ref_pcs) +# #gen_pcs = normalize_point_clouds(gen_pcs) +# # print_kwargs['dataset'] = print_kwargs.get('dataset', +# # '')+'-normbox' +# # logger.info('[data shape] ref_pcs: {}, gen_pcs: {}, mean={}, std={}; norm_box={}', +# # ref_pcs.shape, gen_pcs.shape, m_pcs.shape, s_pcs.shape, norm_box) +# elif m_pcs is not None and s_pcs is not None: +# ref_pcs = ref_pcs * s_pcs + m_pcs +# gen_pcs = gen_pcs * s_pcs + m_pcs +# # visualize first few samples: +# if VIS and writer is not None and writer.exp is not None or exp is not None: +# logger.info('vis the result') +# if exp is None: +# exp = writer.exp +# img_list = [] +# for i in range(min(20, ref_pcs.shape[0])): +# NORM_VIS = 0 +# if NORM_VIS: +# norm_ref, norm_gen = normalize_point_clouds([ +# ref_pcs[i], gen_pcs[i]]) +# else: +# norm_ref = ref_pcs[i] +# norm_gen = gen_pcs[i] +# img = visualize_point_clouds_3d([norm_ref, norm_gen], +# [f'ref-{i}', f'gen-{i}'], bound=0.5) +# img_list.append(torch.as_tensor(img) / 255.0) +# grid = torchvision.utils.make_grid(img_list) +# # to 3,H,W to H,W,3 +# ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute( +# 1, 2, 0).to('cpu', torch.uint8).numpy() +# exp.log_image(ndarr, 'samples/verse_%s' % +# print_kwargs.get('hash', '_'), step=print_kwargs.get('step', 0)) +# # epoch=print_kwargs.get('epoch', 0)) +# +# metric2 = 'EMD' if not cd_only else None +# results = compute_all_metrics(gen_pcs.to(device).float(), +# ref_pcs.to(device).float(), batch_size_test, +# accelerated_cd=accelerated_cd, metric2=metric2, +# verbose=verbose, +# **print_kwargs) +# print_results(results, **print_kwargs) +# +# return results + + +@torch.no_grad() +def compute_score(output_name, ref_name, batch_size_test=256, device_str='cuda', + device=None, accelerated_cd=True, writer=None, + exp=None, + norm_box=False, skip_write=False, **print_kwargs): + """ + Args: + output_name (str) path to sample obj: tensor: (Nsample.Npoint.3or6) + ref_name (str) path to torch obj: + torch.save({'ref': ref_pcs, 'mean': m_pcs, 'std': s_pcs}, ref_name) + print_kwargs (dict): entries: dataset, hash, step, epoch; + """ + logger.info('[compute sample metric] sample: {} and ref: {}', + output_name, ref_name) + ref = torch.load(ref_name) + ref_pcs = ref['ref'][:, :, :3] + m_pcs, s_pcs = ref['mean'], ref['std'] + gen_pcs = torch.load(output_name) + if gen_pcs.shape[1] > ref_pcs.shape[1]: + xperm = np.random.permutation(np.arange(gen_pcs.shape[1]))[ + :ref_pcs.shape[1]] + gen_pcs = gen_pcs[:, xperm] + if type(gen_pcs) is dict: + logger.info('WARNING: the gen_pcs is a dict, with key ' + 'as {}| usuaglly its a tensor ' + 'you perhaps takes the train data,', + gen_pcs.keys()) + gen_pcs = gen_pcs['ref'] + device = torch.device(device_str) if device is None else device + # batch_size_test = ref_pcs.shape[0] + logger.info('[data shape] ref_pcs: {}, gen_pcs: {}, mean={}, std={}; norm_box={}', + ref_pcs.shape, gen_pcs.shape, m_pcs.shape, s_pcs.shape, norm_box) + N_ref = ref_pcs.shape[0] # subset it + m_pcs = m_pcs[:N_ref] + s_pcs = s_pcs[:N_ref] + ref_pcs = ref_pcs[:N_ref] + gen_pcs = gen_pcs[:N_ref] + if gen_pcs.shape[2] == 6: # B,N,3 or 6 + gen_pcs = gen_pcs[:, :, :3] + ref_pcs = ref_pcs[:, :, :3] + if norm_box: + #ref_pcs = ref_pcs * s_pcs + m_pcs + #gen_pcs = gen_pcs * s_pcs + m_pcs + ref_pcs = 0.5 * torch.stack(normalize_point_clouds(ref_pcs), dim=0) + gen_pcs = 0.5 * torch.stack(normalize_point_clouds(gen_pcs), dim=0) + print_kwargs['dataset'] = print_kwargs.get('dataset', + '')+'-normbox' + else: + ref_pcs = ref_pcs * s_pcs + m_pcs + gen_pcs = gen_pcs * s_pcs + m_pcs + # visualize first few samples: + if VIS: + if exp is not None: + exp = exp + elif writer is not None: + exp = writer.exp + elif os.path.exists('.comet_api'): + comet_args = json.load(open('.comet_api', 'r')) + exp = Experiment(display_summary_level=0, + **comet_args) + else: + exp = OfflineExperiment(offline_directory="/tmp") + img_list = [] + gen_list = [] + ref_list = [] + for i in range(20): + NORM_VIS = 0 + if NORM_VIS: + norm_ref, norm_gen = normalize_point_clouds([ + ref_pcs[i], gen_pcs[i]]) + else: + norm_ref = ref_pcs[i] + norm_gen = gen_pcs[i] + ref_img = visualize_point_clouds_3d([norm_ref], + [f'ref-{i}'], bound=1.0) # 0.8) + gen_img = visualize_point_clouds_3d([norm_gen], + [f'gen-{i}'], bound=1.0) # 0.8) + ref_list.append(torch.as_tensor(ref_img) / 255.0) + gen_list.append(torch.as_tensor(gen_img) / 255.0) + img_list.append(ref_list[-1]) + img_list.append(gen_list[-1]) + + path = output_name.replace('.pt', '_eval.png') + + grid = torchvision.utils.make_grid(gen_list) + # to 3,H,W to H,W,3 + ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute( + 1, 2, 0).to('cpu', torch.uint8).numpy() + exp.log_image(ndarr, 'samples') + + ref_grid = torchvision.utils.make_grid(ref_list) + # to 3,H,W to H,W,3 + ref_ndarr = ref_grid.mul(255).add_(0.5).clamp_(0, 255).permute( + 1, 2, 0).to('cpu', torch.uint8).numpy() + ndarr = np.concatenate([ndarr, ref_ndarr], axis=0) + exp.log_image(ndarr, 'samples_vs_ref') + + torchvision.utils.save_image(img_list, path) + logger.info(exp.url) + logger.info('save vis at {}', path) + metric2 = 'EMD' if not CD_ONLY else None + logger.info('print_kwargs: {}', print_kwargs) + results = compute_all_metrics(gen_pcs.to(device).float(), + ref_pcs.to(device).float(), batch_size_test, accelerated_cd=accelerated_cd, metric2=metric2, + **print_kwargs) + + jsd = jsd_between_point_cloud_sets( + gen_pcs.cpu().numpy(), ref_pcs.cpu().numpy()) + results['jsd'] = jsd + msg = print_results(results, **print_kwargs) + # with open('../exp/eval_out.txt', 'a') as f: + # run_time = time.strftime('%m%d-%H%M-%S') + # f.write('<< date: %s >>\n' % run_time) + # f.write('%s\n%s\n' % (exp.url, msg)) + results['url'] = exp.url + if not skip_write: + os.makedirs('results', exist_ok=True) + msg = write_results( + os.path.join('./results/', 'eval_out.csv'), + results, **print_kwargs) + if metric2 is None: + logger.info('early exit') + exit() + return results + diff --git a/utils/evaluation_metrics_fast.py b/utils/evaluation_metrics_fast.py new file mode 100644 index 0000000..b1728f7 --- /dev/null +++ b/utils/evaluation_metrics_fast.py @@ -0,0 +1,687 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +""" +copied and modified from + https://github.com/luost26/diffusion-point-cloud/blob/910334a8975aa611423a920869807427a6b60efc/evaluation/evaluation_metrics.py +and + https://github.com/stevenygd/PointFlow/tree/b7a9216ffcd2af49b24078156924de025c4dbfb6/metrics +""" +import torch +import time +from tabulate import tabulate +import numpy as np +from loguru import logger +import warnings +from scipy.stats import entropy +from sklearn.neighbors import NearestNeighbors +from numpy.linalg import norm +from utils.exp_helper import ExpTimer +from third_party.PyTorchEMD.emd_nograd import earth_mover_distance_nograd +from third_party.PyTorchEMD.emd import earth_mover_distance +from third_party.ChamferDistancePytorch.chamfer3D.dist_chamfer_3D import chamfer_3DDist_nograd, chamfer_3DDist +from utils.checker import * +import torch.nn.functional as F + + +def distChamferCUDA_l1(pred, target, points_dim=3): + import models.pvcnn.functional as pvcnn_fun + # expect B.2048.3 and B.2048.3 + B = pred.shape[0] + CHECKDIM(pred, 2, points_dim) + CHECKDIM(target, 2, points_dim) + CHECK3D(pred) + CHECK3D(target) + target_nndist, pred_nndist, target_nnidx, pred_nnidx \ + = chamfer_3DDist()(target[:, :, :3], pred[:, :, :3]) + target_normal = target.contiguous().permute(0, 2, 1).contiguous() # BN3->B3N + pred_normal = pred.contiguous().permute(0, 2, 1).contiguous() # BN3->B3N + + target_point_normal = pvcnn_fun.grouping( + target_normal, pred_nnidx[:, :, None]) # B,3,Np,1 + target_point_normal = target_point_normal.squeeze(-1) # B,3,Np + cham_norm_y = F.l1_loss(pred_normal.view(-1, points_dim), + target_point_normal.view(-1, points_dim), + reduction='sum') + + closest_pred_point_normal = pvcnn_fun.grouping( + pred_normal, target_nnidx[:, :, None]).squeeze(-1) # B,3,Np,1 -> B,3,Np, + cham_norm_y2 = F.l1_loss(closest_pred_point_normal.view(-1, points_dim), + target_normal.view(-1, points_dim), + reduction='sum') + + return cham_norm_y, cham_norm_y2 + ## target_nndist, pred_nndist, cham_norm_y, pred_with_gt_normal + # return nn_distance(x, y) + + +#def distChamferCUDA_withnormal(pred, target, normal_loss='cos'): +# # expect B.2048.3 and B.2048.3 +# import models.pvcnn.functional as pvcnn_fun +# B = pred.shape[0] +# CHECKDIM(pred, 2, 6) +# CHECKDIM(target, 2, 6) +# CHECK3D(pred) +# CHECK3D(target) +# +# target_nndist, pred_nndist, target_nnidx, pred_nnidx \ +# = chamfer_3DDist()(target[:, :, :3], pred[:, :, :3]) +# target_normal = target[:, :, 3:].contiguous().permute( +# 0, 2, 1).contiguous() # BN3->B3N +# pred_normal = pred[:, :, 3:].contiguous().permute( +# 0, 2, 1).contiguous() # BN3->B3N +# target_point_normal = pvcnn_fun.grouping( +# target_normal, pred_nnidx[:, :, None]) # B,3,Np,1 +# target_point_normal = target_point_normal.squeeze(-1) # B,3,Np +# if normal_loss == 'cos': +# pred_normal = pred_normal / \ +# (1e-8 + (pred_normal**2).sum(1, keepdim=True).sqrt()) +# cham_norm_y = 1 - torch.abs( +# F.cosine_similarity(pred_normal, target_point_normal, +# dim=1, eps=1e-6)) +# elif normal_loss == 'l2': +# cham_norm_y = F.mse_loss(pred_normal.view(B, -1), target_point_normal.view(B, -1), +# reduction='none').view(B, 3, -1).mean(1) +# else: +# raise NotImplementedError(normal_loss) +# pred_with_gt_normal = torch.cat([pred[:, :, :3], target_point_normal.permute(0, 2, 1)], +# dim=2).contiguous() +# CHECKEQ(cham_norm_y.shape[-1], pred_nndist.shape[-1]) +# +# return target_nndist, pred_nndist, cham_norm_y, pred_with_gt_normal +# # return nn_distance(x, y) + + +def distChamferCUDA(x, y): + # expect B.2048.3 and B.2048.3 + B = x.shape[0] + CHECKDIM(x, 2, 3) + CHECKDIM(y, 2, 3) + CHECK3D(x) + CHECK3D(y) + # assert (x.shape[-1] == 3 + # and y.shape[-1] == 3), f'get {x.shape} and {y.shape}' + dist1, dist2, _, _ = chamfer_3DDist()(x.cuda(), y.cuda()) + return dist1, dist2 + + +def distChamferCUDAnograd(x, y): + # expect B.2048.3 and B.2048.3 + assert (x.shape[-1] == 3 + and y.shape[-1] == 3), f'get {x.shape} and {y.shape}' + # return nn_distance_nograd(x, y) + # dist1, _, dist2, _ = NNDistance(x, y) + dist1, dist2, _, _ = chamfer_3DDist_nograd()(x.cuda(), y.cuda()) + return dist1, dist2 + + +def emd_approx(sample, ref, require_grad=True): + #B, N, N_ref = sample.size(0), sample.size(1), ref.size(1) + #assert N == N_ref, f"Not sure what would EMD do in this case; get N={N};N_ref={N_ref}" + # if not require_grad: + # t00 = time.time() + # match, _ = ApproxMatch(sample, ref) + # print('am: ', time.time() - t00) + # emd = MatchCost(sample, ref, match) + # del match + # # emd = match_cost_nograd(sample, ref) + # else: + # logger.info('error, required require_grad for faster compute ') + # exit() + # emd = match_cost(sample, ref) # (B,) + # emd_norm = emd / float(N) # (B,) + # logger.info('emd_norm: {} | sample: {}, ref: {}', + # emd_norm.shape, sample.shape, ref.shape) + if not require_grad: + emd_pyt = earth_mover_distance_nograd( + sample.cuda(), ref.cuda(), transpose=False) + else: + emd_pyt = earth_mover_distance( + sample.cuda(), ref.cuda(), transpose=False) + + #logger.info('emd_pyt: {}, diff: {}', emd_pyt.shape, ((emd_pyt - emd_norm)**2).sum()) + return emd_pyt + + +# def emd_approx(sample, ref, require_grad=True): +# B, N, N_ref = sample.size(0), sample.size(1), ref.size(1) +# assert N == N_ref, f"Not sure what would EMD do in this case; get N={N};N_ref={N_ref}" +# if not require_grad: +# t00 = time.time() +# match, _ = ApproxMatch(sample, ref) +# print('am: ', time.time() - t00) +# emd = MatchCost(sample, ref, match) +# del match +# # emd = match_cost_nograd(sample, ref) +# else: +# logger.info('error, required require_grad for faster compute ') +# exit() +# emd = match_cost(sample, ref) # (B,) +# emd_norm = emd / float(N) # (B,) +# logger.info('emd_norm: {} | sample: {}, ref: {}', +# emd_norm.shape, sample.shape, ref.shape) +# return emd_norm + + +# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet +def distChamfer(a, b): + x, y = a, b + bs, num_points, points_dim = x.size() + xx = torch.bmm(x, x.transpose(2, 1)) + yy = torch.bmm(y, y.transpose(2, 1)) + zz = torch.bmm(x, y.transpose(2, 1)) + diag_ind = torch.arange(0, num_points).to(a).long() + rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx) + ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy) + P = (rx.transpose(2, 1) + ry - 2 * zz) + return P.min(1)[0], P.min(2)[0] + + +def EMD_CD(sample_pcs, + ref_pcs, + batch_size, + accelerated_cd=False, + reduced=True, + require_grad=False): + N_sample = sample_pcs.shape[0] + N_ref = ref_pcs.shape[0] + assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample) + + cd_lst = [] + emd_lst = [] + iterator = range(0, N_sample, batch_size) + + for b_start in iterator: + b_end = min(N_sample, b_start + batch_size) + sample_batch = sample_pcs[b_start:b_end] + ref_batch = ref_pcs[b_start:b_end] + + if accelerated_cd and not require_grad: + dl, dr = distChamferCUDAnograd(sample_batch, ref_batch) + elif accelerated_cd: + dl, dr = distChamferCUDA(sample_batch, ref_batch) + else: + dl, dr = distChamfer(sample_batch, ref_batch) + cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1)) + + emd_batch = emd_approx(sample_batch, ref_batch, + require_grad=require_grad) + emd_lst.append(emd_batch) + + if reduced: + cd = torch.cat(cd_lst).mean() + emd = torch.cat(emd_lst).mean() + else: + cd = torch.cat(cd_lst) + emd = torch.cat(emd_lst) + + results = { + 'MMD-CD': cd, + 'MMD-EMD': emd, + } + return results + + +def formulate_results(results, dataset, hash, step, epoch): + reported = f'S{step}E{epoch}' + reported = '' if reported == 'SE' else reported + msg_head, msg_oneline = '', '' + if dataset != '-': + msg_head += "Dataset " + msg_oneline += f"{dataset} " + if hash != '-': + msg_head += "Model " + msg_oneline += f"{hash} " + if step != '' or epoch != '': + msg_head += 'reported ' + msg_oneline += f"{reported} " + + msg_head += "MMD-CDx0.001\u2193 MMD-EMDx0.01\u2193 COV-CD%\u2191 COV-EMD%\u2191 1-NNA-CD%\u2193 1-NNA-EMD%\u2193 JSD\u2193" + msg_oneline += f"{results.get('lgan_mmd-CD', 0)*1000:.4f} {results.get('lgan_mmd-EMD', 0)*100:.4f} {results.get('lgan_cov-CD', 0)*100:.2f} {results.get('lgan_cov-EMD', 0)*100:.2f} {results.get('1-NN-CD-acc', 0)*100:.2f} {results.get('1-NN-EMD-acc', 0)*100:.2f} {results.get('jsd', 0):.2f}" + if results.get('url', None) is not None: + msg_head += " url" + msg_oneline += f" {results.get('url', '-')}" + msg_oneline = msg_oneline.split(' ') + msg_head = msg_head.split(' ') + return msg_head, msg_oneline + + +def write_results(out_file, results, dataset='', hash='', step='', epoch=''): + msg_head, msg_oneline = formulate_results( + results, dataset, hash, step, epoch) + content2 = tabulate([msg_oneline], msg_head, tablefmt="tsv") + text_file = open(out_file, "a") + text_file.write(content2+'\n') + text_file.close() + return content2 + + +def print_results(results, dataset='-', hash='-', step='', epoch=''): + msg_head, msg_oneline = formulate_results( + results, dataset, hash, step, epoch) + msg = '{}'.format( + tabulate([msg_oneline], msg_head, tablefmt="plain")) + logger.info('\n{}', msg) + return msg + + +def _pairwise_EMD_CD_sub(metric, sample_batch, ref_pcs, N_ref, batch_size, accelerated_cd, verbose, require_grad): + cd_lst = [] + emd_lst = [] + sub_iterator = range(0, N_ref, batch_size) + total_iter = int(N_ref / float(batch_size) + 0.5) + # if verbose: + # import tqdm + # sub_iterator = tqdm.tqdm(sub_iterator, leave=False) + #t00 = time.time() + iter_id = 0 + for ref_b_start in sub_iterator: + ref_b_end = min(N_ref, ref_b_start + batch_size) + ref_batch = ref_pcs[ref_b_start:ref_b_end] + + batch_size_ref = ref_batch.size(0) + point_dim = ref_batch.size(2) + sample_batch_exp = sample_batch.view(1, -1, point_dim).expand( + batch_size_ref, -1, -1) + sample_batch_exp = sample_batch_exp.contiguous() + # print('before cuda {:.5f}s'.format(time.time() - t00)) + # t00 = time.time() + if metric == 'CD': + if accelerated_cd and not require_grad: + dl, dr = distChamferCUDAnograd(sample_batch_exp, ref_batch) + elif accelerated_cd: + dl, dr = distChamferCUDA(sample_batch_exp, ref_batch) + else: + dl, dr = distChamfer(sample_batch_exp, ref_batch) + # print('cuda: {:.5f}'.format(time.time() - t00)) + #t00 = time.time() + cd_lst.append(((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1)) + ) + elif metric == 'EMD': + emd_batch = emd_approx( + sample_batch_exp, ref_batch, require_grad=require_grad) + emd_lst.append(emd_batch.view(1, -1)) + else: + raise NotImplementedError + # torch.cuda.empty_cache() + # print('approx: {:.5f}'.format(time.time() - t00)) + if metric == 'CD': + cd_lst = torch.cat(cd_lst, dim=1) + return cd_lst, cd_lst + else: + emd_lst = torch.cat(emd_lst, dim=1) + return emd_lst, emd_lst + return cd_lst, emd_lst + + +def _pairwise_EMD_CD_(metric, + sample_pcs, + ref_pcs, + batch_size, + require_grad=True, + accelerated_cd=True, + verbose=True): + N_sample = sample_pcs.shape[0] + N_ref = ref_pcs.shape[0] + # N_sample = 50 + all_cd, all_emd = [], [] + iterator = range(N_sample) + total_iter = N_sample + exp_timer = ExpTimer(total_iter) + iter_id = 0 + print_every = max(int(total_iter // 3), 5) + for i, sample_b_start in enumerate(iterator): + exp_timer.tic() + if iter_id % print_every == 0 and iter_id > 0 and verbose: + logger.info('done {:02.1f}%({}) eta={:.1f}m', + 100.0*iter_id/total_iter, total_iter, + exp_timer.hours_left()*60) + sample_batch = sample_pcs[sample_b_start] + cd_lst, emd_lst = _pairwise_EMD_CD_sub(metric, + sample_batch, ref_pcs, N_ref, batch_size, + accelerated_cd, verbose, require_grad) + all_cd.append(cd_lst) + all_emd.append(emd_lst) + exp_timer.toc() + iter_id += 1 + + all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref + all_emd = torch.cat(all_emd, dim=0) # N_sample, N_ref + + return all_cd, all_emd +# def _pairwise_EMD_CD_(sample_pcs, +# ref_pcs, +# batch_size, +# require_grad=True, +# accelerated_cd=True, +# verbose=True, +# bs=50): +## N_sample = sample_pcs.shape[0] +## N_ref = ref_pcs.shape[0] +# N_sample = 10 +# all_cd = torch.zeros(N_sample, N_ref) #[] +# all_emd = torch.zeros(N_sample, N_ref) +# N_sample = 50 +## all_cd, all_emd = [], [] +## iterator = range(0, N_sample, bs) +# if verbose: +## import tqdm +## iterator = tqdm.tqdm(iterator) +# for i, sample_b_start in enumerate(iterator): +# if bs == 1: +## sample_batch = sample_pcs[sample_b_start] +# cd_lst, emd_lst = _pairwise_EMD_CD_sub( +## sample_batch, ref_pcs, N_ref, batch_size, +# accelerated_cd, verbose, require_grad) +# elif bs > 1: +## sample_b_end = min(N_sample, sample_b_start + bs) +# cd_lst, emd_lst = _pairwise_EMD_CD_(sample_pcs[sample_b_start:sample_b_end], +## ref_pcs, batch_size, +# require_grad=require_grad, +# accelerated_cd=accelerated_cd, +# verbose=verbose, +# bs=1) +## +# all_cd[i:i+1] = cd_lst.cpu() +# all_emd[i:i+1] = emd_lst.cpu() +# all_cd.append(cd_lst) +# all_emd.append(emd_lst) +## +# if (len(all_cd)+1) % 36 == 0: +# torch.cuda.empty_cache() +## +# all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref +# all_emd = torch.cat(all_emd, dim=0) # N_sample, N_ref +# return all_cd, all_emd +# torch.cuda.empty_cache() +# return all_cd.cpu(), all_emd.cpu() + + +# Adapted from https://github.com/xuqiantong/ +# GAN-Metrics/blob/master/framework/metric.py +def knn(Mxx, Mxy, Myy, k, sqrt=False): + n0 = Mxx.size(0) + n1 = Myy.size(0) + #logger.info('n0={}, n1={}: ', Mxx.shape, Myy.shape) + #logger.info('Mxx={}, Myy={}, Mxy{}', Mxx, Myy, Mxy) + label = torch.cat((torch.ones(n0), torch.zeros(n1))).to(Mxx) + #logger.info('label: {}', label.shape) + + M = torch.cat( + [torch.cat((Mxx, Mxy), 1), + torch.cat((Mxy.transpose(0, 1), Myy), 1)], 0) + if sqrt: + M = M.abs().sqrt() + INFINITY = float('inf') + val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk( + k, 0, False) + + count = torch.zeros(n0 + n1).to(Mxx) + for i in range(0, k): + count = count + label.index_select(0, idx[i]) + pred = torch.ge(count, + (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float() + #logger.info('ored: {}, label: {}', pred, label) + s = { + 'tp': (pred * label).sum(), # pred is 1, label is 1 + 'fp': (pred * (1 - label)).sum(), # pred is 1, label is 0 + 'fn': ((1 - pred) * label).sum(), # pred is 0, label is 1 + 'tn': ((1 - pred) * (1 - label)).sum(), # pred is 0, label is 0 + } + #logger.info( 'label: {} | shape: {} ', label.sum(), label.shape) + #logger.info('s={}', s) + + s.update({ + 'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10), + 'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10), + 'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10), + 'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10), + 'acc': torch.eq(label, pred).float().mean(), + }) + return s + + +def lgan_mmd_cov(all_dist): + N_sample, N_ref = all_dist.size(0), all_dist.size(1) + min_val_fromsmp, min_idx = torch.min(all_dist, dim=1) + min_val, _ = torch.min(all_dist, dim=0) + mmd = min_val.mean() + mmd_smp = min_val_fromsmp.mean() + cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref) + cov = torch.tensor(cov).to(all_dist) + return { + 'lgan_mmd': mmd, + 'lgan_cov': cov, + 'lgan_mmd_smp': mmd_smp, + } + + +def compute_all_metrics(sample_pcs, ref_pcs, batch_size, + verbose=True, accelerated_cd=False, metric1='CD', + metric2='EMD', **print_kwargs): + results = {} + ## metric2 = 'EMD' + ## metric1 = 'CD' + if verbose: + logger.info("Pairwise EMD CD") + batch_size = ref_pcs.shape[0] // 2 if ref_pcs.shape[0] != batch_size else batch_size + v1 = False + v2 = True if verbose else False + # --- eval CD results --- # + metric = metric1 # 'CD' + if verbose: + logger.info('eval metric: {}; batch-size={}, device: {}, {}', + metric, batch_size, ref_pcs.device, sample_pcs.device) + # batch_size = 100 + # v1 = True + M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(metric, ref_pcs, + sample_pcs, + batch_size, + accelerated_cd=accelerated_cd, + require_grad=False, verbose=v1) + M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(metric, ref_pcs, + sample_pcs, + batch_size, + accelerated_cd=accelerated_cd, + require_grad=False, verbose=v1) + + res_cd = lgan_mmd_cov(M_rs_cd.t()) + results.update({'%s-%s' % (k, metric): v.item() + for k, v in res_cd.items()}) + # logger.info('results: {}', results) + if verbose: + print_results(results, **print_kwargs) + M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(metric, ref_pcs, + ref_pcs, + batch_size, + accelerated_cd=accelerated_cd, + require_grad=False, verbose=v1) + M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(metric, sample_pcs, + sample_pcs, + batch_size, + accelerated_cd=accelerated_cd, + require_grad=False, verbose=v1) + # 1-NN results + one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False) + results.update( + {"1-NN-%s-%s" % (metric, k): v.item() + for k, v in one_nn_cd_res.items() if 'acc' in k}) + # logger.info('results: {}', results) + if verbose: + print_results(results, **print_kwargs) + #logger.info('early exit') + # exit() + # --- eval EMD results --- # + metric = metric2 # 'EMD' + if metric is not None: + if verbose: + logger.info('eval metric: {}', metric) + ## batch_size = min(batch_size, 31) + M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(metric, ref_pcs, + sample_pcs, + batch_size, + accelerated_cd=accelerated_cd, + require_grad=False, verbose=v2) + M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(metric, ref_pcs, + sample_pcs, + batch_size, + accelerated_cd=accelerated_cd, + require_grad=False, verbose=v2) + + res_cd = lgan_mmd_cov(M_rs_cd.t()) + results.update({'%s-%s' % (k, metric): v.item() + for k, v in res_cd.items()}) + if verbose: + print_results(results, **print_kwargs) + # logger.info('results: {}', results) + M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(metric, ref_pcs, + ref_pcs, + batch_size, + accelerated_cd=accelerated_cd, + require_grad=False, verbose=v2) + M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(metric, sample_pcs, + sample_pcs, + batch_size, + accelerated_cd=accelerated_cd, + require_grad=False, verbose=v2) + # 1-NN results + one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False) + results.update( + {"1-NN-%s-%s" % (metric, k): v.item() + for k, v in one_nn_cd_res.items() if 'acc' in k}) + if verbose: + print_results(results, **print_kwargs) + # logger.info('results: {}', results) + + return results + + +####################################################### +# JSD : from https://github.com/optas/latent_3d_points +####################################################### +def unit_cube_grid_point_cloud(resolution, clip_sphere=False): + """Returns the center coordinates of each cell of a 3D grid with + resolution^3 cells, that is placed in the unit-cube. If clip_sphere it True + it drops the "corner" cells that lie outside the unit-sphere. + """ + grid = np.ndarray((resolution, resolution, resolution, 3), np.float32) + spacing = 1.0 / float(resolution - 1) + for i in range(resolution): + for j in range(resolution): + for k in range(resolution): + grid[i, j, k, 0] = i * spacing - 0.5 + grid[i, j, k, 1] = j * spacing - 0.5 + grid[i, j, k, 2] = k * spacing - 0.5 + + if clip_sphere: + grid = grid.reshape(-1, 3) + grid = grid[norm(grid, axis=1) <= 0.5] + + return grid, spacing + + +def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28): + """Computes the JSD between two sets of point-clouds, + as introduced in the paper + ```Learning Representations And Generative Models For 3D Point Clouds```. + Args: + sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points. + ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points. + resolution: (int) grid-resolution. Affects granularity of measurements. + """ + in_unit_sphere = True + sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, + in_unit_sphere)[1] + ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, + in_unit_sphere)[1] + return jensen_shannon_divergence(sample_grid_var, ref_grid_var) + + +def entropy_of_occupancy_grid(pclouds, + grid_resolution, + in_sphere=False, + verbose=False): + """Given a collection of point-clouds, estimate the entropy of + the random variables corresponding to occupancy-grid activation patterns. + Inputs: + pclouds: (numpy array) #point-clouds x points per point-cloud x 3 + grid_resolution (int) size of occupancy grid that will be used. + """ + epsilon = 10e-4 + bound = 0.5 + epsilon + if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound: + if verbose: + warnings.warn('Point-clouds are not in unit cube.') + + if in_sphere and np.max(np.sqrt(np.sum(pclouds**2, axis=2))) > bound: + if verbose: + warnings.warn('Point-clouds are not in unit sphere.') + + grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, + in_sphere) + grid_coordinates = grid_coordinates.reshape(-1, 3) + grid_counters = np.zeros(len(grid_coordinates)) + grid_bernoulli_rvars = np.zeros(len(grid_coordinates)) + nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates) + + for pc in pclouds: + _, indices = nn.kneighbors(pc) + indices = np.squeeze(indices) + for i in indices: + grid_counters[i] += 1 + indices = np.unique(indices) + for i in indices: + grid_bernoulli_rvars[i] += 1 + + acc_entropy = 0.0 + n = float(len(pclouds)) + for g in grid_bernoulli_rvars: + if g > 0: + p = float(g) / n + acc_entropy += entropy([p, 1.0 - p]) + + return acc_entropy / len(grid_counters), grid_counters + + +def jensen_shannon_divergence(P, Q): + if np.any(P < 0) or np.any(Q < 0): + raise ValueError('Negative values.') + if len(P) != len(Q): + raise ValueError('Non equal size.') + + P_ = P / np.sum(P) # Ensure probabilities. + Q_ = Q / np.sum(Q) + + e1 = entropy(P_, base=2) + e2 = entropy(Q_, base=2) + e_sum = entropy((P_ + Q_) / 2.0, base=2) + res = e_sum - ((e1 + e2) / 2.0) + + res2 = _jsdiv(P_, Q_) + + if not np.allclose(res, res2, atol=10e-5, rtol=0): + warnings.warn('Numerical values of two JSD methods don\'t agree.') + + return res + + +def _jsdiv(P, Q): + """another way of computing JSD""" + def _kldiv(A, B): + a = A.copy() + b = B.copy() + idx = np.logical_and(a > 0, b > 0) + a = a[idx] + b = b[idx] + return np.sum([v for v in a * np.log2(a / b)]) + + P_ = P / np.sum(P) + Q_ = Q / np.sum(Q) + + M = 0.5 * (P_ + Q_) + + return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M)) diff --git a/utils/exp_helper.py b/utils/exp_helper.py new file mode 100644 index 0000000..58320a6 --- /dev/null +++ b/utils/exp_helper.py @@ -0,0 +1,122 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import time +import os +import numpy as np +from loguru import logger +from math import isnan +from calmsize import size as calmsize + + +def parse_cfg_str(cfg_str): + """ parse a string into a dict + string format: k1=v1,k2=v2,k3=v3 + """ + cfg_list = cfg_str.split('-') + cfg_list = [c for c in cfg_list if len(c) > 0] + cfg_expand_list = [] + for c in cfg_list: + k, v = c.split('=') + cfg_expand_list.append(k) + cfg_expand_list.append(v) + return cfg_expand_list + + ##cfg_dict = {} + # if cfg_str == '': + # return cfg_dict + ##cfg_str_list = cfg_str.split(',') + # for p in cfg_str_list: + ## kvs = p.split('=') + ## assert(len(kvs) == 2), f'wrong format, expect k1=v1 for {p}' + ## k, v = kvs + ## cfg_dict[k] = v + # return cfg_dict + + +def readable_size(num_bytes: int) -> str: + return '' if isnan(num_bytes) else '{:.1f}'.format(calmsize(num_bytes)) + + +class ExpTimer(object): + def __init__(self, num_epoch, start_epoch=0): + self.cur_epoch = start_epoch + self.num_epoch = num_epoch + self.time_list = [] + + def tic(self): + self.last_tic = time.time() + + def toc(self): + self.time_list.append(time.time() - self.last_tic) + self.cur_epoch += 1 + + def hours_left(self): + if len(self.time_list) == 0: + return 0 + num_epoch_left = self.num_epoch - self.cur_epoch + mean_epoch_time = np.array(self.time_list).mean() + hours_left = (mean_epoch_time * num_epoch_left) / 3600.0 # hours + return hours_left + + def print(self): + logger.info('est: {:.1}h', self.hours_left) + + +def format_e(n): + if n == 0: + return '0' + a = '%E' % n + return a.split('E')[0].rstrip('0').rstrip('.') + 'E' + a.split('E-0')[1] + + +def get_evalname(config): + # generate tag for the generated samples + tag = '' + if config.ddpm.model_var_type != 'fixedlarge': + tag += config.ddpm.model_var_type + + if not config.ddpm.ema: + tag += 'noema' + tag += f"s{config.trainer.seed}" + if config.data.te_max_sample_points != 2048: + tag += 'N%d' % config.data.te_max_sample_points + if config.eval_ddim_step > 0: + tag += 'ddim%d_%s%.1f' % ( + config.eval_ddim_step, + config.sde.ddim_skip_type, + config.sde.ddim_kappa) + githash = os.popen('git rev-parse HEAD').read().strip()[:5] + logger.info('git hash: {}', githash) + tag += f"H{githash}" + return tag + + +def get_expname(config): + if config.exp_name == '' or config.exp_name == 'none': + cate = config.data.cates if type( + config.data.cates) is str else config.data.cates[0] + cfg_file_name = '' + if config.data.type == 'datasets.neuralspline_datasets': + cfg_file_name += 'ns' + cfg_file_name += '%s/' % cate + if len(config.hash): + cfg_file_name += '%s_' % config.hash + cfg_file_name += f"{config.trainer.type.split('.')[-1].split('_')[0]}_" + if len(config.cmt): + cfg_file_name += config.cmt + '_' + + cfg_file_name += 'B%d' % config.data.batch_size + + if config.data.tr_max_sample_points != 2048: + cfg_file_name += 'N%d' % config.data.tr_max_sample_points + run_time = time.strftime('%m%d') + cfg_file_name = run_time + '/' + cfg_file_name + + else: + cfg_file_name = config.exp_name + return cfg_file_name diff --git a/utils/io_helper.py b/utils/io_helper.py new file mode 100644 index 0000000..c8d6a89 --- /dev/null +++ b/utils/io_helper.py @@ -0,0 +1,17 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import hashlib +import json + +def hash_str(file_name): + BUF_SIZE = 65536 # lets read stuff in 64kb chunks! + md5 = hashlib.md5() + data = file_name + md5.update(data.encode()) + hash_str = md5.hexdigest()[:6] + return hash_str diff --git a/utils/model_helper.py b/utils/model_helper.py new file mode 100644 index 0000000..a692542 --- /dev/null +++ b/utils/model_helper.py @@ -0,0 +1,138 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch +import torch.nn.functional as F +from loguru import logger +import torch +from torch.autograd import grad +import importlib +from utils.evaluation_metrics_fast import distChamferCUDA, emd_approx, distChamferCUDA_l1 + + +def loss_fn(predv, targetv, loss_type, point_dim, batch_size, loss_weight_emd=0.02, + loss_weight_cdnorm=1, + return_dict=False): + B = batch_size + output = {} + + if loss_type == 'dcd': + from evaluation.dist_aware_cd import calc_dcd + res = calc_dcd(predv, targetv) + loss = res[0] + output['print/rec_dcd'] = loss + + elif loss_type == 'cd1_sum_emd': # use l1 loss in chamfer distance, take the sum + dl, dr = distChamferCUDA_l1(predv, targetv, point_dim) + loss = dl + dr # .view(B,-1).sum(-1) + dr.view(B,-1).sum(-1) + output['print/rec_cd1_sum'] = loss + emd = emd_approx(predv, targetv) + emd = emd.view(B, -1)*predv.view(B, -1).shape[1] + output['print/rec_emd'] = emd + loss = loss + emd + + elif loss_type == 'cd1_sum': # use l1 loss in chamfer distance, take the sum + dl, dr = distChamferCUDA_l1(predv, targetv, point_dim) + loss = dl + dr # .view(B,-1).sum(-1) + dr.view(B,-1).sum(-1) + output['print/rec_cd1_sum'] = loss + + # use l2 loss in chamfer distance, take the sum over N points, but its mean over point dim (3) + elif loss_type == 'cd_sum': + dl, dr = distChamferCUDA(predv, targetv) + loss = dl.view(B, -1).sum(-1) + dr.view(B, -1).sum(-1) + output['print/rec_cd1_sum'] = loss + + elif loss_type == 'chamfer': + dl, dr = distChamferCUDA(predv, targetv) + loss = dl.view(B, -1).mean(-1) + dr.view(B, -1).mean(-1) + output['print/rec_cd'] = loss + + elif loss_type == 'mse_sum': + loss = F.mse_loss( + predv.contiguous().view(-1, point_dim), targetv.view(-1, point_dim), + reduction='sum') + output['print/rec_mse'] = loss + + elif loss_type == 'l1_sum': + loss = F.l1_loss( + predv.contiguous().view(-1, point_dim), targetv.view(-1, point_dim), + reduction='sum') + output['print/rec_l1'] = loss + + elif loss_type == 'l1_cd': + loss = F.l1_loss( + predv.contiguous().view(-1, point_dim), targetv.view(-1, point_dim), + reduction='sum') + output['print/rec_l1'] = loss + dl, dr = distChamferCUDA(predv, targetv) + cd_loss = dl.view(B, -1).sum(-1) + dr.view(B, -1).sum(-1) + output['print/rec_cd'] = cd_loss + loss = loss + cd_loss + + elif loss_type == 'mse': + loss = F.mse_loss( + predv.contiguous().view(-1, point_dim), targetv.view(-1, point_dim), + reduction='mean') + output['print/rec_mse'] = loss + + elif loss_type == 'emd': + emd = emd_approx(predv, targetv) + # dl.view(B,-1).mean(-1) + dr.view(B,-1).mean(-1) + loss = emd.view(B, -1) + output['print/rec_emd'] = loss + + elif loss_type == 'chamfer_emd': + dl, dr = distChamferCUDA(predv, targetv) + cd = dl.view(B, -1).mean(-1) + dr.view(B, -1).mean(-1) + cd = cd.view(B, -1) + emd = emd_approx(predv, targetv).view(B, -1) + loss = cd + emd * loss_weight_emd # balance the scale of two loss + output['print/rec_emd'] = emd.mean() + output['print/rec_weight_emd'] = loss_weight_emd + output['print/rec_cd'] = cd.mean() + + else: + raise ValueError(loss_type) + if return_dict: + return loss, output + return loss + + +def import_model(model_str): + logger.info('import: {}', model_str) + p, m = model_str.rsplit('.', 1) + mod = importlib.import_module(p) + Model = getattr(mod, m) + return Model + ## self.encoder = Model(zdim=latent_dim, input_dim=args.ddpm.input_dim, args=args) + + +class DataParallelPassthrough(torch.nn.parallel.DistributedDataParallel): + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.module, name) + + +def average_gradients(model, rank=-1): + size = float(dist.get_world_size()) + for name, param in model.named_parameters(): + if not param.requires_grad or param.grad is None: + continue + dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM, async_op=True) + param.grad.data /= size + torch.cuda.synchronize() + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters()) # if p.requires_grad) + + +def get_device(model): + param = next(model.parameters()) + return param.device diff --git a/utils/sr_utils.py b/utils/sr_utils.py new file mode 100644 index 0000000..4ecfb08 --- /dev/null +++ b/utils/sr_utils.py @@ -0,0 +1,115 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +""" copied and modified from https://github.com/NVlabs/LSGM/blob/5eae2f385c014f2250c3130152b6be711f6a3a5a/util/sr_utils.py """ +import torch +import torch.nn as nn +import torch.nn.functional as F +from loguru import logger + + +@torch.jit.script +def fused_abs_max_add(weight: torch.Tensor, loss: torch.Tensor) -> torch.Tensor: + loss += torch.max(torch.abs(weight)) + return loss + + +class SpectralNormCalculator: + def __init__(self, num_power_iter=4, custom_conv=False): + self.num_power_iter = num_power_iter + # increase the number of iterations for the first time + self.num_power_iter_init = 10 * num_power_iter + self.all_conv_layers = [] + # left/right singular vectors used for SR + self.sr_u = {} + self.sr_v = {} + self.all_bn_layers = [] + self.custom_conv = custom_conv + + def add_conv_layers(self, model): + for n, layer in model.named_modules(): + if self.custom_conv: + # add our customized conv layers + if isinstance(layer, Conv2D) or isinstance(layer, ARConv2d): + self.all_conv_layers.append(layer) + else: + if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Conv3d) or \ + isinstance(layer, nn.Conv1d) or isinstance(layer, nn.Linear): # add pytorch conv layers + self.all_conv_layers.append(layer) + + def add_bn_layers(self, model): + for n, layer in model.named_modules(): + if isinstance(layer, nn.BatchNorm2d) or isinstance(layer, nn.SyncBatchNorm) or \ + isinstance(layer, nn.GroupNorm): + self.all_bn_layers.append(layer) + + def spectral_norm_parallel(self): + """ This method computes spectral normalization for all conv layers in parallel. This method should be called + after calling the forward method of all the conv layers in each iteration. """ + + weights = {} # a dictionary indexed by the shape of weights + for l in self.all_conv_layers: + weight = l.weight_normalized if self.custom_conv else l.weight + if not isinstance(l, nn.Linear): + weight_mat = weight.view(weight.size(0), -1) + else: + weight_mat = weight + ## logger.info('mat weight: {} | weight: {}', weight_mat.shape, weight.shape) + + if weight_mat.shape not in weights: + weights[weight_mat.shape] = [] + + weights[weight_mat.shape].append(weight_mat) + + loss = 0 + for i in weights: + weights[i] = torch.stack(weights[i], dim=0) + with torch.no_grad(): + num_iter = self.num_power_iter + if i not in self.sr_u: + num_w, row, col = weights[i].shape + self.sr_u[i] = F.normalize(torch.ones( + num_w, row).normal_(0, 1).cuda(), dim=1, eps=1e-3) + self.sr_v[i] = F.normalize(torch.ones( + num_w, col).normal_(0, 1).cuda(), dim=1, eps=1e-3) + num_iter = self.num_power_iter_init + + for j in range(num_iter): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + self.sr_v[i] = F.normalize(torch.matmul(self.sr_u[i].unsqueeze(1), weights[i]).squeeze(1), + dim=1, eps=1e-3) # bx1xr * bxrxc --> bx1xc --> bxc + self.sr_u[i] = F.normalize(torch.matmul(weights[i], self.sr_v[i].unsqueeze(2)).squeeze(2), + dim=1, eps=1e-3) # bxrxc * bxcx1 --> bxrx1 --> bxr + + sigma = torch.matmul(self.sr_u[i].unsqueeze( + 1), torch.matmul(weights[i], self.sr_v[i].unsqueeze(2))) + loss += torch.sum(sigma) + return loss + + def batchnorm_loss(self): + loss = torch.zeros(size=()).cuda() + for l in self.all_bn_layers: + if l.affine: + loss = fused_abs_max_add(l.weight, loss) + + return loss + + def state_dict(self): + return { + 'sr_v': self.sr_v, + 'sr_u': self.sr_u + } + + def load_state_dict(self, state_dict, device): + # map the tensor to the device id of self.sr_v + for s in state_dict['sr_v']: + self.sr_v[s] = state_dict['sr_v'][s].to(device) + + for s in state_dict['sr_u']: + self.sr_u[s] = state_dict['sr_u'][s].to(device) diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..ff58436 --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,1532 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +"""copied and modified from https://github.com/NVlabs/LSGM/blob/5eae2f385c014f2250c3130152b6be711f6a3a5a/util/utils.py""" +from loguru import logger +from comet_ml import Experiment, ExistingExperiment +import wandb as WB +import os +import math +import shutil +import time +import sys +import types +from PIL import Image +import torch +import torch.nn as nn +import numpy as np +from torch import optim +import torch.distributed as dist +from torch.cuda.amp import autocast, GradScaler +USE_COMET = int(os.environ.get('USE_COMET', 1)) +USE_TFB = int(os.environ.get('USE_TFB', 0)) +USE_WB = int(os.environ.get('USE_WB', 0)) +print(f'utils/utils.py: USE_COMET={USE_COMET}, USE_WB={USE_WB}') + +class PixelNormal(object): + def __init__(self, param, fixed_log_scales=None): + size = param.size() + C = size[1] + if fixed_log_scales is None: + self.num_c = C // 2 + # B, 1 or 3, H, W + self.means = param[:, :self.num_c, :, :] + self.log_scales = torch.clamp( + param[:, self.num_c:, :, :], min=-7.0) # B, 1 or 3, H, W + raise NotImplementedError + else: + self.num_c = C + # B, 1 or 3, H, W + self.means = param + # B, 1 or 3, H, W + self.log_scales = view4D(fixed_log_scales, size) + + def get_params(self): + return self.means, self.log_scales, self.num_c + + def log_prob(self, samples): + B, C, H, W = samples.size() + assert C == self.num_c + + log_probs = -0.5 * torch.square(self.means - samples) * torch.exp(-2.0 * + self.log_scales) - self.log_scales - 0.9189385332 # -0.5*log(2*pi) + return log_probs + + def sample(self, t=1.): + z, rho = sample_normal_jit( + self.means, torch.exp(self.log_scales)*t) # B, 3, H, W + return z + + def log_prob_discrete(self, samples): + """ + Calculates discrete pixel probabilities. + """ + # samples should be in [-1, 1] already + B, C, H, W = samples.size() + assert C == self.num_c + + centered = samples - self.means + inv_stdv = torch.exp(- self.log_scales) + plus_in = inv_stdv * (centered + 1. / 255.) + cdf_plus = torch.distributions.Normal(0, 1).cdf(plus_in) + min_in = inv_stdv * (centered - 1. / 255.) + cdf_min = torch.distributions.Normal(0, 1).cdf(min_in) + log_cdf_plus = torch.log(torch.clamp(cdf_plus, min=1e-12)) + log_one_minus_cdf_min = torch.log(torch.clamp(1. - cdf_min, min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where(samples < -0.999, log_cdf_plus, torch.where(samples > 0.999, log_one_minus_cdf_min, + torch.log(torch.clamp(cdf_delta, min=1e-12)))) + + assert log_probs.size() == samples.size() + return log_probs + + def mean(self): + return self.means + + +class DummyGradScalar(object): + def __init__(self, *args, **kwargs): + pass + + def scale(self, input): + return input + + def update(self): + pass + + def state_dict(self): + return {} + + def load_state_dict(self, x): + pass + + def step(self, opt): + opt.step() + + def unscale_(self, x): + return x + + +def get_opt(params, cfgopt, use_ema, other_cfg=None): + if cfgopt.type == 'adam': + optimizer = optim.Adam(params, + lr=float(cfgopt.lr), + betas=(cfgopt.beta1, cfgopt.beta2), + weight_decay=cfgopt.weight_decay) + elif cfgopt.type == 'sgd': + optimizer = torch.optim.SGD(params, + lr=float(cfgopt.lr), + momentum=cfgopt.momentum) + elif cfgopt.type == 'adamax': + from utils.adamax import Adamax + logger.info('[Optimizer] Adamax, lr={}, weight_decay={}, eps={}', + cfgopt.lr, cfgopt.weight_decay, 1e-4) + optimizer = Adamax(params, float(cfgopt.lr), + weight_decay=args.weight_decay, eps=1e-4) + + else: + assert 0, "Optimizer type should be either 'adam' or 'sgd'" + if use_ema: + logger.info('use_ema') + ema_decay = 0.9999 + from .ema import EMA + optimizer = EMA(optimizer, ema_decay=ema_decay) + scheduler = optim.lr_scheduler.LambdaLR( + optimizer, lr_lambda=lambda x: 1.0) # constant lr + scheduler_type = getattr(cfgopt, "scheduler", None) + if scheduler_type is not None and len(scheduler_type) > 0: + logger.info('get scheduler_type: {}', scheduler_type) + if scheduler_type == 'exponential': + decay = float(getattr(cfgopt, "step_decay", 0.1)) + scheduler = optim.lr_scheduler.ExponentialLR(optimizer, decay) + elif scheduler_type == 'step': + step_size = int(getattr(cfgopt, "step_epoch", 500)) + decay = float(getattr(cfgopt, "step_decay", 0.1)) + scheduler = optim.lr_scheduler.StepLR(optimizer, + step_size=step_size, + gamma=decay) + elif scheduler_type == 'linear': # use default setting from shapeLatent + start_epoch = int(getattr(cfgopt, 'sched_start_epoch', 200*1e3)) + end_epoch = int(getattr(cfgopt, 'sched_end_epoch', 400*1e3)) + end_lr = float(getattr(cfgopt, 'end_lr', 1e-4)) + start_lr = cfgopt.lr + + def lambda_rule(epoch): + if epoch <= start_epoch: + return 1.0 + elif epoch <= end_epoch: + total = end_epoch - start_epoch + delta = epoch - start_epoch + frac = delta / total + return (1 - frac) * 1.0 + frac * (end_lr / start_lr) + else: + return end_lr / start_lr + scheduler = optim.lr_scheduler.LambdaLR(optimizer, + lr_lambda=lambda_rule) + + elif scheduler_type == 'lambda': # linear': + step_size = int(getattr(cfgopt, "step_epoch", 2000)) + final_ratio = float(getattr(cfgopt, "final_ratio", 0.01)) + start_ratio = float(getattr(cfgopt, "start_ratio", 0.5)) + duration_ratio = float(getattr(cfgopt, "duration_ratio", 0.45)) + + def lambda_rule(ep): + lr_l = 1.0 - min( + 1, + max(0, ep - start_ratio * step_size) / + float(duration_ratio * step_size)) * (1 - final_ratio) + return lr_l + + scheduler = optim.lr_scheduler.LambdaLR(optimizer, + lr_lambda=lambda_rule) + + elif scheduler_type == 'cosine_anneal_nocycle': + ## logger.info('scheduler_type: {}', scheduler_type) + assert(other_cfg is not None) + final_lr_ratio = float(getattr(cfgopt, "final_lr_ratio", 0.01)) + eta_min = float(cfgopt.lr) * final_lr_ratio + eta_max = float(cfgopt.lr) + + total_epoch = int(other_cfg.trainer.epochs) + ##getattr(cfgopt, "step_epoch", 2000) + start_ratio = float(getattr(cfgopt, "start_ratio", 0.6)) + T_max = total_epoch * (1 - start_ratio) + + def lambda_rule(ep): + curr_ep = max(0., ep - start_ratio * total_epoch) + lr = eta_min + 0.5 * (eta_max - eta_min) * ( + 1 + np.cos(np.pi * curr_ep / T_max)) + lr_l = lr / eta_max + return lr_l + + scheduler = optim.lr_scheduler.LambdaLR(optimizer, + lr_lambda=lambda_rule) + + else: + assert 0, "args.schedulers should be either 'exponential' or 'linear' or 'step'" + return optimizer, scheduler + + +class AvgrageMeter(object): + + def __init__(self): + self.reset() + + def reset(self): + self.avg = 0 + self.sum = 0 + self.cnt = 0 + + def update(self, val, n=1): + self.sum += val * n + self.cnt += n + self.avg = self.sum / self.cnt + + +class ExpMovingAvgrageMeter(object): + + def __init__(self, momentum=0.9): + self.momentum = momentum + self.reset() + + def reset(self): + self.avg = 0 + + def update(self, val): + self.avg = (1. - self.momentum) * self.avg + self.momentum * val + + +class DummyDDP(nn.Module): + def __init__(self, model): + super(DummyDDP, self).__init__() + self.module = model + + def forward(self, *input, **kwargs): + return self.module(*input, **kwargs) + + +def count_parameters_in_M(model): + return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6 + + +def save_checkpoint(state, is_best, save): + filename = os.path.join(save, 'checkpoint.pth.tar') + torch.save(state, filename) + if is_best: + best_filename = os.path.join(save, 'model_best.pth.tar') + shutil.copyfile(filename, best_filename) + + +def save(model, model_path): + torch.save(model.state_dict(), model_path) + + +def load(model, model_path): + model.load_state_dict(torch.load(model_path)) + + +# def create_exp_dir(path, scripts_to_save=None): +# if not os.path.exists(path): +# os.makedirs(path, exist_ok=True) +# print('Experiment dir : {}'.format(path)) +# +# if scripts_to_save is not None: +# if not os.path.exists(os.path.join(path, 'scripts')): +# os.mkdir(os.path.join(path, 'scripts')) +# for script in scripts_to_save: +# dst_file = os.path.join(path, 'scripts', os.path.basename(script)) +# shutil.copyfile(script, dst_file) +# + +# class Logger(object): +# def __init__(self, rank, save): +# # other libraries may set logging before arriving at this line. +# # by reloading logging, we can get rid of previous configs set by other libraries. +# from importlib import reload +# reload(logging) +# self.rank = rank +# if self.rank == 0: +# log_format = '%(asctime)s %(message)s' +# logging.basicConfig(stream=sys.stdout, level=logging.INFO, +# format=log_format, datefmt='%m/%d %I:%M:%S %p') +# fh = logging.FileHandler(os.path.join(save, 'log.txt')) +# fh.setFormatter(logging.Formatter(log_format)) +# logging.getLogger().addHandler(fh) +# self.start_time = time.time() +# +# def info(self, string, *args): +# if self.rank == 0: +# elapsed_time = time.time() - self.start_time +# elapsed_time = time.strftime( +# '(Elapsed: %H:%M:%S) ', time.gmtime(elapsed_time)) +# if isinstance(string, str): +# string = elapsed_time + string +# else: +# logging.info(elapsed_time) +# logging.info(string, *args) + +def flatten_dict(dd, separator='_', prefix=''): + return {prefix + separator + k if prefix else k: v for kk, vv in dd.items() + for k, v in flatten_dict(vv, separator, kk).items()} \ + if isinstance(dd, dict) else {prefix: dd} + + +class Writer(object): + def __init__(self, rank=0, save=None, exp=None, wandb=False): + self.rank = rank + self.exp = None + self.wandb = False + self.meter_dict = {} + if self.rank == 0: + self.exp = exp + if USE_TFB and save is not None: + logger.info('init TFB: {}', save) + from torch.utils.tensorboard import SummaryWriter + self.writer = SummaryWriter(log_dir=save, flush_secs=20) + else: + logger.info('Not init TFB') + self.writer = None + if self.exp is not None and save is not None: + with open(os.path.join(save, 'url.txt'), 'a') as f: + f.write(self.exp.url) + f.write('\n') + self.wandb = wandb + else: + logger.info('rank={}, init writer as a blackhole', rank) + + def set_model_graph(self, *args, **kwargs): + if self.rank == 0 and self.exp is not None: + self.exp.set_model_graph(*args, **kwargs) + + @property + def url(self): + if self.exp is not None: + return self.exp.url + else: + return 'none' + + def add_hparams(self, cfg, args): # **kwargs): + if self.exp is not None: + self.exp.log_parameters(flatten_dict(cfg)) + self.exp.log_parameters(flatten_dict(args)) + if self.wandb: + WB.config.update(flatten_dict(cfg)) + WB.config.update(flatten_dict(args)) + + def avg_meter(self, name, value, step=None, epoch=None): + if self.rank == 0: + if name not in self.meter_dict: + self.meter_dict[name] = AvgrageMeter() + self.meter_dict[name].update(value) + + def upload_meter(self, step=None, epoch=None): + for name, value in self.meter_dict.items(): + self.add_scalar(name, value.avg, step=step, epoch=epoch) + self.meter_dict = {} + + def add_scalar(self, *args, **kwargs): + if self.rank == 0 and self.writer is not None: + if 'step' in kwargs: + self.writer.add_scalar(*args, + global_step=kwargs['step']) + else: + self.writer.add_scalar(*args, **kwargs) + + if self.exp is not None: + self.exp.log_metric(*args, **kwargs) + if self.wandb: + name = args[0] + v = args[1] + WB.log({name: v}) + + def log_model(self, name, path): + pass + + def log_other(self, name, value): + if self.rank == 0 and self.exp is not None: + self.exp.log_other(name, value) + # if self.rank == 0 and self.exp is not None: + # self.exp.log_model(name, path) + + def watch(self, model): + if self.wandb: + WB.watch(model) + + def log_points_3d(self, scene_name, points, step=0): # *args, **kwargs): + if self.rank == 0 and self.exp is not None: + self.exp.log_points_3d(*args, **kwargs) + if self.wandb: + WB.log({"point_cloud": WB.Object3D(points)}) + + def add_figure(self, *args, **kwargs): + if self.rank == 0 and self.writer is not None: + self.writer.add_figure(*args, **kwargs) + + def add_image(self, *args, **kwargs): + if self.rank == 0 and self.writer is not None: + self.writer.add_image(*args, **kwargs) + self.writer.flush() + if self.exp is not None: + name, img, i = args + if isinstance(img, Image.Image): + # logger.debug('log PIL Imgae: {}, {}', name, i) + self.exp.log_image(img, name, step=i) + elif type(img) is str: + # logger.debug('log str image: {}, {}: {}', name, i, img) + self.exp.log_image(img, name, step=i) + elif torch.is_tensor(img): + if img.shape[0] in [3, 4] and len(img.shape) == 3: # 3,H,W + img = img.permute(1, 2, 0).contiguous() # 3,H,W -> H,W,3 + if img.max() < 100: # [0-1] + ndarr = img.mul(255).add_(0.5).clamp_( + 0, 255).to('cpu') # .squeeze() + ndarr = ndarr.numpy().astype(np.uint8) + # .reshape(-1, ndarr.shape[-1])) + im = Image.fromarray(ndarr) + self.exp.log_image(im, name, step=i) + else: + im = img.to('cpu').numpy() + self.exp.log_image(im, name, step=i) + + elif isinstance(img, (np.ndarray, np.generic)): + if img.shape[0] == 3 and len(img.shape) == 3: # 3,H,W + img = img.transpose(1, 2, 0) + self.exp.log_image(img, name, step=i) + if self.wandb and torch.is_tensor(img) and self.rank == 0: + ## print(img.shape, img.max(), img.type()) + WB.log({name: WB.Image(img.numpy())}) + + def add_histogram(self, *args, **kwargs): + if self.rank == 0 and self.writer is not None: + self.writer.add_histogram(*args, **kwargs) + if self.exp is not None: + name, value, step = args + self.exp.log_histogram_3d(value, name, step) + # *args, **kwargs) + + def add_histogram_if(self, write, *args, **kwargs): + if write and False: # Used for debugging. + self.add_histogram(*args, **kwargs) + + def close(self, *args, **kwargs): + if self.rank == 0 and self.writer is not None: + self.writer.close() + + def log_asset(self, *args, **kwargs): + if self.exp is not None: + self.exp.log_asset(*args, **kwargs) + + +def common_init(rank, seed, save_dir, comet_key=''): + # we use different seeds per gpu. But we sync the weights after model initialization. + logger.info('[common-init] at rank={}, seed={}', rank, seed) + torch.manual_seed(rank + seed) + np.random.seed(rank + seed) + torch.cuda.manual_seed(rank + seed) + torch.cuda.manual_seed_all(rank + seed) + torch.backends.cudnn.benchmark = True + + # prepare logging and tensorboard summary + #logging = Logger(rank, save_dir) + logging = None + if rank == 0: + if os.path.exists('.comet_api'): + comet_args = json.load(open('.comet_api', 'r')) + exp = Experiment(display_summary_level=0, + disabled=USE_COMET == 0, + **comet_args) + exp.set_name(save_dir.split('exp/')[-1]) + exp.set_cmd_args() + exp.log_code(folder='./models/') + exp.log_code(folder='./trainers/') + exp.log_code(folder='./utils/') + exp.log_code(folder='./datasets/') + else: + exp = None + + if os.path.exists('.wandb_api'): + wb_args = json.load(open('.wandb_api', 'r')) + wb_dir = '../exp/wandb/' if not os.path.exists( + '/workspace/result') else '/workspace/result/wandb/' + if not os.path.exists(wb_dir): + os.makedirs(wb_dir) + WB.init( + project=wb_args['project'], + entity=wb_args['entity'], + name=save_dir.split('exp/')[-1], + dir=wb_dir + ) + wandb = True + else: + wandb = False + else: + exp = None + wandb = False + writer = Writer(rank, save_dir, exp, wandb) + logger.info('[common-init] DONE') + + return logging, writer + + +def reduce_tensor(tensor, world_size): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= world_size + return rt + + +def get_stride_for_cell_type(cell_type): + if cell_type.startswith('normal') or cell_type.startswith('combiner'): + stride = 1 + elif cell_type.startswith('down'): + stride = 2 + elif cell_type.startswith('up'): + stride = -1 + else: + raise NotImplementedError(cell_type) + + return stride + + +def get_cout(cin, stride): + if stride == 1: + cout = cin + elif stride == -1: + cout = cin // 2 + elif stride == 2: + cout = 2 * cin + + return cout + + +def kl_balancer_coeff(num_scales, groups_per_scale, fun='square'): + if fun == 'equal': + coeff = torch.cat([torch.ones(groups_per_scale[num_scales - i - 1]) + for i in range(num_scales)], dim=0).cuda() + elif fun == 'linear': + coeff = torch.cat([(2 ** i) * torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], + dim=0).cuda() + elif fun == 'sqrt': + coeff = torch.cat( + [np.sqrt(2 ** i) * torch.ones(groups_per_scale[num_scales - i - 1]) + for i in range(num_scales)], + dim=0).cuda() + elif fun == 'square': + coeff = torch.cat( + [np.square(2 ** i) / groups_per_scale[num_scales - i - 1] * torch.ones(groups_per_scale[num_scales - i - 1]) + for i in range(num_scales)], dim=0).cuda() + else: + raise NotImplementedError + # convert min to 1. + coeff /= torch.min(coeff) + return coeff + + +def kl_per_group(kl_all): + kl_vals = torch.mean(kl_all, dim=0) + kl_coeff_i = torch.abs(kl_all) + kl_coeff_i = torch.mean(kl_coeff_i, dim=0, keepdim=True) + 0.01 + + return kl_coeff_i, kl_vals + + +def rec_balancer(rec_all, rec_coeff=1.0, npoints=None): + # layer depth increase, alpha_i increase, 1/alpha_i decrease; kl_coeff decrease + # the rec with more points should have higher loss + min_points = min(npoints) + coeff = [] + rec_loss = 0 + assert(len(rec_all) == len(npoints)) + for ni, n in enumerate(npoints): + c = rec_coeff*np.sqrt(n/min_points) + rec_loss += rec_all[ni] * c + coeff.append(c) # the smallest points' loss weight is 1 + + return rec_loss, coeff, rec_all + + +def kl_balancer(kl_all, kl_coeff=1.0, kl_balance=False, alpha_i=None): + # layer depth increase, alpha_i increase, 1/alpha_i decrease; kl_coeff decrease + if kl_balance and kl_coeff < 1.0: + alpha_i = alpha_i.unsqueeze(0) + + kl_all = torch.stack(kl_all, dim=1) + kl_coeff_i, kl_vals = kl_per_group(kl_all) + total_kl = torch.sum(kl_coeff_i) + # kl = ( sum * kl / alpha ) + kl_coeff_i = kl_coeff_i / alpha_i * total_kl + kl_coeff_i = kl_coeff_i / torch.mean(kl_coeff_i, dim=1, keepdim=True) + kl = torch.sum(kl_all * kl_coeff_i.detach(), dim=1) + + # for reporting + kl_coeffs = kl_coeff_i.squeeze(0) + else: + kl_all = torch.stack(kl_all, dim=1) + kl_vals = torch.mean(kl_all, dim=0) + kl = torch.sum(kl_all, dim=1) + kl_coeffs = torch.ones(size=(len(kl_vals),)) + + return kl_coeff * kl, kl_coeffs, kl_vals + + +def kl_per_group_vada(all_log_q, all_neg_log_p): + assert(len(all_log_q) == len(all_neg_log_p) + ), f'get len={len(all_log_q)} and {len(all_neg_log_p)}' + + kl_all_list = [] + kl_diag = [] + for log_q, neg_log_p in zip(all_log_q, all_neg_log_p): + kl_diag.append(torch.mean( + torch.sum(neg_log_p + log_q, dim=[2, 3]), dim=0)) + kl_all_list.append(torch.sum(neg_log_p + log_q, + dim=[1, 2, 3])) # sum over D,H,W + + # kl_all = torch.stack(kl_all, dim=1) # batch x num_total_groups + kl_vals = torch.mean(torch.stack(kl_all_list, dim=1), + dim=0) # mean per group + + return kl_all_list, kl_vals, kl_diag + + +def kl_coeff(step, total_step, constant_step, min_kl_coeff, max_kl_coeff): + # return max(min(max_kl_coeff * (step - constant_step) / total_step, max_kl_coeff), min_kl_coeff) + return max(min(min_kl_coeff + (max_kl_coeff - min_kl_coeff) * (step - constant_step) / total_step, max_kl_coeff), min_kl_coeff) + + +def log_iw(decoder, x, log_q, log_p, crop=False): + recon = reconstruction_loss(decoder, x, crop) + return - recon - log_q + log_p + + +def reconstruction_loss(decoder, x, crop=False): + + recon = decoder.log_p(x) + if crop: + recon = recon[:, :, 2:30, 2:30] + + if isinstance(decoder, DiscMixLogistic): + return - torch.sum(recon, dim=[1, 2]) # summation over RGB is done. + else: + return - torch.sum(recon, dim=[1, 2, 3]) + + +def vae_terms(all_log_q, all_eps): + + # compute kl + kl_all = [] + kl_diag = [] + log_p, log_q = 0., 0. + for log_q_conv, eps in zip(all_log_q, all_eps): + log_p_conv = log_p_standard_normal(eps) + kl_per_var = log_q_conv - log_p_conv + kl_diag.append(torch.mean(torch.sum(kl_per_var, dim=[2, 3]), dim=0)) + kl_all.append(torch.sum(kl_per_var, dim=[1, 2, 3])) + log_q += torch.sum(log_q_conv, dim=[1, 2, 3]) + log_p += torch.sum(log_p_conv, dim=[1, 2, 3]) + return log_q, log_p, kl_all, kl_diag + + +def sum_log_q(all_log_q): + log_q = 0. + for log_q_conv in all_log_q: + log_q += torch.sum(log_q_conv, dim=[1, 2, 3]) + + return log_q + + +def cross_entropy_normal(all_eps): + + cross_entropy = 0. + neg_log_p_per_group = [] + for eps in all_eps: + neg_log_p_conv = - log_p_standard_normal(eps) + neg_log_p = torch.sum(neg_log_p_conv, dim=[1, 2, 3]) + cross_entropy += neg_log_p + neg_log_p_per_group.append(neg_log_p_conv) + + return cross_entropy, neg_log_p_per_group + + +def tile_image(batch_image, n, m=None): + if m is None: + m = n + assert n * m == batch_image.size(0) + channels, height, width = batch_image.size( + 1), batch_image.size(2), batch_image.size(3) + batch_image = batch_image.view(n, m, channels, height, width) + batch_image = batch_image.permute(2, 0, 3, 1, 4) # n, height, n, width, c + batch_image = batch_image.contiguous().view(channels, n * height, m * width) + return batch_image + + +def average_gradients_naive(params, is_distributed): + """ Gradient averaging. """ + if is_distributed: + size = float(dist.get_world_size()) + for param in params: + if param.requires_grad: + param.grad.data /= size + dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) + + +def average_gradients(params, is_distributed): + """ Gradient averaging. """ + if is_distributed: + if isinstance(params, types.GeneratorType): + params = [p for p in params] + + size = float(dist.get_world_size()) + grad_data = [] + grad_size = [] + grad_shapes = [] + # Gather all grad values + for param in params: + if param.requires_grad: + if param.grad is not None: + grad_size.append(param.grad.data.numel()) + grad_shapes.append(list(param.grad.data.shape)) + grad_data.append(param.grad.data.flatten()) + grad_data = torch.cat(grad_data).contiguous() + + # All-reduce grad values + grad_data /= size + dist.all_reduce(grad_data, op=dist.ReduceOp.SUM) + + # Put back the reduce grad values to parameters + base = 0 + i = 0 + for param in params: + if param.requires_grad and param.grad is not None: + param.grad.data = grad_data[base:base + + grad_size[i]].view(grad_shapes[i]) + base += grad_size[i] + i += 1 + + +def average_params(params, is_distributed): + """ parameter averaging. """ + if is_distributed: + size = float(dist.get_world_size()) + for param in params: + param.data /= size + dist.all_reduce(param.data, op=dist.ReduceOp.SUM) + + +def average_tensor(t, is_distributed): + if is_distributed: + size = float(dist.get_world_size()) + dist.all_reduce(t.data, op=dist.ReduceOp.SUM) + t.data /= size + + +def broadcast_params(params, is_distributed): + if is_distributed: + for param in params: + dist.broadcast(param.data, src=0) + + +def num_output(dataset): + if dataset in {'mnist', 'omniglot'}: + return 28 * 28 + elif dataset == 'cifar10': + return 3 * 32 * 32 + elif dataset.startswith('celeba') or dataset.startswith('imagenet') or dataset.startswith('lsun'): + size = int(dataset.split('_')[-1]) + return 3 * size * size + elif dataset == 'ffhq': + return 3 * 256 * 256 + else: + raise NotImplementedError + + +def get_input_size(dataset): + if dataset in {'mnist', 'omniglot'}: + return 32 + elif dataset == 'cifar10': + return 32 + elif dataset.startswith('celeba') or dataset.startswith('imagenet') or dataset.startswith('lsun'): + size = int(dataset.split('_')[-1]) + return size + elif dataset == 'ffhq': + return 256 + elif dataset.startswith('shape'): + return 1 # 2048 + else: + raise NotImplementedError + + +def get_bpd_coeff(dataset): + n = num_output(dataset) + return 1. / np.log(2.) / n + + +def get_channel_multiplier(dataset, num_scales): + if dataset in {'cifar10', 'omniglot'}: + mult = (1, 1, 1) + elif dataset in {'celeba_256', 'ffhq', 'lsun_church_256'}: + if num_scales == 3: + mult = (1, 1, 1) # used for prior at 16 + elif num_scales == 4: + mult = (1, 2, 2, 2) # used for prior at 32 + elif num_scales == 5: + mult = (1, 1, 2, 2, 2) # used for prior at 64 + elif dataset == 'mnist': + mult = (1, 1) + else: + mult = (1, 1) + # raise NotImplementedError + + return mult + + +def get_attention_scales(dataset): + if dataset in {'cifar10', 'omniglot'}: + attn = (True, False, False) + elif dataset in {'celeba_256', 'ffhq', 'lsun_church_256'}: + # attn = (False, True, False, False) # used for 32 + attn = (False, False, True, False, False) # used for 64 + elif dataset == 'mnist': + attn = (True, False) + else: + raise NotImplementedError + + return attn + + +def change_bit_length(x, num_bits): + if num_bits != 8: + x = torch.floor(x * 255 / 2 ** (8 - num_bits)) + x /= (2 ** num_bits - 1) + return x + + +def view4D(t, size, inplace=True): + """ + Equal to view(-1, 1, 1, 1).expand(size) + Designed because of this bug: + https://github.com/pytorch/pytorch/pull/48696 + """ + if inplace: + return t.unsqueeze_(-1).unsqueeze_(-1).unsqueeze_(-1).expand(size) + else: + return t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(size) + + +def get_arch_cells(arch_type, use_se): + if arch_type == 'res_mbconv': + arch_cells = dict() + arch_cells['normal_enc'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se} + arch_cells['down_enc'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se} + arch_cells['normal_dec'] = { + 'conv_branch': ['mconv_e6k5g0'], 'se': use_se} + arch_cells['up_dec'] = {'conv_branch': ['mconv_e6k5g0'], 'se': use_se} + arch_cells['normal_pre'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se} + arch_cells['down_pre'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se} + arch_cells['normal_post'] = { + 'conv_branch': ['mconv_e3k5g0'], 'se': use_se} + arch_cells['up_post'] = {'conv_branch': ['mconv_e3k5g0'], 'se': use_se} + arch_cells['ar_nn'] = [''] + elif arch_type == 'res_bnswish': + arch_cells = dict() + arch_cells['normal_enc'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se} + arch_cells['down_enc'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se} + arch_cells['normal_dec'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se} + arch_cells['up_dec'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se} + arch_cells['normal_pre'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se} + arch_cells['down_pre'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se} + arch_cells['normal_post'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se} + arch_cells['up_post'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se} + arch_cells['ar_nn'] = [''] + elif arch_type == 'res_bnswish2': + arch_cells = dict() + arch_cells['normal_enc'] = { + 'conv_branch': ['res_bnswish_x2'], 'se': use_se} + arch_cells['down_enc'] = { + 'conv_branch': ['res_bnswish_x2'], 'se': use_se} + arch_cells['normal_dec'] = { + 'conv_branch': ['res_bnswish_x2'], 'se': use_se} + arch_cells['up_dec'] = {'conv_branch': [ + 'res_bnswish_x2'], 'se': use_se} + arch_cells['normal_pre'] = { + 'conv_branch': ['res_bnswish_x2'], 'se': use_se} + arch_cells['down_pre'] = { + 'conv_branch': ['res_bnswish_x2'], 'se': use_se} + arch_cells['normal_post'] = { + 'conv_branch': ['res_bnswish_x2'], 'se': use_se} + arch_cells['up_post'] = {'conv_branch': [ + 'res_bnswish_x2'], 'se': use_se} + arch_cells['ar_nn'] = [''] + elif arch_type == 'res_mbconv_attn': + arch_cells = dict() + arch_cells['normal_enc'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish', ], 'se': use_se, 'attn_type': 'attn'} + arch_cells['down_enc'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se, 'attn_type': 'attn'} + arch_cells['normal_dec'] = {'conv_branch': [ + 'mconv_e6k5g0'], 'se': use_se, 'attn_type': 'attn'} + arch_cells['up_dec'] = {'conv_branch': [ + 'mconv_e6k5g0'], 'se': use_se, 'attn_type': 'attn'} + arch_cells['normal_pre'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se} + arch_cells['down_pre'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se} + arch_cells['normal_post'] = { + 'conv_branch': ['mconv_e3k5g0'], 'se': use_se} + arch_cells['up_post'] = {'conv_branch': ['mconv_e3k5g0'], 'se': use_se} + arch_cells['ar_nn'] = [''] + elif arch_type == 'res_mbconv_attn_half': + arch_cells = dict() + arch_cells['normal_enc'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se} + arch_cells['down_enc'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se} + arch_cells['normal_dec'] = {'conv_branch': [ + 'mconv_e6k5g0'], 'se': use_se, 'attn_type': 'attn'} + arch_cells['up_dec'] = {'conv_branch': [ + 'mconv_e6k5g0'], 'se': use_se, 'attn_type': 'attn'} + arch_cells['normal_pre'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se} + arch_cells['down_pre'] = {'conv_branch': [ + 'res_bnswish', 'res_bnswish'], 'se': use_se} + arch_cells['normal_post'] = { + 'conv_branch': ['mconv_e3k5g0'], 'se': use_se} + arch_cells['up_post'] = {'conv_branch': ['mconv_e3k5g0'], 'se': use_se} + arch_cells['ar_nn'] = [''] + else: + raise NotImplementedError + + return arch_cells + + +def get_arch_cells_denoising(arch_type, use_se, apply_sqrt2): + if arch_type == 'res_mbconv': + arch_cells = dict() + arch_cells['normal_enc_diff'] = { + 'conv_branch': ['mconv_e6k5g0_gn'], 'se': use_se} + arch_cells['down_enc_diff'] = { + 'conv_branch': ['mconv_e6k5g0_gn'], 'se': use_se} + arch_cells['normal_dec_diff'] = { + 'conv_branch': ['mconv_e6k5g0_gn'], 'se': use_se} + arch_cells['up_dec_diff'] = { + 'conv_branch': ['mconv_e6k5g0_gn'], 'se': use_se} + elif arch_type == 'res_ho': + arch_cells = dict() + arch_cells['normal_enc_diff'] = { + 'conv_branch': ['res_gnswish_x2'], 'se': use_se} + arch_cells['down_enc_diff'] = { + 'conv_branch': ['res_gnswish_x2'], 'se': use_se} + arch_cells['normal_dec_diff'] = { + 'conv_branch': ['res_gnswish_x2'], 'se': use_se} + arch_cells['up_dec_diff'] = { + 'conv_branch': ['res_gnswish_x2'], 'se': use_se} + elif arch_type == 'res_ho_p1': + arch_cells = dict() + arch_cells['normal_enc_diff'] = { + 'conv_branch': ['res_gnswish_x2_p1'], 'se': use_se} + arch_cells['down_enc_diff'] = { + 'conv_branch': ['res_gnswish_x2_p1'], 'se': use_se} + arch_cells['normal_dec_diff'] = { + 'conv_branch': ['res_gnswish_x2_p1'], 'se': use_se} + arch_cells['up_dec_diff'] = { + 'conv_branch': ['res_gnswish_x2_p1'], 'se': use_se} + elif arch_type == 'res_ho_attn': + arch_cells = dict() + arch_cells['normal_enc_diff'] = { + 'conv_branch': ['res_gnswish_x2'], 'se': use_se} + arch_cells['down_enc_diff'] = { + 'conv_branch': ['res_gnswish_x2'], 'se': use_se} + arch_cells['normal_dec_diff'] = { + 'conv_branch': ['res_gnswish_x2'], 'se': use_se} + arch_cells['up_dec_diff'] = { + 'conv_branch': ['res_gnswish_x2'], 'se': use_se} + else: + raise NotImplementedError + + for k in arch_cells: + arch_cells[k]['apply_sqrt2'] = apply_sqrt2 + + return arch_cells + + +def groups_per_scale(num_scales, num_groups_per_scale): + g = [] + n = num_groups_per_scale + for s in range(num_scales): + assert n >= 1 + g.append(n) + return g + + +#class PositionalEmbedding(nn.Module): +# def __init__(self, embedding_dim, scale): +# super(PositionalEmbedding, self).__init__() +# self.embedding_dim = embedding_dim +# self.scale = scale +# +# def forward(self, timesteps): +# assert len(timesteps.shape) == 1 +# timesteps = timesteps * self.scale +# half_dim = self.embedding_dim // 2 +# emb = math.log(10000) / (half_dim - 1) +# emb = torch.exp(torch.arange(half_dim) * -emb) +# emb = emb.to(device=timesteps.device) +# emb = timesteps[:, None] * emb[None, :] +# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) +# return emb +# +# +#class RandomFourierEmbedding(nn.Module): +# def __init__(self, embedding_dim, scale): +# super(RandomFourierEmbedding, self).__init__() +# self.w = nn.Parameter(torch.randn( +# size=(1, embedding_dim // 2)) * scale, requires_grad=False) +# +# def forward(self, timesteps): +# emb = torch.mm(timesteps[:, None], self.w * 2 * 3.14159265359) +# return torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) +# +# +#def init_temb_fun(embedding_type, embedding_scale, embedding_dim): +# if embedding_type == 'positional': +# temb_fun = PositionalEmbedding(embedding_dim, embedding_scale) +# elif embedding_type == 'fourier': +# temb_fun = RandomFourierEmbedding(embedding_dim, embedding_scale) +# else: +# raise NotImplementedError +# +# return temb_fun + + +def symmetrize_image_data(images): + return 2.0 * images - 1.0 + + +def unsymmetrize_image_data(images): + return (images + 1.) / 2. + + +def normalize_symmetric(images): + """ + Normalize images by dividing the largest intensity. Used for visualizing the intermediate steps. + """ + b = images.shape[0] + m, _ = torch.max(torch.abs(images).view(b, -1), dim=1) + images /= (m.view(b, 1, 1, 1) + 1e-3) + + return images + + +@torch.jit.script +def soft_clamp5(x: torch.Tensor): + # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5] + return x.div(5.).tanh_().mul(5.) + + +@torch.jit.script +def soft_clamp(x: torch.Tensor, a: torch.Tensor): + return x.div(a).tanh_().mul(a) + + +class SoftClamp5(nn.Module): + def __init__(self): + super(SoftClamp5, self).__init__() + + def forward(self, x): + return soft_clamp5(x) + + +def override_architecture_fields(args, stored_args, logging): + # list of architecture parameters used in NVAE: + architecture_fields = ['arch_instance', 'num_nf', 'num_latent_scales', 'num_groups_per_scale', + 'num_latent_per_group', 'num_channels_enc', 'num_preprocess_blocks', + 'num_preprocess_cells', 'num_cell_per_cond_enc', 'num_channels_dec', + 'num_postprocess_blocks', 'num_postprocess_cells', 'num_cell_per_cond_dec', + 'decoder_dist', 'num_x_bits', 'log_sig_q_scale', 'latent_grad_cutoff', + 'progressive_output_vae', 'progressive_input_vae', 'channel_mult'] + + # backward compatibility + """ We have broken backward compatibility. No need to se these manually + if not hasattr(stored_args, 'log_sig_q_scale'): + logging.info('*** Setting %s manually ****', 'log_sig_q_scale') + setattr(stored_args, 'log_sig_q_scale', 5.) + + if not hasattr(stored_args, 'latent_grad_cutoff'): + logging.info('*** Setting %s manually ****', 'latent_grad_cutoff') + setattr(stored_args, 'latent_grad_cutoff', 0.) + + if not hasattr(stored_args, 'progressive_input_vae'): + logging.info('*** Setting %s manually ****', 'progressive_input_vae') + setattr(stored_args, 'progressive_input_vae', 'none') + + if not hasattr(stored_args, 'progressive_output_vae'): + logging.info('*** Setting %s manually ****', 'progressive_output_vae') + setattr(stored_args, 'progressive_output_vae', 'none') + """ + + for f in architecture_fields: + if not hasattr(args, f) or getattr(args, f) != getattr(stored_args, f): + logging.info('Setting %s from loaded checkpoint', f) + setattr(args, f, getattr(stored_args, f)) + + +def init_processes(rank, size, fn, args, config): + """ Initialize the distributed environment. """ + os.environ['MASTER_ADDR'] = args.master_address + os.environ['MASTER_PORT'] = '6020' + if args.num_proc_node == 1: + import socket + import errno + a_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + for p in range(6010, 6030): + location = (args.master_address, p) # "127.0.0.1", p) + try: + a_socket.bind((args.master_address, p)) + logger.debug('set port as {}', p) + os.environ['MASTER_PORT'] = '%d' % p + a_socket.close() + break + except socket.error as e: + a = 0 + # if e.errno == errno.EADDRINUSE: + # # logger.debug("Port {} is already in use", p) + # else: + # logger.debug(e) + + logger.info('init_process: rank={}, world_size={}', rank, size) + torch.cuda.set_device(args.local_rank) + dist.init_process_group( + backend='nccl', init_method='env://', rank=rank, world_size=size) + fn(args, config) + logger.info('barrier: rank={}, world_size={}', rank, size) + dist.barrier() + logger.info('skip destroy_process_group: rank={}, world_size={}', rank, size) + # dist.destroy_process_group() + logger.info('skip destroy fini') + + +def sample_rademacher_like(y): + return torch.randint(low=0, high=2, size=y.shape, device='cuda') * 2 - 1 + + +def sample_gaussian_like(y): + return torch.randn_like(y, device='cuda') + + +def trace_df_dx_hutchinson(f, x, noise, no_autograd): + """ + Hutchinson's trace estimator for Jacobian df/dx, O(1) call to autograd + """ + if no_autograd: + # the following is compatible with checkpointing + torch.sum(f * noise).backward() + # torch.autograd.backward(tensors=[f], grad_tensors=[noise]) + jvp = x.grad + trJ = torch.sum(jvp * noise, dim=[1, 2, 3]) + x.grad = None + else: + jvp = torch.autograd.grad(f, x, noise, create_graph=False)[0] + trJ = torch.sum(jvp * noise, dim=[1, 2, 3]) + # trJ = torch.einsum('bijk,bijk->b', jvp, noise) # we could test if there's a speed difference in einsum vs sum + + return trJ + + +def calc_jacobian_regularization(pred_params, eps_t, dae, var_t, m_t, f_t, g2_t, var_N_t, args): + """ + Calculates Jabobian regularization loss. For reference implementations, see + https://github.com/facebookresearch/jacobian_regularizer/blob/master/jacobian/jacobian.py or + https://github.com/cfinlay/ffjord-rnode/blob/master/lib/layers/odefunc.py. + """ + # eps_t_jvp = eps_t.detach() + # eps_t_jvp = eps_t.detach().requires_grad_() + if args.no_autograd_jvp: + raise NotImplementedError( + "We have not implemented no_autograd_jvp for jacobian reg.") + + jvp_ode_func_norms = [] + alpha = torch.sigmoid(dae.mixing_logit.detach()) + for _ in range(args.jac_reg_samples): + noise = sample_gaussian_like(eps_t) + jvp = torch.autograd.grad( + pred_params, eps_t, noise, create_graph=True)[0] + + if args.sde_type in ['geometric_sde', 'vpsde', 'power_vpsde']: + jvp_ode_func = alpha * (noise * torch.sqrt(var_t) - jvp) + if not args.jac_kin_reg_drop_weights: + jvp_ode_func = f_t / torch.sqrt(var_t) * jvp_ode_func + elif args.sde_type in ['sub_vpsde', 'sub_power_vpsde']: + sigma2_N_t = (1.0 - m_t ** 2) ** 2 + m_t ** 2 + jvp_ode_func = noise * torch.sqrt(var_t) / (1.0 - m_t ** 4) - ( + (1.0 - alpha) * noise * torch.sqrt(var_t) / sigma2_N_t + alpha * jvp) + if not args.jac_kin_reg_drop_weights: + jvp_ode_func = f_t * (1.0 - m_t ** 4) / \ + torch.sqrt(var_t) * jvp_ode_func + elif args.sde_type in ['vesde']: + jvp_ode_func = (1.0 - alpha) * noise * \ + torch.sqrt(var_t) / var_N_t + alpha * jvp + if not args.jac_kin_reg_drop_weights: + jvp_ode_func = 0.5 * g2_t / torch.sqrt(var_t) * jvp_ode_func + else: + raise ValueError("Unrecognized SDE type: {}".format(args.sde_type)) + + jvp_ode_func_norms.append(jvp_ode_func.view( + eps_t.size(0), -1).pow(2).sum(dim=1, keepdim=True)) + + jac_reg_loss = torch.cat(jvp_ode_func_norms, dim=1).mean() + # jac_reg_loss = torch.mean(jvp_ode_func.view(eps_t.size(0), -1).pow(2).sum(dim=1)) + return jac_reg_loss + + +def calc_kinetic_regularization(pred_params, eps_t, dae, var_t, m_t, f_t, g2_t, var_N_t, args): + """ + Calculates kinetic regularization loss. For a reference implementation, see + https://github.com/cfinlay/ffjord-rnode/blob/master/lib/layers/wrappers/cnf_regularization.py + """ + # eps_t_kin = eps_t.detach() + + alpha = torch.sigmoid(dae.mixing_logit.detach()) + if args.sde_type in ['geometric_sde', 'vpsde', 'power_vpsde']: + ode_func = alpha * (eps_t * torch.sqrt(var_t) - pred_params) + if not args.jac_kin_reg_drop_weights: + ode_func = f_t / torch.sqrt(var_t) * ode_func + elif args.sde_type in ['sub_vpsde', 'sub_power_vpsde']: + sigma2_N_t = (1.0 - m_t ** 2) ** 2 + m_t ** 2 + ode_func = eps_t * torch.sqrt(var_t) / (1.0 - m_t ** 4) - ( + (1.0 - alpha) * eps_t * torch.sqrt(var_t) / sigma2_N_t + alpha * pred_params) + if not args.jac_kin_reg_drop_weights: + ode_func = f_t * (1.0 - m_t ** 4) / torch.sqrt(var_t) * ode_func + elif args.sde_type in ['vesde']: + ode_func = (1.0 - alpha) * eps_t * torch.sqrt(var_t) / \ + var_N_t + alpha * pred_params + if not args.jac_kin_reg_drop_weights: + ode_func = 0.5 * g2_t / torch.sqrt(var_t) * ode_func + else: + raise ValueError("Unrecognized SDE type: {}".format(args.sde_type)) + + kin_reg_loss = torch.mean(ode_func.view( + eps_t.size(0), -1).pow(2).sum(dim=1)) + return kin_reg_loss + + +def different_p_q_objectives(iw_sample_p, iw_sample_q): + assert iw_sample_p in ['ll_uniform', 'drop_all_uniform', 'll_iw', 'drop_all_iw', 'drop_sigma2t_iw', 'rescale_iw', + 'drop_sigma2t_uniform'] + assert iw_sample_q in ['reweight_p_samples', 'll_uniform', 'll_iw'] + # Removed assert below. It may be stupid, but user can still do it. It may make sense for debugging purposes. + # assert iw_sample_p != iw_sample_q, 'It does not make sense to use the same objectives for p and q, but train ' \ + # 'with separated q and p updates. To reuse the p objective for q, specify ' \ + # '"reweight_p_samples" instead (for the ll-based objectives, the ' \ + # 'reweighting factor will simply be 1.0 then)!' + # In these cases, we reuse the likelihood-based p-objective (either the uniform sampling version or the importance + # sampling version) also for q. + if iw_sample_p in ['ll_uniform', 'll_iw'] and iw_sample_q == 'reweight_p_samples': + return False + # In these cases, we are using a non-likelihood-based objective for p, and hence definitly need to use another q + # objective. + else: + return True + + +def decoder_output(dataset, logits, fixed_log_scales=None): + if dataset in {'cifar10', 'celeba_64', 'celeba_256', 'imagenet_32', 'imagenet_64', 'ffhq', + 'lsun_bedroom_128', 'lsun_bedroom_256', 'mnist', 'omniglot', + 'lsun_church_256'}: + return PixelNormal(logits, fixed_log_scales) + else: + return PixelNormal(logits, fixed_log_scales) + # raise NotImplementedError + + +def get_mixed_prediction(mixed_prediction, param, mixing_logit, mixing_component=None): + if mixed_prediction: + assert mixing_component is not None, 'Provide mixing component when mixed_prediction is enabled.' + coeff = torch.sigmoid(mixing_logit) + param = (1 - coeff) * mixing_component + coeff * param + + return param + + +def set_vesde_sigma_max(args, vae, train_queue, logging, is_distributed): + logging.info('') + logging.info( + 'Calculating max. pairwise distance in latent space to set sigma2_max for VESDE...') + + eps_list = [] + vae.eval() + for step, x in enumerate(train_queue): + x = x[0] if len(x) > 1 else x + x = x.cuda() + x = symmetrize_image_data(x) + + # run vae + with autocast(enabled=args.autocast_train): + with torch.set_grad_enabled(False): + logits, all_log_q, all_eps = vae(x) + eps = torch.cat(all_eps, dim=1) + + eps_list.append(eps.detach()) + + # if step > 5: ### DEBUG + # break ### DEBUG + + # concat eps tensor on each GPU and then gather all on all GPUs + eps_this_rank = torch.cat(eps_list, dim=0) + if is_distributed: + eps_all_gathered = [torch.zeros_like( + eps_this_rank)] * dist.get_world_size() + dist.all_gather(eps_all_gathered, eps_this_rank) + eps_full = torch.cat(eps_all_gathered, dim=0) + else: + eps_full = eps_this_rank + + # max pairwise distance squared between all latent encodings, is computed on CPU + eps_full = eps_full.cpu().float() + eps_full = eps_full.flatten(start_dim=1).unsqueeze(0) + max_pairwise_dist_sqr = torch.cdist(eps_full, eps_full).square().max() + max_pairwise_dist_sqr = max_pairwise_dist_sqr.cuda() + + # to be safe, we broadcast to all GPUs if we are in distributed environment. Shouldn't be necessary in principle. + if is_distributed: + dist.broadcast(max_pairwise_dist_sqr, src=0) + + args.sigma2_max = max_pairwise_dist_sqr.item() + + logging.info('Done! Set args.sigma2_max set to {}'.format(args.sigma2_max)) + logging.info('') + return args + + +def mask_inactive_variables(x, is_active): + x = x * is_active + return x + + +def common_x_operations(x, num_x_bits): + x = x[0] if len(x) > 1 else x + x = x.cuda() + + # change bit length + x = change_bit_length(x, num_x_bits) + x = symmetrize_image_data(x) + + return x + + +def vae_regularization(args, vae_sn_calculator, loss_weight=None): + """ + when using hvae_trainer, we pass args=None, and loss_weight value + """ + regularization_q, vae_norm_loss, vae_bn_loss, vae_wdn_coeff = 0., 0., 0., args.weight_decay_norm_vae if loss_weight is None else loss_weight + if loss_weight is not None or args.train_vae: + vae_norm_loss = vae_sn_calculator.spectral_norm_parallel() + vae_bn_loss = vae_sn_calculator.batchnorm_loss() + regularization_q = (vae_norm_loss + vae_bn_loss) * vae_wdn_coeff + + return regularization_q, vae_norm_loss, vae_bn_loss, vae_wdn_coeff + + +def dae_regularization(args, dae_sn_calculator, diffusion, dae, step, t, pred_params_p, eps_t_p, var_t_p, m_t_p, g2_t_p): + dae_wdn_coeff = args.weight_decay_norm_dae + dae_norm_loss = dae_sn_calculator.spectral_norm_parallel() + dae_bn_loss = dae_sn_calculator.batchnorm_loss() + regularization_p = (dae_norm_loss + dae_bn_loss) * dae_wdn_coeff + + # Jacobian regularization + jac_reg_loss = 0. + if args.jac_reg_coeff > 0.0 and step % args.jac_reg_freq == 0: + f_t = diffusion.f(t).view(-1, 1, 1, 1) + var_N_t = diffusion.var_N( + t).view(-1, 1, 1, 1) if args.sde_type == 'vesde' else None + """ + # Arash: Please remove the following if it looks correct to you, Karsten. + # jac_reg_loss = utils.calc_jacobian_regularization(pred_params, eps_t, dae, var_t, m_t, f_t, args) + if args.iw_sample_q in ['ll_uniform', 'll_iw']: + pred_params_jac_reg = torch.chunk(pred_params, chunks=2, dim=0)[0] + var_t_jac_reg, m_t_jac_reg, f_t_jac_reg = torch.chunk(var_t, chunks=2, dim=0)[0], \ + torch.chunk(m_t, chunks=2, dim=0)[0], \ + torch.chunk(f_t, chunks=2, dim=0)[0] + g2_t_jac_reg = torch.chunk(g2_t, chunks=2, dim=0)[0] + var_N_t_jac_reg = torch.chunk(var_N_t, chunks=2, dim=0)[0] if args.sde_type == 'vesde' else None + else: + pred_params_jac_reg = pred_params + var_t_jac_reg, m_t_jac_reg, f_t_jac_reg, g2_t_jac_reg, var_N_t_jac_reg = var_t, m_t, f_t, g2_t, var_N_t + jac_reg_loss = utils.calc_jacobian_regularization(pred_params_jac_reg, eps_t_p, dae, var_t_jac_reg, m_t_jac_reg, + f_t_jac_reg, g2_t_jac_reg, var_N_t_jac_reg, args) + """ + jac_reg_loss = calc_jacobian_regularization(pred_params_p, eps_t_p, dae, var_t_p, m_t_p, + f_t, g2_t_p, var_N_t, args) + regularization_p += args.jac_reg_coeff * jac_reg_loss + + # Kinetic regularization + kin_reg_loss = 0. + if args.kin_reg_coeff > 0.0: + f_t = diffusion.f(t).view(-1, 1, 1, 1) + var_N_t = diffusion.var_N( + t).view(-1, 1, 1, 1) if args.sde_type == 'vesde' else None + """ + # Arash: Please remove the following if it looks correct to you, Karsten. + # kin_reg_loss = utils.calc_kinetic_regularization(pred_params, eps_t, dae, var_t, m_t, f_t, args) + if args.iw_sample_q in ['ll_uniform', 'll_iw']: + pred_params_kin_reg = torch.chunk(pred_params, chunks=2, dim=0)[0] + var_t_kin_reg, m_t_kin_reg, f_t_kin_reg = torch.chunk(var_t, chunks=2, dim=0)[0], \ + torch.chunk(m_t, chunks=2, dim=0)[0], \ + torch.chunk(f_t, chunks=2, dim=0)[0] + g2_t_kin_reg = torch.chunk(g2_t, chunks=2, dim=0)[0] + var_N_t_kin_reg = torch.chunk(var_N_t, chunks=2, dim=0)[0] if args.sde_type == 'vesde' else None + else: + pred_params_kin_reg = pred_params + var_t_kin_reg, m_t_kin_reg, f_t_kin_reg, g2_t_kin_reg, var_N_t_kin_reg = var_t, m_t, f_t, g2_t, var_N_t + kin_reg_loss = utils.calc_kinetic_regularization(pred_params_kin_reg, eps_t_p, dae, var_t_kin_reg, m_t_kin_reg, + f_t_kin_reg, g2_t_kin_reg, var_N_t_kin_reg, args) + """ + kin_reg_loss = calc_kinetic_regularization(pred_params_p, eps_t_p, dae, var_t_p, m_t_p, + f_t, g2_t_p, var_N_t, args) + regularization_p += args.kin_reg_coeff * kin_reg_loss + + return regularization_p, dae_norm_loss, dae_bn_loss, dae_wdn_coeff, jac_reg_loss, kin_reg_loss + + +def update_vae_lr(args, global_step, warmup_iters, vae_optimizer): + if global_step < warmup_iters: + lr = args.trainer.opt.lr * float(global_step) / warmup_iters + for param_group in vae_optimizer.param_groups: + param_group['lr'] = lr + # use same lr if lr for local-dae is not specified + + +def update_lr(args, global_step, warmup_iters, dae_optimizer, vae_optimizer, dae_local_optimizer=None): + if global_step < warmup_iters: + lr = args.learning_rate_dae * float(global_step) / warmup_iters + if args.learning_rate_mlogit > 0 and len(dae_optimizer.param_groups) > 1: + lr_mlogit = args.learning_rate_mlogit * \ + float(global_step) / warmup_iters + for i, param_group in enumerate(dae_optimizer.param_groups): + if i == 0: + param_group['lr'] = lr_mlogit + else: + param_group['lr'] = lr + else: + for param_group in dae_optimizer.param_groups: + param_group['lr'] = lr + # use same lr if lr for local-dae is not specified + lr = lr if args.learning_rate_dae_local <= 0 else args.learning_rate_dae_local * \ + float(global_step) / warmup_iters + if dae_local_optimizer is not None: + for param_group in dae_local_optimizer.param_groups: + param_group['lr'] = lr + + if args.train_vae: + lr = args.learning_rate_vae * float(global_step) / warmup_iters + for param_group in vae_optimizer.param_groups: + param_group['lr'] = lr + + +def start_meters(): + tr_loss_meter = AvgrageMeter() + vae_recon_meter = AvgrageMeter() + vae_kl_meter = AvgrageMeter() + vae_nelbo_meter = AvgrageMeter() + kl_per_group_ema = AvgrageMeter() + return tr_loss_meter, vae_recon_meter, vae_kl_meter, vae_nelbo_meter, kl_per_group_ema + + +def epoch_logging(args, writer, step, vae_recon_meter, vae_kl_meter, vae_nelbo_meter, tr_loss_meter, kl_per_group_ema): + average_tensor(vae_recon_meter.avg, args.distributed) + average_tensor(vae_kl_meter.avg, args.distributed) + average_tensor(vae_nelbo_meter.avg, args.distributed) + average_tensor(tr_loss_meter.avg, args.distributed) + average_tensor(kl_per_group_ema.avg, args.distributed) + + writer.add_scalar('epoch/vae_recon', vae_recon_meter.avg, step) + writer.add_scalar('epoch/vae_kl', vae_kl_meter.avg, step) + writer.add_scalar('epoch/vae_nelbo', vae_nelbo_meter.avg, step) + writer.add_scalar('epoch/total_loss', tr_loss_meter.avg, step) + # add kl value per group to tensorboard + for i in range(len(kl_per_group_ema.avg)): + writer.add_scalar('kl_value/group_%d' % + i, kl_per_group_ema.avg[i], step) + + +def infer_active_variables(train_queue, vae, args, device, distributed, max_iter=None): + kl_meter = AvgrageMeter() + vae.eval() + for step, x in enumerate(train_queue): + if max_iter is not None and step > max_iter: + break + tr_pts = x['tr_points'] + with autocast(enabled=args.autocast_train): + # apply vae: + with torch.set_grad_enabled(False): + # output = model.recont(val_x) ## torch.cat([val_x, tr_x])) + dist = vae.encode(tr_pts.to(device)) + eps = dist.sample()[0] + all_log_q = [dist.log_p(eps)] + ## _, all_log_q, all_eps = vae(x) + ## all_eps = vae.concat_eps_per_scale(all_eps) + ## all_log_q = vae.concat_eps_per_scale(all_log_q) + all_eps = [eps] + + def make_4d(xlist): return [ + x.unsqueeze(-1).unsqueeze(-1) if len(x.shape) == 2 else x.unsqueeze(-1) for x in xlist] + + log_q, log_p, kl_all, kl_diag = vae_terms( + make_4d(all_log_q), make_4d(all_eps)) + kl_meter.update(kl_diag[0], 1) # only the top scale + average_tensor(kl_meter.avg, distributed) + return kl_meter.avg > 0.1 diff --git a/utils/vis_helper.py b/utils/vis_helper.py new file mode 100644 index 0000000..4e7b3b3 --- /dev/null +++ b/utils/vis_helper.py @@ -0,0 +1,149 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import os +from datetime import datetime +import torchvision +from utils.checker import * +import matplotlib.cm as cm +import torch +import matplotlib.pyplot as plt +from loguru import logger +import numpy as np +import matplotlib +matplotlib.use('Agg') +from PIL import Image +# Visualization + +def plot_points(output, output_name=None): + from utils.data_helper import normalize_point_clouds + output = output.cpu() + input_list = [] + for idx in range(output.shape[0]): + pts = output[idx] + pts = normalize_point_clouds([pts]) + input_img = visualize_point_clouds_3d(pts, ['out#%d' % idx]) + input_list.append(input_img) + input_list = np.concatenate(input_list, axis=2) + img = Image.fromarray(input_list[:3].astype(np.uint8).transpose((1, 2, 0))) + if output_name is None: + output_dir = './results/nv_demos/lion/' + os.makedirs(output_dir, exist_ok=True) + output_name = os.path.join(output_dir, datetime.now().strftime("%y%m%d_%H%M%S.png")) + img.save(output_name) + # print(f'INFO save output img as {output_name}') + return output_name + +def visualize_point_clouds_3d_list(pcl_lst, title_lst, vis_order, vis_2D, bound, S): + t_list = [] + for i in range(len(pcl_lst)): + img = visualize_point_clouds_3d([pcl_lst[i]], [title_lst[i]] if title_lst is not None else None, + vis_order, vis_2D, bound, S) + t_list.append(img) + img = np.concatenate(t_list, axis=2) + return img + + +def visualize_point_clouds_3d(pcl_lst, title_lst=None, + vis_order=[2, 0, 1], vis_2D=1, bound=1.5, S=3, rgba=0): + """ + Copied and modified from https://github.com/stevenygd/PointFlow/blob/b7a9216ffcd2af49b24078156924de025c4dbfb6/utils.py#L109 + + Args: + pcl_lst: list of tensor, len $L$ = num of point sets, + each tensor in shape (N,3), range in [-1,1] + Returns: + image with $L$ column + """ + assert(type(pcl_lst) == list and torch.is_tensor(pcl_lst[0]) + ), f'expect list of tensor, get {type(pcl_lst)} and {type(pcl_lst[0])}' + if len(pcl_lst) > 1: + return visualize_point_clouds_3d_list(pcl_lst, title_lst, vis_order, vis_2D, bound, S) + + pcl_lst = [pcl.cpu().detach().numpy() for pcl in pcl_lst] + if title_lst is None: + title_lst = [""] * len(pcl_lst) + + fig = plt.figure(figsize=(3 * len(pcl_lst), 3)) + num_col = len(pcl_lst) + assert(num_col == len(title_lst) + ), f'require same len, get {num_col} and {len(title_lst)}' + for idx, (pts, title) in enumerate(zip(pcl_lst, title_lst)): + ax1 = fig.add_subplot(1, num_col, 1 + idx, projection='3d') + ax1.set_title(title) + rgb = None + if type(S) is list: + psize = S[idx] + else: + psize = S + ax1.scatter(pts[:, vis_order[0]], pts[:, vis_order[1]], + pts[:, vis_order[2]], s=psize, c=rgb) + ax1.set_xlim(-bound, bound) + ax1.set_ylim(-bound, bound) + ax1.set_zlim(-bound, bound) + ax1.grid(False) + fig.canvas.draw() + + # grab the pixel buffer and dump it into a numpy array + res = fig2data(fig) + res = np.transpose(res, (2, 0, 1)) # 3,H,W + + plt.close() + + if vis_2D: + v1 = 0.5 + v2 = 0 + fig = plt.figure(figsize=(3 * len(pcl_lst), 3)) + num_col = len(pcl_lst) + assert(num_col == len(title_lst) + ), f'require same len, get {num_col} and {len(title_lst)}' + for idx, (pts, title) in enumerate(zip(pcl_lst, title_lst)): + ax1 = fig.add_subplot(1, num_col, 1 + idx, projection='3d') + rgb = None + if type(S) is list: + psize = S[idx] + else: + psize = S + ax1.scatter(pts[:, vis_order[0]], pts[:, vis_order[1]], + pts[:, vis_order[2]], s=psize, c=rgb) + ax1.set_xlim(-bound, bound) + ax1.set_ylim(-bound, bound) + ax1.set_zlim(-bound, bound) + ax1.grid(False) + ax1.set_title(title + '-2D') + ax1.view_init(v1, v2) # 0.5, 0) + + fig.canvas.draw() + + # grab the pixel buffer and dump it into a numpy array + # res_2d = np.array(fig.canvas.renderer._renderer) + res_2d = fig2data(fig) + res_2d = np.transpose(res_2d, (2, 0, 1)) + plt.close() + + res = np.concatenate([res, res_2d], axis=1) + return res + + +def fig2data(fig): + """ + Adapted from https://stackoverflow.com/questions/55703105/convert-matplotlib-figure-to-numpy-array-of-same-shape + @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it + @param fig a matplotlib figure + @return a numpy 3D array of RGBA values + """ + # draw the renderer + ## fig.canvas.draw ( ) + + # Get the RGBA buffer from the figure + w, h = fig.canvas.get_width_height() + buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8) + buf.shape = (w, h, 4) + + # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode + buf = np.roll(buf, 3, axis=2) + return buf