diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..d60f8d6
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,150 @@
+/out
+/data
+.vscode
+.cache
+*.pyc
+*.pyd
+*.pt
+*.so
+*.o
+*.prof
+*.swp
+*.lib
+*.obj
+*.exp
+.nfs*
+*.jpg
+*.png
+*.ply
+*.off
+*.npz
+*.txt
+# *.sh
+
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
diff --git a/README.md b/README.md
index b2e3392..e7bf941 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,6 @@
# Shape As Points (SAP)
-[**Paper**](https://arxiv.org/abs/2106.03452) | [**Project Page**](https://pengsongyou.github.io/sap) | [**Short Video (6 min)**](https://youtu.be/FL8LMk_qWb4) | [**Long Video (13 min)**](https://youtu.be/TgR0NvYty0A)
+
+### [**Paper**](https://arxiv.org/abs/2106.03452) | [**Project Page**](https://pengsongyou.github.io/sap) | [**Short Video (6 min)**](https://youtu.be/FL8LMk_qWb4) | [**Long Video (12 min)**](https://youtu.be/TgR0NvYty0A)
![](./media/teaser_wheel.gif)
@@ -10,13 +11,164 @@ Shape As Points: A Differentiable Poisson Solver
**NeurIPS 2021 (Oral)**
-## Code is coming soon!
-
If you find our code or paper useful, please consider citing
```bibtex
@inproceedings{Peng2021SAP,
- author = {Peng, Songyou and Jiang, Chiyu "Max" and Liao, Yiyi and Niemeyer, Michael and Pollefeys, Marc and Andreas, Geiger},
+ author = {Peng, Songyou and Jiang, Chiyu "Max" and Liao, Yiyi and Niemeyer, Michael and Pollefeys, Marc and Geiger, Andreas},
title = {Shape As Points: A Differentiable Poisson Solver},
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
year = {2021}}
```
+
+
+## Installation
+First you have to make sure that you have all dependencies in place.
+The simplest way to do so, is to use [anaconda](https://www.anaconda.com/).
+
+You can create an anaconda environment called `sap` using
+```
+conda env create -f environment.yaml
+conda activate sap
+```
+
+Now, you can install [PyTorch3D](https://pytorch3d.org/) 0.6.0 from the [official instruction](https://github.com/facebookresearch/pytorch3d/blob/master/INSTALL.md#3-install-wheels-for-linux) as follows
+```sh
+pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu102_pyt190/download.html
+```
+And install [PyTorch Scatter](https://github.com/rusty1s/pytorch_scatter):
+```sh
+conda install pytorch-scatter -c pyg
+```
+
+
+## Demo - Quick Start
+
+First, run the script to get the demo data:
+
+```bash
+bash scripts/download_demo_data.sh
+```
+
+### Optimization-based 3D Surface Reconstruction
+
+You can now quickly test our code on the data shown in the teaser. To this end, simply run:
+
+```python
+python optim_hierarchy.py configs/optim_based/teaser.yaml
+```
+This script should create a folder `out/demo_optim` where the output meshes and the optimized oriented point clouds under different grid resolution are stored.
+
+To visualize the optimization process on the fly, you can set `o3d_show: Frue` in [`configs/optim_based/teaser.yaml`](https://github.com/autonomousvision/shape_as_points/tree/main/configs/optim_based/teaser.yaml).
+
+### Learning-based 3D Surface Reconstruction
+You can also test SAP on another application where we can reconstruct from unoriented point clouds with either **large noises** or **outliers** with a learned network.
+
+![](./media/results_large_noise.gif)
+
+For the point clouds with large noise as shown above, you can run:
+```python
+python generate.py configs/learning_based/demo_large_noise.yaml
+```
+The results can been found at `out/demo_shapenet_large_noise/generation/vis`.
+
+![](./media/results_outliers.gif)
+As for the point clouds with outliers, you can run:
+```python
+python generate.py configs/learning_based/demo_outlier.yaml
+```
+You can find the reconstrution on `out/demo_shapenet_outlier/generation/vis`.
+
+
+## Dataset
+
+We have different dataset for our optimization-based and learning-based settings.
+
+### Dataset for Optimization-based Reconstruction
+Here we consider the following dataset:
+- [Thingi10K](https://arxiv.org/abs/1605.04797) (synthetic)
+- [Surface Reconstruction Benchmark (SRB)](https://github.com/fwilliams/deep-geometric-prior) (real scans)
+- [MPI Dynamic FAUST](https://dfaust.is.tue.mpg.de/) (real scans)
+
+Please cite the corresponding papers if you use the data.
+
+You can download the processed dataset (~200 MB) by running:
+```bash
+bash scripts/download_optim_data.sh
+```
+
+### Dataset for Learning-based Reconstruction
+We train and evaluate on [ShapeNet](https://shapenet.org/).
+You can download the processed dataset (~220 GB) by running:
+```bash
+bash scripts/download_shapenet.sh
+```
+After, you should have the dataset in `data/shapenet_psr` folder.
+
+Alternatively, you can also preprocess the dataset yourself. To this end, you can:
+* first download the preprocessed dataset (73.4 GB) by running [the script](https://github.com/autonomousvision/occupancy_networks#preprocessed-data) from Occupancy Networks.
+* check [`scripts/process_shapenet.py`](https://github.com/autonomousvision/shape_as_points/tree/main/scripts/process_shapenet.py), modify the base path and run the code
+
+
+## Usage for Optimization-based 3D Reconstruction
+
+For our optimization-based setting, you can consider running with a coarse-to-fine strategy:
+```python
+python optim_hierarchy.py configs/optim_based/CONFIG.yaml
+```
+We start from a grid resolution of 32^3, and increase to 64^3, 128^3 and finally 256^3.
+
+Alternatively, you can also run on a single resolution with:
+
+```python
+python optim.py configs/optim_based/CONFIG.yaml
+```
+You might need to modify the `CONFIG.yaml` accordingly.
+
+## Usage for Learning-based 3D Reconstruction
+
+### Mesh Generation
+To generate meshes using a trained model, use
+```python
+python generate.py configs/learning_based/CONFIG.yaml
+```
+where you replace `CONFIG.yaml` with the correct config file.
+
+#### Use a pre-trained model
+The easiest way is to use a pre-trained model. You can do this by using one of the config files with postfix `_pretrained`.
+
+For example, for 3D reconstruction from point clouds with outliers using our model with 7x offsets, you can simply run:
+```python
+python generate.py configs/learning_based/outlier/ours_7x_pretrained.yaml
+```
+
+The script will automatically download the pretrained model and run the generation. You can find the outputs in the `out/.../generation_pretrained` folders.
+
+**Note** config files are only for generation, not for training new models: when these configs are used for training, the model will be trained from scratch, but during inference our code will still use the pretrained model.
+
+We provide the following pretrained models:
+```
+noise_small/ours.pt
+noise_large/ours.pt
+outlier/ours_1x.pt
+outlier/ours_3x.pt
+outlier/ours_5x.pt
+outlier/ours_7x.pt
+outlier/ours_3plane.pt
+```
+
+
+### Evaluation
+To evaluate a trained model, we provide the script [`eval_meshes.py`](https://github.com/autonomousvision/shape_as_points/blob/main/eval_meshes.py). You can run it using:
+```python
+python eval_meshes.py configs/learning_based/CONFIG.yaml
+```
+The script takes the meshes generated in the previous step and evaluates them using a standardized protocol. The output will be written to `.pkl` and `.csv` files in the corresponding generation folder that can be processed using [pandas](https://pandas.pydata.org/).
+
+### Training
+
+Finally, to train a new network from scratch, simply run:
+```python
+python train.py configs/learning_based/CONFIG.yaml
+```
+For available training options, please take a look at `configs/default.yaml`.
+
diff --git a/configs/default.yaml b/configs/default.yaml
new file mode 100644
index 0000000..b1205ce
--- /dev/null
+++ b/configs/default.yaml
@@ -0,0 +1,104 @@
+data:
+ dataset: Shapes3D
+ path: data/ShapeNet
+ class: null
+ data_type: img
+ input_type: pointcloud
+ dim: 3
+ num_points: 1000
+ num_gt_points: 1000
+ num_offset: 1
+ img_size: null
+ n_views_input: 20
+ n_views_per_iter: 2
+ pointcloud_noise: 0
+ pointcloud_file: pointcloud.npz
+ pointcloud_outlier_ratio: 0
+ fixed_scale: 0
+ train_split: train
+ val_split: val
+ test_split: test
+ points_file: null
+ points_iou_file: points.npz
+ points_unpackbits: true
+ padding: 0.1
+ multi_files: null
+ gt_mesh: null
+ zero_level: 0
+ only_single: False
+ sample_only_floor: False
+model:
+ apply_sigmoid: True
+ grid_res: 128 # poisson grid resolution
+ psr_sigma: 0
+ psr_tanh: False
+ normal_normalize: False
+ raster: {}
+ renderer: {}
+ encoder: null
+ predict_normal: True
+ predict_offset: True
+ s_offset: 0.001
+ local_coord: True
+ encoder_kwargs: {}
+ unet3d: False
+ unet3d_kwargs: {}
+ multi_gpu: false
+ rotate_matrix: false
+ c_dim: 512
+ sphere_radius: 0.2
+train:
+ lr: 1e-3
+ lr_pcl: 2e-2
+ input_mesh: ''
+ out_dir: out/default
+ subsample_vertex: False
+ batch_size: 1
+ n_grow_points: 0
+ resample_every: 0
+ l_weight: {}
+ w_reg_point: 0
+ w_psr: 0
+ w_raw: 0 # train with raw point cloud
+ gauss_weight: 0
+ n_sup_point: 0
+ w_normals: 0
+ total_epochs: 1000
+ print_every: 20
+ visualize_every: 1000
+ save_every: 1000
+ vis_vert_color: True
+ o3d_show: False
+ o3d_vis_pcl: True
+ o3d_window_size: 540
+ vis_rendering: False
+ vis_psr: False
+ save_video: False
+ exp_mesh: True
+ exp_pcl: True
+ checkpoint_every: 1000
+ validate_every: 2000000
+ backup_every: 100000
+ timestamp: False # add timestamp to out_dir name
+ model_selection_metric: loss
+ model_selection_mode: minimize
+ n_workers: 0
+ n_workers_val: 0
+test:
+ threshold: 0.5
+ eval_mesh: true
+ eval_pointcloud: false
+ model_file: model_best.pt
+generation:
+ batch_size: 100000
+ exp_gt: False
+ exp_oracle: false
+ exp_input: False
+ vis_n_outputs: 10
+ generate_mesh: true
+ generate_pointcloud: true
+ generation_dir: generation
+ copy_input: true
+ use_sampling: false
+ psr_resolution: 0
+ psr_sigma: 0
diff --git a/configs/learning_based/demo_large_noise.yaml b/configs/learning_based/demo_large_noise.yaml
new file mode 100644
index 0000000..cfb7da2
--- /dev/null
+++ b/configs/learning_based/demo_large_noise.yaml
@@ -0,0 +1,9 @@
+
+inherit_from: configs/learning_based/noise_large/ours.yaml
+data:
+ class: ['']
+ path: data/demo/shapenet_chair
+train:
+ out_dir: out/demo_shapenet_large_noise
+test:
+ model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_noise_025.pt
\ No newline at end of file
diff --git a/configs/learning_based/demo_outlier.yaml b/configs/learning_based/demo_outlier.yaml
new file mode 100644
index 0000000..246ca49
--- /dev/null
+++ b/configs/learning_based/demo_outlier.yaml
@@ -0,0 +1,9 @@
+
+inherit_from: configs/learning_based/outlier/ours_7x.yaml
+data:
+ class: ['']
+ path: data/demo/shapenet_lamp
+train:
+ out_dir: out/demo_shapenet_outlier
+test:
+ model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_7x.pt
\ No newline at end of file
diff --git a/configs/learning_based/noise_large/ours.yaml b/configs/learning_based/noise_large/ours.yaml
new file mode 100644
index 0000000..2c8eac1
--- /dev/null
+++ b/configs/learning_based/noise_large/ours.yaml
@@ -0,0 +1,54 @@
+data:
+ class: null
+ data_type: psr_full
+ input_type: pointcloud
+ path: data/shapenet_psr
+ num_gt_points: 10000
+ num_offset: 7
+ pointcloud_n: 3000
+ pointcloud_noise: 0.025
+model:
+ grid_res: 128 # poisson grid resolution
+ psr_sigma: 2
+ psr_tanh: True
+ normal_normalize: False
+ predict_normal: True
+ predict_offset: True
+ c_dim: 32
+ s_offset: 0.001
+ encoder: local_pool_pointnet
+ encoder_kwargs:
+ hidden_dim: 32
+ plane_type: 'grid'
+ grid_resolution: 32
+ unet3d: True
+ unet3d_kwargs:
+ num_levels: 3
+ f_maps: 32
+ in_channels: 32
+ out_channels: 32
+ decoder: simple_local
+ decoder_kwargs:
+ sample_mode: bilinear # bilinear / nearest
+ hidden_size: 32
+train:
+ batch_size: 32
+ lr: 5e-4
+ out_dir: out/shapenet/noise_025_ours
+ w_psr: 1
+ model_selection_metric: psr_l2
+ print_every: 100
+ checkpoint_every: 200
+ validate_every: 5000
+ backup_every: 10000
+ total_epochs: 400000
+ visualize_every: 5000
+ exp_pcl: True
+ exp_mesh: True
+ n_workers: 8
+ n_workers_val: 0
+generation:
+ exp_gt: False
+ exp_input: True
+ psr_resolution: 128
+ psr_sigma: 2
diff --git a/configs/learning_based/noise_large/ours_pretrained.yaml b/configs/learning_based/noise_large/ours_pretrained.yaml
new file mode 100644
index 0000000..bec016c
--- /dev/null
+++ b/configs/learning_based/noise_large/ours_pretrained.yaml
@@ -0,0 +1,5 @@
+inherit_from: configs/learning_based/noise_large/ours.yaml
+generation:
+ generation_dir: generation_pretrained
+test:
+ model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_noise_025.pt
\ No newline at end of file
diff --git a/configs/learning_based/noise_small/ours.yaml b/configs/learning_based/noise_small/ours.yaml
new file mode 100644
index 0000000..3b1eb3a
--- /dev/null
+++ b/configs/learning_based/noise_small/ours.yaml
@@ -0,0 +1,54 @@
+data:
+ class: null
+ data_type: psr_full
+ input_type: pointcloud
+ path: data/shapenet_psr
+ num_gt_points: 10000
+ num_offset: 7
+ pointcloud_n: 3000
+ pointcloud_noise: 0.005
+model:
+ grid_res: 128 # poisson grid resolution
+ psr_sigma: 2
+ psr_tanh: True
+ normal_normalize: False
+ predict_normal: True
+ predict_offset: True
+ c_dim: 32
+ s_offset: 0.001
+ encoder: local_pool_pointnet
+ encoder_kwargs:
+ hidden_dim: 32
+ plane_type: 'grid'
+ grid_resolution: 32
+ unet3d: True
+ unet3d_kwargs:
+ num_levels: 3
+ f_maps: 32
+ in_channels: 32
+ out_channels: 32
+ decoder: simple_local
+ decoder_kwargs:
+ sample_mode: bilinear # bilinear / nearest
+ hidden_size: 32
+train:
+ batch_size: 32
+ lr: 5e-4
+ out_dir: out/shapenet/noise_005_ours
+ w_psr: 1
+ model_selection_metric: psr_l2
+ print_every: 100
+ checkpoint_every: 200
+ validate_every: 5000
+ backup_every: 10000
+ total_epochs: 400000
+ visualize_every: 5000
+ exp_pcl: True
+ exp_mesh: True
+ n_workers: 8
+ n_workers_val: 0
+generation:
+ exp_gt: False
+ exp_input: True
+ psr_resolution: 128
+ psr_sigma: 2
diff --git a/configs/learning_based/noise_small/ours_pretrained.yaml b/configs/learning_based/noise_small/ours_pretrained.yaml
new file mode 100644
index 0000000..fc3317e
--- /dev/null
+++ b/configs/learning_based/noise_small/ours_pretrained.yaml
@@ -0,0 +1,5 @@
+inherit_from: configs/learning_based/noise_small/ours.yaml
+generation:
+ generation_dir: generation_pretrained
+test:
+ model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_noise_005.pt
\ No newline at end of file
diff --git a/configs/learning_based/outlier/ours_1x.yaml b/configs/learning_based/outlier/ours_1x.yaml
new file mode 100644
index 0000000..c973789
--- /dev/null
+++ b/configs/learning_based/outlier/ours_1x.yaml
@@ -0,0 +1,55 @@
+data:
+ class: null
+ data_type: psr_full
+ input_type: pointcloud
+ path: data/shapenet_psr
+ num_gt_points: 10000
+ num_offset: 1
+ pointcloud_n: 3000
+ pointcloud_noise: 0.005
+ pointcloud_outlier_ratio: 0.5
+model:
+ grid_res: 128 # poisson grid resolution
+ psr_sigma: 2
+ psr_tanh: True
+ normal_normalize: False
+ predict_normal: True
+ predict_offset: True
+ c_dim: 32
+ s_offset: 0.001
+ encoder: local_pool_pointnet
+ encoder_kwargs:
+ hidden_dim: 32
+ plane_type: 'grid'
+ grid_resolution: 32
+ unet3d: True
+ unet3d_kwargs:
+ num_levels: 3
+ f_maps: 32
+ in_channels: 32
+ out_channels: 32
+ decoder: simple_local
+ decoder_kwargs:
+ sample_mode: bilinear # bilinear / nearest
+ hidden_size: 32
+train:
+ batch_size: 32
+ lr: 5e-4
+ out_dir: out/shapenet/outlier_ours_1x
+ w_psr: 1
+ model_selection_metric: psr_l2
+ print_every: 100
+ checkpoint_every: 200
+ validate_every: 5000
+ backup_every: 10000
+ total_epochs: 400000
+ visualize_every: 5000
+ exp_pcl: True
+ exp_mesh: True
+ n_workers: 8
+ n_workers_val: 0
+generation:
+ exp_gt: False
+ exp_input: True
+ psr_resolution: 128
+ psr_sigma: 2
diff --git a/configs/learning_based/outlier/ours_1x_pretrained.yaml b/configs/learning_based/outlier/ours_1x_pretrained.yaml
new file mode 100644
index 0000000..5aaf6b5
--- /dev/null
+++ b/configs/learning_based/outlier/ours_1x_pretrained.yaml
@@ -0,0 +1,5 @@
+inherit_from: configs/learning_based/outlier/ours_3x/ours.yaml
+generation:
+ generation_dir: generation_pretrained
+test:
+ model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_3x.pt
\ No newline at end of file
diff --git a/configs/learning_based/outlier/ours_3plane.yaml b/configs/learning_based/outlier/ours_3plane.yaml
new file mode 100644
index 0000000..0c2aa3b
--- /dev/null
+++ b/configs/learning_based/outlier/ours_3plane.yaml
@@ -0,0 +1,54 @@
+data:
+ class: null
+ data_type: psr_full
+ input_type: pointcloud
+ path: data/shapenet_psr
+ num_gt_points: 10000
+ num_offset: 5
+ pointcloud_n: 3000
+ pointcloud_noise: 0.005
+ pointcloud_outlier_ratio: 0.5
+model:
+ grid_res: 128 # poisson grid resolution
+ psr_sigma: 2
+ psr_tanh: True
+ normal_normalize: False
+ predict_normal: True
+ predict_offset: True
+ c_dim: 32
+ s_offset: 0.001
+ encoder: local_pool_pointnet
+ encoder_kwargs:
+ hidden_dim: 32
+ plane_type: ['xz', 'xy', 'yz']
+ plane_resolution: 64
+ unet: True
+ unet_kwargs:
+ depth: 4
+ merge_mode: concat
+ start_filts: 32
+ decoder: simple_local
+ decoder_kwargs:
+ sample_mode: bilinear # bilinear / nearest
+ hidden_size: 32
+train:
+ batch_size: 32
+ lr: 5e-4
+ out_dir: out/shapenet/outlier_ours_3plane
+ w_psr: 1
+ model_selection_metric: psr_l2
+ print_every: 100
+ checkpoint_every: 200
+ validate_every: 5000
+ backup_every: 10000
+ total_epochs: 400000
+ visualize_every: 5000
+ exp_pcl: True
+ exp_mesh: True
+ n_workers: 8
+ n_workers_val: 0
+generation:
+ exp_gt: False
+ exp_input: True
+ psr_resolution: 128
+ psr_sigma: 2
diff --git a/configs/learning_based/outlier/ours_3x.yaml b/configs/learning_based/outlier/ours_3x.yaml
new file mode 100644
index 0000000..e976867
--- /dev/null
+++ b/configs/learning_based/outlier/ours_3x.yaml
@@ -0,0 +1,55 @@
+data:
+ class: null
+ data_type: psr_full
+ input_type: pointcloud
+ path: data/shapenet_psr
+ num_gt_points: 10000
+ num_offset: 3
+ pointcloud_n: 3000
+ pointcloud_noise: 0.005
+ pointcloud_outlier_ratio: 0.5
+model:
+ grid_res: 128 # poisson grid resolution
+ psr_sigma: 2
+ psr_tanh: True
+ normal_normalize: False
+ predict_normal: True
+ predict_offset: True
+ c_dim: 32
+ s_offset: 0.001
+ encoder: local_pool_pointnet
+ encoder_kwargs:
+ hidden_dim: 32
+ plane_type: 'grid'
+ grid_resolution: 32
+ unet3d: True
+ unet3d_kwargs:
+ num_levels: 3
+ f_maps: 32
+ in_channels: 32
+ out_channels: 32
+ decoder: simple_local
+ decoder_kwargs:
+ sample_mode: bilinear # bilinear / nearest
+ hidden_size: 32
+train:
+ batch_size: 32
+ lr: 5e-4
+ out_dir: out/shapenet/outlier_ours_3x
+ w_psr: 1
+ model_selection_metric: psr_l2
+ print_every: 100
+ checkpoint_every: 200
+ validate_every: 5000
+ backup_every: 10000
+ total_epochs: 400000
+ visualize_every: 5000
+ exp_pcl: True
+ exp_mesh: True
+ n_workers: 8
+ n_workers_val: 0
+generation:
+ exp_gt: False
+ exp_input: True
+ psr_resolution: 128
+ psr_sigma: 2
diff --git a/configs/learning_based/outlier/ours_3x_pretrained.yaml b/configs/learning_based/outlier/ours_3x_pretrained.yaml
new file mode 100644
index 0000000..57e2590
--- /dev/null
+++ b/configs/learning_based/outlier/ours_3x_pretrained.yaml
@@ -0,0 +1,5 @@
+inherit_from: configs/learning_based/outlier/ours_1x/ours.yaml
+generation:
+ generation_dir: generation_pretrained
+test:
+ model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_1x.pt
\ No newline at end of file
diff --git a/configs/learning_based/outlier/ours_5x.yaml b/configs/learning_based/outlier/ours_5x.yaml
new file mode 100644
index 0000000..0f067c3
--- /dev/null
+++ b/configs/learning_based/outlier/ours_5x.yaml
@@ -0,0 +1,55 @@
+data:
+ class: null
+ data_type: psr_full
+ input_type: pointcloud
+ path: data/shapenet_psr
+ num_gt_points: 10000
+ num_offset: 5
+ pointcloud_n: 3000
+ pointcloud_noise: 0.005
+ pointcloud_outlier_ratio: 0.5
+model:
+ grid_res: 128 # poisson grid resolution
+ psr_sigma: 2
+ psr_tanh: True
+ normal_normalize: False
+ predict_normal: True
+ predict_offset: True
+ c_dim: 32
+ s_offset: 0.001
+ encoder: local_pool_pointnet
+ encoder_kwargs:
+ hidden_dim: 32
+ plane_type: 'grid'
+ grid_resolution: 32
+ unet3d: True
+ unet3d_kwargs:
+ num_levels: 3
+ f_maps: 32
+ in_channels: 32
+ out_channels: 32
+ decoder: simple_local
+ decoder_kwargs:
+ sample_mode: bilinear # bilinear / nearest
+ hidden_size: 32
+train:
+ batch_size: 32
+ lr: 5e-4
+ out_dir: out/shapenet/outlier_ours_5x
+ w_psr: 1
+ model_selection_metric: psr_l2
+ print_every: 100
+ checkpoint_every: 200
+ validate_every: 5000
+ backup_every: 10000
+ total_epochs: 400000
+ visualize_every: 5000
+ exp_pcl: True
+ exp_mesh: True
+ n_workers: 8
+ n_workers_val: 0
+generation:
+ exp_gt: False
+ exp_input: True
+ psr_resolution: 128
+ psr_sigma: 2
diff --git a/configs/learning_based/outlier/ours_5x_pretrained.yaml b/configs/learning_based/outlier/ours_5x_pretrained.yaml
new file mode 100644
index 0000000..2ae21a2
--- /dev/null
+++ b/configs/learning_based/outlier/ours_5x_pretrained.yaml
@@ -0,0 +1,5 @@
+inherit_from: configs/learning_based/outlier/ours_5x/ours.yaml
+generation:
+ generation_dir: generation_pretrained
+test:
+ model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_5x.pt
\ No newline at end of file
diff --git a/configs/learning_based/outlier/ours_7x.yaml b/configs/learning_based/outlier/ours_7x.yaml
new file mode 100644
index 0000000..4152279
--- /dev/null
+++ b/configs/learning_based/outlier/ours_7x.yaml
@@ -0,0 +1,55 @@
+data:
+ class: null
+ data_type: psr_full
+ input_type: pointcloud
+ path: data/shapenet_psr
+ num_gt_points: 10000
+ num_offset: 7
+ pointcloud_n: 3000
+ pointcloud_noise: 0.005
+ pointcloud_outlier_ratio: 0.5
+model:
+ grid_res: 128 # poisson grid resolution
+ psr_sigma: 2
+ psr_tanh: True
+ normal_normalize: False
+ predict_normal: True
+ predict_offset: True
+ c_dim: 32
+ s_offset: 0.001
+ encoder: local_pool_pointnet
+ encoder_kwargs:
+ hidden_dim: 32
+ plane_type: 'grid'
+ grid_resolution: 32
+ unet3d: True
+ unet3d_kwargs:
+ num_levels: 3
+ f_maps: 32
+ in_channels: 32
+ out_channels: 32
+ decoder: simple_local
+ decoder_kwargs:
+ sample_mode: bilinear # bilinear / nearest
+ hidden_size: 32
+train:
+ batch_size: 32
+ lr: 5e-4
+ out_dir: out/shapenet/outlier_ours_7x
+ w_psr: 1
+ model_selection_metric: psr_l2
+ print_every: 100
+ checkpoint_every: 200
+ validate_every: 5000
+ backup_every: 10000
+ total_epochs: 400000
+ visualize_every: 5000
+ exp_pcl: True
+ exp_mesh: True
+ n_workers: 8
+ n_workers_val: 0
+generation:
+ exp_gt: False
+ exp_input: True
+ psr_resolution: 128
+ psr_sigma: 2
diff --git a/configs/learning_based/outlier/ours_7x_pretrained.yaml b/configs/learning_based/outlier/ours_7x_pretrained.yaml
new file mode 100644
index 0000000..a46e621
--- /dev/null
+++ b/configs/learning_based/outlier/ours_7x_pretrained.yaml
@@ -0,0 +1,5 @@
+inherit_from: configs/learning_based/outlier/ours_7x/ours.yaml
+generation:
+ generation_dir: generation_pretrained
+test:
+ model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_7x.pt
\ No newline at end of file
diff --git a/configs/optim_based/dfaust.yaml b/configs/optim_based/dfaust.yaml
new file mode 100644
index 0000000..226d24d
--- /dev/null
+++ b/configs/optim_based/dfaust.yaml
@@ -0,0 +1,37 @@
+data:
+ class: 'only_pcl'
+ data_type: point
+ data_path: 'data/dfaust/*.ply'
+ object_id: 0
+ num_points: 20000
+model:
+ sphere_radius: 0.2
+ grid_res: 256 # poisson grid resolution
+ psr_sigma: 2
+train:
+ schedule:
+ pcl:
+ initial: 1e-2
+ interval: 700
+ factor: 0.5
+ final: 1e-3
+ out_dir: out/dfaust
+ w_chamfer: 1
+ n_sup_point: 20000
+ batch_size: 1
+ n_grow_points: 2000
+ resample_every: 200
+ subsample_vertex: False
+ total_epochs: 1600
+ print_every: 10
+ visualize_every: 2
+ checkpoint_every: 500
+ save_every: 100
+ exp_pcl: True
+ exp_mesh: True
+ o3d_show: False
+ o3d_vis_pcl: True
+ o3d_window_size: 540
+ vis_rendering: False
+ vis_vert_color: False
+ n_workers: 0
diff --git a/configs/optim_based/dgp.yaml b/configs/optim_based/dgp.yaml
new file mode 100644
index 0000000..79f0726
--- /dev/null
+++ b/configs/optim_based/dgp.yaml
@@ -0,0 +1,38 @@
+data:
+ class: 'only_pcl'
+ data_type: point
+ data_path: 'data/deep_geometric_prior_data/*.ply'
+ object_id: 0
+ num_points: 20000
+model:
+ sphere_radius: 0.2
+ grid_res: 256 # poisson grid resolution
+ psr_sigma: 2
+train:
+ schedule:
+ pcl:
+ initial: 1e-2
+ interval: 700
+ factor: 0.5
+ final: 1e-3
+ out_dir: out/dgp
+ w_reg_point: 0
+ w_chamfer: 1
+ n_sup_point: 20000
+ batch_size: 1
+ n_grow_points: 2000
+ resample_every: 200
+ subsample_vertex: False
+ total_epochs: 1600
+ print_every: 10
+ visualize_every: 2
+ checkpoint_every: 500
+ save_every: 100
+ exp_pcl: True
+ exp_mesh: True
+ o3d_show: False
+ o3d_vis_pcl: True
+ o3d_window_size: 540
+ vis_rendering: False
+ vis_vert_color: False
+ n_workers: 0
diff --git a/configs/optim_based/teaser.yaml b/configs/optim_based/teaser.yaml
new file mode 100644
index 0000000..40fd2d4
--- /dev/null
+++ b/configs/optim_based/teaser.yaml
@@ -0,0 +1,37 @@
+data:
+ class: 'only_pcl'
+ data_type: point
+ data_path: 'data/demo/wheel.obj'
+ object_id: 0
+ num_points: 20000
+model:
+ sphere_radius: 0.2
+ grid_res: 128 # poisson grid resolution
+ psr_sigma: 2
+train:
+ schedule:
+ pcl:
+ initial: 1e-2
+ interval: 700
+ factor: 0.5
+ final: 1e-3
+ out_dir: out/demo_optim
+ w_chamfer: 1
+ n_sup_point: 20000
+ batch_size: 1
+ n_grow_points: 2000
+ resample_every: 200
+ subsample_vertex: False
+ total_epochs: 1600
+ print_every: 10
+ visualize_every: 2
+ checkpoint_every: 500
+ save_every: 100
+ exp_pcl: True
+ exp_mesh: True
+ o3d_show: False
+ o3d_vis_pcl: True
+ o3d_window_size: 540
+ vis_rendering: False
+ vis_vert_color: False
+ n_workers: 0
\ No newline at end of file
diff --git a/configs/optim_based/thingi.yaml b/configs/optim_based/thingi.yaml
new file mode 100644
index 0000000..fa5f1a6
--- /dev/null
+++ b/configs/optim_based/thingi.yaml
@@ -0,0 +1,39 @@
+data:
+ class: 'only_pcl'
+ data_type: point
+ data_path: 'data/thingi/*.ply'
+ object_id: 0
+ num_points: 20000
+model:
+ sphere_radius: 0.2
+ grid_res: 128 # poisson grid resolution
+ psr_sigma: 2
+train:
+ # lr_pcl: 2e-2
+ schedule:
+ pcl:
+ initial: 1e-2
+ interval: 700
+ factor: 0.5
+ final: 1e-3
+ out_dir: out/thingi
+ w_reg_point: 0
+ w_chamfer: 1
+ n_sup_point: 20000
+ batch_size: 1
+ n_grow_points: 2000
+ resample_every: 200
+ subsample_vertex: False
+ total_epochs: 1600
+ print_every: 10
+ visualize_every: 2
+ checkpoint_every: 500
+ save_every: 100
+ exp_pcl: True
+ exp_mesh: True
+ o3d_show: False
+ o3d_vis_pcl: True
+ o3d_window_size: 540
+ vis_rendering: False
+ vis_vert_color: False
+ n_workers: 0
diff --git a/configs/optim_based/thingi_noisy.yaml b/configs/optim_based/thingi_noisy.yaml
new file mode 100644
index 0000000..3abc166
--- /dev/null
+++ b/configs/optim_based/thingi_noisy.yaml
@@ -0,0 +1,39 @@
+data:
+ class: 'only_pcl'
+ data_type: point
+ data_path: 'data/thingi_noisy/*.ply'
+ object_id: 0
+ num_points: 20000
+model:
+ sphere_radius: 0.2
+ grid_res: 128 # poisson grid resolution
+ psr_sigma: 2
+train:
+ # lr_pcl: 2e-2
+ schedule:
+ pcl:
+ initial: 1e-2
+ interval: 700
+ factor: 0.5
+ final: 1e-3
+ out_dir: out/thingi_noisy
+ w_reg_point: 0
+ w_chamfer: 1
+ n_sup_point: 20000
+ batch_size: 1
+ n_grow_points: 2000
+ resample_every: 200
+ subsample_vertex: False
+ total_epochs: 1600
+ print_every: 10
+ visualize_every: 2
+ checkpoint_every: 500
+ save_every: 100
+ exp_pcl: True
+ exp_mesh: True
+ o3d_show: False
+ o3d_vis_pcl: True
+ o3d_window_size: 540
+ vis_rendering: False
+ vis_vert_color: False
+ n_workers: 0
diff --git a/environment.yaml b/environment.yaml
new file mode 100644
index 0000000..ae023d7
--- /dev/null
+++ b/environment.yaml
@@ -0,0 +1,29 @@
+name: sap
+channels:
+ - anaconda
+ - conda-forge
+ - pytorch
+ - defaults
+dependencies:
+ - python=3.8
+ - pytorch=1.9.0
+ - torchvision=0.10.0
+ - cudatoolkit=10.2
+ - numpy=1.18.1
+ - matplotlib=3.4.3
+ - pyyaml=5.3.1
+ - scipy=1.4.1
+ - tqdm=4.54.0
+ - trimesh=3.8.14
+ - igl=2.2.1
+ - tensorboard=2.6.0
+ - pip
+ - pip:
+ - plyfile==0.7
+ - open3d>=0.11.1
+ - scikit-image>=0.18.0
+ - python-mnist==0.7
+ - opencv-python>=4.4
+ - av==8.0.3
+ - pykdtree==1.3.4
+ - ipdb==0.13.7
diff --git a/eval_meshes.py b/eval_meshes.py
new file mode 100644
index 0000000..96a82e5
--- /dev/null
+++ b/eval_meshes.py
@@ -0,0 +1,155 @@
+import torch
+import trimesh
+from torch.utils.data import Dataset, DataLoader
+import numpy as np; np.set_printoptions(precision=4)
+import shutil, argparse, time, os
+import pandas as pd
+from src.data import collate_remove_none, collate_stack_together, worker_init_fn
+from src.training import Trainer
+from src.model import Encode2Points
+from src.data import PointCloudField, IndexField, Shapes3dDataset
+from src.utils import load_config, load_pointcloud
+from src.eval import MeshEvaluator
+from tqdm import tqdm
+from pdb import set_trace as st
+
+
+def main():
+ parser = argparse.ArgumentParser(description='MNIST toy experiment')
+ parser.add_argument('config', type=str, help='Path to config file.')
+ parser.add_argument('--no_cuda', action='store_true', default=False,
+ help='disables CUDA training')
+ parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
+ parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.')
+
+ args = parser.parse_args()
+ cfg = load_config(args.config, 'configs/default.yaml')
+ use_cuda = not args.no_cuda and torch.cuda.is_available()
+ device = torch.device("cuda" if use_cuda else "cpu")
+ data_type = cfg['data']['data_type']
+ # Shorthands
+ out_dir = cfg['train']['out_dir']
+ generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir'])
+
+ if cfg['generation'].get('iter', 0)!=0:
+ generation_dir += '_%04d'%cfg['generation']['iter']
+ elif args.iter is not None:
+ generation_dir += '_%04d'%args.iter
+
+ print('Evaluate meshes under %s'%generation_dir)
+
+ out_file = os.path.join(generation_dir, 'eval_meshes_full.pkl')
+ out_file_class = os.path.join(generation_dir, 'eval_meshes.csv')
+
+ # PYTORCH VERSION > 1.0.0
+ assert(float(torch.__version__.split('.')[-3]) > 0)
+
+ pointcloud_field = PointCloudField(cfg['data']['pointcloud_file'])
+ fields = {
+ 'pointcloud': pointcloud_field,
+ 'idx': IndexField(),
+ }
+
+ print('Test split: ', cfg['data']['test_split'])
+
+ dataset_folder = cfg['data']['path']
+ dataset = Shapes3dDataset(
+ dataset_folder, fields,
+ cfg['data']['test_split'],
+ categories=cfg['data']['class'], cfg=cfg)
+
+ # Loader
+ test_loader = torch.utils.data.DataLoader(
+ dataset, batch_size=1, num_workers=0, shuffle=False)
+
+ # Evaluator
+ evaluator = MeshEvaluator(n_points=100000)
+
+ eval_dicts = []
+ print('Evaluating meshes...')
+ for it, data in enumerate(tqdm(test_loader)):
+
+ if data is None:
+ print('Invalid data.')
+ continue
+
+ mesh_dir = os.path.join(generation_dir, 'meshes')
+ pointcloud_dir = os.path.join(generation_dir, 'pointcloud')
+
+
+ # Get index etc.
+ idx = data['idx'].item()
+ try:
+ model_dict = dataset.get_model_dict(idx)
+ except AttributeError:
+ model_dict = {'model': str(idx), 'category': 'n/a'}
+
+ modelname = model_dict['model']
+ category_id = model_dict['category']
+
+ try:
+ category_name = dataset.metadata[category_id].get('name', 'n/a')
+ except AttributeError:
+ category_name = 'n/a'
+
+ if category_id != 'n/a':
+ mesh_dir = os.path.join(mesh_dir, category_id)
+ pointcloud_dir = os.path.join(pointcloud_dir, category_id)
+
+ # Evaluate
+ pointcloud_tgt = data['pointcloud'].squeeze(0).numpy()
+ normals_tgt = data['pointcloud.normals'].squeeze(0).numpy()
+
+
+ eval_dict = {
+ 'idx': idx,
+ 'class id': category_id,
+ 'class name': category_name,
+ 'modelname':modelname,
+ }
+ eval_dicts.append(eval_dict)
+
+ # Evaluate mesh
+ if cfg['test']['eval_mesh']:
+ mesh_file = os.path.join(mesh_dir, '%s.off' % modelname)
+
+ if os.path.exists(mesh_file):
+ mesh = trimesh.load(mesh_file, process=False)
+ eval_dict_mesh = evaluator.eval_mesh(
+ mesh, pointcloud_tgt, normals_tgt)
+ for k, v in eval_dict_mesh.items():
+ eval_dict[k + ' (mesh)'] = v
+ else:
+ print('Warning: mesh does not exist: %s' % mesh_file)
+
+ # Evaluate point cloud
+ if cfg['test']['eval_pointcloud']:
+ pointcloud_file = os.path.join(
+ pointcloud_dir, '%s.ply' % modelname)
+
+ if os.path.exists(pointcloud_file):
+ pointcloud = load_pointcloud(pointcloud_file).astype(np.float32)
+ eval_dict_pcl = evaluator.eval_pointcloud(
+ pointcloud, pointcloud_tgt)
+ for k, v in eval_dict_pcl.items():
+ eval_dict[k + ' (pcl)'] = v
+ else:
+ print('Warning: pointcloud does not exist: %s'
+ % pointcloud_file)
+
+
+ # Create pandas dataframe and save
+ eval_df = pd.DataFrame(eval_dicts)
+ eval_df.set_index(['idx'], inplace=True)
+ eval_df.to_pickle(out_file)
+
+ # Create CSV file with main statistics
+ eval_df_class = eval_df.groupby(by=['class name']).mean()
+ eval_df_class.loc['mean'] = eval_df_class.mean()
+ eval_df_class.to_csv(out_file_class)
+
+ # Print results
+ print(eval_df_class)
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/generate.py b/generate.py
new file mode 100644
index 0000000..08dd28b
--- /dev/null
+++ b/generate.py
@@ -0,0 +1,217 @@
+import torch
+from torch.utils.data import Dataset, DataLoader
+import numpy as np; np.set_printoptions(precision=4)
+import shutil, argparse, time, os
+import pandas as pd
+from collections import defaultdict
+from src import config
+from src.utils import mc_from_psr, export_mesh, export_pointcloud
+from src.dpsr import DPSR
+from src.training import Trainer
+from src.model import Encode2Points
+from src.utils import load_config, load_model_manual, scale2onet, is_url, load_url
+from tqdm import tqdm
+from pdb import set_trace as st
+
+
+def main():
+ parser = argparse.ArgumentParser(description='MNIST toy experiment')
+ parser.add_argument('config', type=str, help='Path to config file.')
+ parser.add_argument('--no_cuda', action='store_true', default=False,
+ help='disables CUDA training')
+ parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
+ parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.')
+
+ args = parser.parse_args()
+ cfg = load_config(args.config, 'configs/default.yaml')
+ use_cuda = not args.no_cuda and torch.cuda.is_available()
+ device = torch.device("cuda" if use_cuda else "cpu")
+ data_type = cfg['data']['data_type']
+ input_type = cfg['data']['input_type']
+ vis_n_outputs = cfg['generation']['vis_n_outputs']
+ if vis_n_outputs is None:
+ vis_n_outputs = -1
+ # Shorthands
+ out_dir = cfg['train']['out_dir']
+ if not out_dir:
+ os.makedirs(out_dir)
+ generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir'])
+ out_time_file = os.path.join(generation_dir, 'time_generation_full.pkl')
+ out_time_file_class = os.path.join(generation_dir, 'time_generation.pkl')
+
+ # PYTORCH VERSION > 1.0.0
+ assert(float(torch.__version__.split('.')[-3]) > 0)
+
+ dataset = config.get_dataset('test', cfg, return_idx=True)
+ test_loader = torch.utils.data.DataLoader(
+ dataset, batch_size=1, num_workers=0, shuffle=False)
+
+ model = Encode2Points(cfg).to(device)
+
+ # load model
+ try:
+ if is_url(cfg['test']['model_file']):
+ state_dict = load_url(cfg['test']['model_file'])
+ elif cfg['generation'].get('iter', 0)!=0:
+ state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% cfg['generation']['iter']))
+ generation_dir += '_%04d'%cfg['generation']['iter']
+ elif args.iter is not None:
+ state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% args.iter))
+ else:
+ state_dict = torch.load(os.path.join(out_dir, 'model_best.pt'))
+
+ load_model_manual(state_dict['state_dict'], model)
+
+ except:
+ print('Model loading error. Exiting.')
+ exit()
+
+
+ # Generator
+ generator = config.get_generator(model, cfg, device=device)
+
+ # Determine what to generate
+ generate_mesh = cfg['generation']['generate_mesh']
+ generate_pointcloud = cfg['generation']['generate_pointcloud']
+
+ # Statistics
+ time_dicts = []
+
+ # Generate
+ model.eval()
+ dpsr = DPSR(res=(cfg['generation']['psr_resolution'],
+ cfg['generation']['psr_resolution'],
+ cfg['generation']['psr_resolution']),
+ sig= cfg['generation']['psr_sigma']).to(device)
+
+
+
+ # Count how many models already created
+ model_counter = defaultdict(int)
+
+ print('Generating...')
+ for it, data in enumerate(tqdm(test_loader)):
+
+ # Output folders
+ mesh_dir = os.path.join(generation_dir, 'meshes')
+ in_dir = os.path.join(generation_dir, 'input')
+ pointcloud_dir = os.path.join(generation_dir, 'pointcloud')
+ generation_vis_dir = os.path.join(generation_dir, 'vis', )
+
+ # Get index etc.
+ idx = data['idx'].item()
+
+ try:
+ model_dict = dataset.get_model_dict(idx)
+ except AttributeError:
+ model_dict = {'model': str(idx), 'category': 'n/a'}
+
+ modelname = model_dict['model']
+ category_id = model_dict['category']
+
+ try:
+ category_name = dataset.metadata[category_id].get('name', 'n/a')
+ except AttributeError:
+ category_name = 'n/a'
+
+ if category_id != 'n/a':
+ mesh_dir = os.path.join(mesh_dir, str(category_id))
+ pointcloud_dir = os.path.join(pointcloud_dir, str(category_id))
+ in_dir = os.path.join(in_dir, str(category_id))
+
+ folder_name = str(category_id)
+ if category_name != 'n/a':
+ folder_name = str(folder_name) + '_' + category_name.split(',')[0]
+
+ generation_vis_dir = os.path.join(generation_vis_dir, folder_name)
+
+ # Create directories if necessary
+ if vis_n_outputs >= 0 and not os.path.exists(generation_vis_dir):
+ os.makedirs(generation_vis_dir)
+
+ if generate_mesh and not os.path.exists(mesh_dir):
+ os.makedirs(mesh_dir)
+
+ if generate_pointcloud and not os.path.exists(pointcloud_dir):
+ os.makedirs(pointcloud_dir)
+
+ if not os.path.exists(in_dir):
+ os.makedirs(in_dir)
+
+ # Timing dict
+ time_dict = {
+ 'idx': idx,
+ 'class id': category_id,
+ 'class name': category_name,
+ 'modelname':modelname,
+ }
+ time_dicts.append(time_dict)
+
+ # Generate outputs
+ out_file_dict = {}
+
+ if generate_mesh:
+ #! deploy the generator to a separate class
+ out = generator.generate_mesh(data)
+
+ v, f, points, normals, stats_dict = out
+ time_dict.update(stats_dict)
+
+ # Write output
+ mesh_out_file = os.path.join(mesh_dir, '%s.off' % modelname)
+ export_mesh(mesh_out_file, scale2onet(v), f)
+ out_file_dict['mesh'] = mesh_out_file
+
+ if generate_pointcloud:
+ pointcloud_out_file = os.path.join(
+ pointcloud_dir, '%s.ply' % modelname)
+ export_pointcloud(pointcloud_out_file, scale2onet(points), normals)
+ out_file_dict['pointcloud'] = pointcloud_out_file
+
+ if cfg['generation']['copy_input']:
+ inputs_path = os.path.join(in_dir, '%s.ply' % modelname)
+ p = data.get('inputs').to(device)
+ export_pointcloud(inputs_path, scale2onet(p))
+ out_file_dict['in'] = inputs_path
+
+ # Copy to visualization directory for first vis_n_output samples
+ c_it = model_counter[category_id]
+ if c_it < vis_n_outputs:
+ # Save output files
+ img_name = '%02d.off' % c_it
+ for k, filepath in out_file_dict.items():
+ ext = os.path.splitext(filepath)[1]
+ out_file = os.path.join(generation_vis_dir, '%02d_%s%s'
+ % (c_it, k, ext))
+ shutil.copyfile(filepath, out_file)
+
+ # Also generate oracle meshes
+ if cfg['generation']['exp_oracle']:
+ points_gt = data.get('gt_points').to(device)
+ normals_gt = data.get('gt_points.normals').to(device)
+ psr_gt = dpsr(points_gt, normals_gt)
+ v, f, _ = mc_from_psr(psr_gt,
+ zero_level=cfg['data']['zero_level'])
+ out_file = os.path.join(generation_vis_dir, '%02d_%s%s'
+ % (c_it, 'mesh_oracle', '.off'))
+ export_mesh(out_file, scale2onet(v), f)
+
+ model_counter[category_id] += 1
+
+
+ # Create pandas dataframe and save
+ time_df = pd.DataFrame(time_dicts)
+ time_df.set_index(['idx'], inplace=True)
+ time_df.to_pickle(out_time_file)
+
+ # Create pickle files with main statistics
+ time_df_class = time_df.groupby(by=['class name']).mean()
+ time_df_class.loc['mean'] = time_df_class.mean()
+ time_df_class.to_pickle(out_time_file_class)
+
+ # Print results
+ print('Timings [s]:')
+ print(time_df_class)
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/media/results_large_noise.gif b/media/results_large_noise.gif
new file mode 100644
index 0000000..4feabda
Binary files /dev/null and b/media/results_large_noise.gif differ
diff --git a/media/results_outliers.gif b/media/results_outliers.gif
new file mode 100644
index 0000000..4b3abc0
Binary files /dev/null and b/media/results_outliers.gif differ
diff --git a/optim.py b/optim.py
new file mode 100644
index 0000000..fa552b0
--- /dev/null
+++ b/optim.py
@@ -0,0 +1,315 @@
+import torch
+import trimesh
+import shutil, argparse, time, os, glob
+
+import numpy as np; np.set_printoptions(precision=4)
+import open3d as o3d
+from torch.utils.tensorboard import SummaryWriter
+from torchvision.utils import save_image
+from torchvision.io import write_video
+
+from src.optimization import Trainer
+from src.utils import load_config, update_config, initialize_logger, \
+ get_learning_rate_schedules, adjust_learning_rate, AverageMeter,\
+ update_optimizer, export_pointcloud
+from skimage import measure
+from plyfile import PlyData
+from pytorch3d.ops import sample_points_from_meshes
+from pytorch3d.io import load_objs_as_meshes
+from pytorch3d.structures import Meshes
+
+
+def main():
+ parser = argparse.ArgumentParser(description='MNIST toy experiment')
+ parser.add_argument('config', type=str, help='Path to config file.')
+ parser.add_argument('--no_cuda', action='store_true', default=False,
+ help='disables CUDA training')
+ parser.add_argument('--seed', type=int, default=1457, metavar='S',
+ help='random seed')
+
+ args, unknown = parser.parse_known_args()
+ cfg = load_config(args.config, 'configs/default.yaml')
+ cfg = update_config(cfg, unknown)
+ use_cuda = not args.no_cuda and torch.cuda.is_available()
+ device = torch.device("cuda" if use_cuda else "cpu")
+ data_type = cfg['data']['data_type']
+ data_class = cfg['data']['class']
+
+ print(cfg['train']['out_dir'])
+
+ # PYTORCH VERSION > 1.0.0
+ assert(float(torch.__version__.split('.')[-3]) > 0)
+
+ # boiler-plate
+ if cfg['train']['timestamp']:
+ cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S")
+ logger = initialize_logger(cfg)
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+ np.random.seed(args.seed)
+ shutil.copyfile(args.config,
+ os.path.join(cfg['train']['out_dir'], 'config.yaml'))
+
+ # tensorboardX writer
+ tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log")
+ if not os.path.exists(tblogdir):
+ os.makedirs(tblogdir)
+ writer = SummaryWriter(log_dir=tblogdir)
+
+ # initialize o3d visualizer
+ vis = None
+ if cfg['train']['o3d_show']:
+ vis = o3d.visualization.Visualizer()
+ vis.create_window(width=cfg['train']['o3d_window_size'],
+ height=cfg['train']['o3d_window_size'])
+
+ # initialize dataset
+ if data_type == 'point':
+ if cfg['data']['object_id'] != -1:
+ data_paths = sorted(glob.glob(cfg['data']['data_path']))
+ data_path = data_paths[cfg['data']['object_id']]
+ print('Loaded %d/%d object' % (cfg['data']['object_id']+1, len(data_paths)))
+ else:
+ data_path = cfg['data']['data_path']
+ print('Data loaded')
+ ext = data_path.split('.')[-1]
+ if ext == 'obj': # have GT mesh
+ mesh = load_objs_as_meshes([data_path], device=device)
+ # scale the mesh into unit cube
+ verts = mesh.verts_packed()
+ N = verts.shape[0]
+ center = verts.mean(0)
+ mesh.offset_verts_(-center.expand(N, 3))
+ scale = max((verts - center).abs().max(0)[0])
+ mesh.scale_verts_((1.0 / float(scale)))
+ # important for our DPSR to have the range in [0, 1), not reaching 1
+ mesh.scale_verts_(0.9)
+
+ target_pts, target_normals = sample_points_from_meshes(mesh,
+ num_samples=200000, return_normals=True)
+ elif ext == 'ply': # only have the point cloud
+ plydata = PlyData.read(data_path)
+ vertices = np.stack([plydata['vertex']['x'],
+ plydata['vertex']['y'],
+ plydata['vertex']['z']], axis=1)
+ normals = np.stack([plydata['vertex']['nx'],
+ plydata['vertex']['ny'],
+ plydata['vertex']['nz']], axis=1)
+ N = vertices.shape[0]
+ center = vertices.mean(0)
+ scale = np.max(np.max(np.abs(vertices - center), axis=0))
+ vertices -= center
+ vertices /= scale
+ vertices *= 0.9
+
+ target_pts = torch.tensor(vertices, device=device)[None].float()
+ target_normals = torch.tensor(normals, device=device)[None].float()
+ mesh = None # no GT mesh
+
+ if not torch.is_tensor(center):
+ center = torch.from_numpy(center)
+ if not torch.is_tensor(scale):
+ scale = torch.from_numpy(np.array([scale]))
+
+ data = {'target_points': target_pts,
+ 'target_normals': target_normals, # normals are never used
+ 'gt_mesh': mesh}
+
+ else:
+ raise NotImplementedError
+
+ # save the input point cloud
+ if 'target_points' in data.keys():
+ outdir_pcl = os.path.join(cfg['train']['out_dir'], 'target_pcl.ply')
+ if 'target_normals' in data.keys():
+ export_pointcloud(outdir_pcl, data['target_points'], data['target_normals'])
+ else:
+ export_pointcloud(outdir_pcl, data['target_points'])
+
+ # save oracle PSR mesh (mesh from our PSR using GT point+normals)
+ if data.get('gt_mesh') is not None:
+ gt_verts, gt_faces = data['gt_mesh'].get_mesh_verts_faces(0)
+ pts_gt, norms_gt = sample_points_from_meshes(data['gt_mesh'],
+ num_samples=500000, return_normals=True)
+ pts_gt = (pts_gt + 1) / 2
+ from src.dpsr import DPSR
+ dpsr_tmp = DPSR(res=(cfg['model']['grid_res'],
+ cfg['model']['grid_res'],
+ cfg['model']['grid_res']),
+ sig=cfg['model']['psr_sigma']).to(device)
+ target = dpsr_tmp(pts_gt, norms_gt).unsqueeze(1).to(device)
+ target = torch.tanh(target)
+ s = target.shape[-1] # size of psr_grid
+ psr_grid_numpy = target.squeeze().detach().cpu().numpy()
+ verts, faces, _, _ = measure.marching_cubes(psr_grid_numpy)
+ verts = verts / s * 2. - 1 # [-1, 1]
+ mesh = o3d.geometry.TriangleMesh()
+ mesh.vertices = o3d.utility.Vector3dVector(verts)
+ mesh.triangles = o3d.utility.Vector3iVector(faces)
+ outdir_mesh = os.path.join(cfg['train']['out_dir'], 'oracle_mesh.ply')
+ o3d.io.write_triangle_mesh(outdir_mesh, mesh)
+
+ # initialize the source point cloud given an input mesh
+ if 'input_mesh' in cfg['train'].keys() and \
+ os.path.isfile(cfg['train']['input_mesh']):
+ if cfg['train']['input_mesh'].split('/')[-2] == 'mesh':
+ mesh_tmp = trimesh.load_mesh(cfg['train']['input_mesh'])
+ verts = torch.from_numpy(mesh_tmp.vertices[None]).float().to(device)
+ faces = torch.from_numpy(mesh_tmp.faces[None]).to(device)
+ mesh = Meshes(verts=verts, faces=faces)
+ points, normals = sample_points_from_meshes(mesh,
+ num_samples=cfg['data']['num_points'], return_normals=True)
+ # mesh is saved in the original scale of the gt
+ points -= center.float().to(device)
+ points /= scale.float().to(device)
+ points *= 0.9
+ # make sure the points are within the range of [0, 1)
+ points = points / 2. + 0.5
+ else:
+ # directly initialize from a point cloud
+ pcd = o3d.io.read_point_cloud(cfg['train']['input_mesh'])
+ points = torch.from_numpy(np.array(pcd.points)[None]).float().to(device)
+ normals = torch.from_numpy(np.array(pcd.normals)[None]).float().to(device)
+ points -= center.float().to(device)
+ points /= scale.float().to(device)
+ points *= 0.9
+ points = points / 2. + 0.5
+ else: #! initialize our source point cloud from a sphere
+ sphere_radius = cfg['model']['sphere_radius']
+ sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius,
+ count=[256,256])
+ points, idx = sphere_mesh.sample(cfg['data']['num_points'],
+ return_index=True)
+ points += 0.5 # make sure the points are within the range of [0, 1)
+ normals = sphere_mesh.face_normals[idx]
+ points = torch.from_numpy(points).unsqueeze(0).to(device)
+ normals = torch.from_numpy(normals).unsqueeze(0).to(device)
+
+
+ points = torch.log(points/(1-points)) # inverse sigmoid
+ inputs = torch.cat([points, normals], axis=-1).float()
+ inputs.requires_grad = True
+
+ model = None # no network
+
+ # initialize optimizer
+ cfg['train']['schedule']['pcl']['initial'] = cfg['train']['lr_pcl']
+ print('Initial learning rate:', cfg['train']['schedule']['pcl']['initial'])
+ if 'schedule' in cfg['train']:
+ lr_schedules = get_learning_rate_schedules(cfg['train']['schedule'])
+ else:
+ lr_schedules = None
+
+ optimizer = update_optimizer(inputs, cfg,
+ epoch=0, model=model, schedule=lr_schedules)
+
+ try:
+ # load model
+ state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt'))
+ if ('pcl' in state_dict.keys()) & (state_dict['pcl'] is not None):
+ inputs = state_dict['pcl'].to(device)
+ inputs.requires_grad = True
+
+ optimizer = update_optimizer(inputs, cfg,
+ epoch=state_dict.get('epoch'), schedule=lr_schedules)
+
+ out = "Load model from epoch %d" % state_dict.get('epoch', 0)
+ print(out)
+ logger.info(out)
+ except:
+ state_dict = dict()
+
+ start_epoch = state_dict.get('epoch', -1)
+
+ trainer = Trainer(cfg, optimizer, device=device)
+ runtime = {}
+ runtime['all'] = AverageMeter()
+
+ # training loop
+ for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1):
+
+ # schedule the learning rate
+ if (epoch>0) & (lr_schedules is not None):
+ if (epoch % lr_schedules[0].interval == 0):
+ adjust_learning_rate(lr_schedules, optimizer, epoch)
+ if len(lr_schedules) >1:
+ print('[epoch {}] net_lr: {}, pcl_lr: {}'.format(epoch,
+ lr_schedules[0].get_learning_rate(epoch),
+ lr_schedules[1].get_learning_rate(epoch)))
+ else:
+ print('[epoch {}] adjust pcl_lr to: {}'.format(epoch,
+ lr_schedules[0].get_learning_rate(epoch)))
+
+ start = time.time()
+ loss, loss_each = trainer.train_step(data, inputs, model, epoch)
+ runtime['all'].update(time.time() - start)
+
+ if epoch % cfg['train']['print_every'] == 0:
+ log_text = ('[Epoch %02d] loss=%.5f') %(epoch, loss)
+ if loss_each is not None:
+ for k, l in loss_each.items():
+ if l.item() != 0.:
+ log_text += (' loss_%s=%.5f') % (k, l.item())
+
+ log_text += (' time=%.3f / %.3f') % (runtime['all'].val,
+ runtime['all'].sum)
+ logger.info(log_text)
+ print(log_text)
+
+ # visualize point clouds and meshes
+ if (epoch % cfg['train']['visualize_every'] == 0) & (vis is not None):
+ trainer.visualize(data, inputs, model, epoch, o3d_vis=vis)
+
+ # save outputs
+ if epoch % cfg['train']['save_every'] == 0:
+ trainer.save_mesh_pointclouds(inputs, epoch,
+ center.cpu().numpy(),
+ scale.cpu().numpy()*(1/0.9))
+
+ # save checkpoints
+ if (epoch > 0) & (epoch % cfg['train']['checkpoint_every'] == 0):
+ state = {'epoch': epoch}
+ pcl = None
+ if isinstance(inputs, torch.Tensor):
+ state['pcl'] = inputs.detach().cpu()
+
+ torch.save(state, os.path.join(cfg['train']['dir_model'],
+ '%04d' % epoch + '.pt'))
+ print("Save new model at epoch %d" % epoch)
+ logger.info("Save new model at epoch %d" % epoch)
+ torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt'))
+
+ # resample and gradually add new points to the source pcl
+ if (epoch > 0) & \
+ (cfg['train']['resample_every']!=0) & \
+ (epoch % cfg['train']['resample_every'] == 0) & \
+ (epoch < cfg['train']['total_epochs']):
+ inputs = trainer.point_resampling(inputs)
+ optimizer = update_optimizer(inputs, cfg,
+ epoch=epoch, model=model, schedule=lr_schedules)
+ trainer = Trainer(cfg, optimizer, device=device)
+
+ # visualize the Open3D outputs
+ if cfg['train']['o3d_show']:
+ out_video_dir = os.path.join(cfg['train']['out_dir'],
+ 'vis/o3d/video.mp4')
+ if os.path.isfile(out_video_dir):
+ os.system('rm {}'.format(out_video_dir))
+ os.system('ffmpeg -framerate 30 \
+ -start_number 0 \
+ -i {}/vis/o3d/%04d.jpg \
+ -pix_fmt yuv420p \
+ -crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir))
+ out_video_dir = os.path.join(cfg['train']['out_dir'],
+ 'vis/o3d/video_pcd.mp4')
+ if os.path.isfile(out_video_dir):
+ os.system('rm {}'.format(out_video_dir))
+ os.system('ffmpeg -framerate 30 \
+ -start_number 0 \
+ -i {}/vis/o3d/%04d_pcd.jpg \
+ -pix_fmt yuv420p \
+ -crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir))
+ print('Video saved.')
+
+if __name__ == '__main__':
+ main()
diff --git a/optim_hierarchy.py b/optim_hierarchy.py
new file mode 100644
index 0000000..23347b9
--- /dev/null
+++ b/optim_hierarchy.py
@@ -0,0 +1,69 @@
+import sys, os
+import argparse
+from src.utils import load_config
+import subprocess
+os.environ['MKL_THREADING_LAYER'] = 'GNU'
+
+def main():
+
+ parser = argparse.ArgumentParser(description='MNIST toy experiment')
+ parser.add_argument('config', type=str, help='Path to config file.')
+ parser.add_argument('--start_res', type=int, default=-1, help='Resolution to start with.')
+ parser.add_argument('--object_id', type=int, default=-1, help='Object index.')
+
+ args, unknown = parser.parse_known_args()
+ cfg = load_config(args.config, 'configs/default.yaml')
+
+ resolutions=[32, 64, 128, 256]
+ iterations=[1000, 1000, 1000, 200]
+ lrs=[2e-3, 2e-3*0.7, 2e-3*(0.7**2), 2e-3*(0.7**3)] # reduce lr
+ for idx,(res, iteration, lr) in enumerate(zip(resolutions, iterations, lrs)):
+
+ if rescfg['model']['grid_res']:
+ continue
+
+ psr_sigma= 2 if res<=128 else 3
+
+ if res > 128:
+ psr_sigma = 5 if 'thingi_noisy' in args.config else 3
+
+ if args.object_id != -1:
+ out_dir = os.path.join(cfg['train']['out_dir'], 'object_%02d'%args.object_id, 'res_%d'%res)
+ else:
+ out_dir = os.path.join(cfg['train']['out_dir'], 'res_%d'%res)
+
+ # sample from mesh when resampling is enabled, otherwise reuse the pointcloud
+ init_shape='mesh' if cfg['train']['resample_every']>0 else 'pointcloud'
+
+
+ if args.object_id != -1:
+ input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'],
+ 'object_%02d'%args.object_id, 'res_%d' % (resolutions[idx-1]),
+ 'vis', init_shape, '%04d.ply' % (iterations[idx-1]))
+ else:
+ input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'],
+ 'res_%d' % (resolutions[idx-1]),
+ 'vis', init_shape, '%04d.ply' % (iterations[idx-1]))
+
+
+ cmd = 'export MKL_SERVICE_FORCE_INTEL=1 && '
+ cmd += "python optim.py %s --model:grid_res %d --model:psr_sigma %d \
+ --train:input_mesh %s --train:total_epochs %d \
+ --train:out_dir %s --train:lr_pcl %f \
+ --data:object_id %d" % (
+ args.config,
+ res,
+ psr_sigma,
+ input_mesh,
+ iteration,
+ out_dir,
+ lr,
+ args.object_id)
+ print(cmd)
+ os.system(cmd)
+
+if __name__=="__main__":
+ main()
diff --git a/scripts/download_demo_data.sh b/scripts/download_demo_data.sh
new file mode 100644
index 0000000..8f617b2
--- /dev/null
+++ b/scripts/download_demo_data.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+mkdir -p data
+cd data
+echo "Start downloading ..."
+wget https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/data/demo.zip
+unzip demo.zip
+rm demo.zip
+echo "Done!"
\ No newline at end of file
diff --git a/scripts/download_optim_data.sh b/scripts/download_optim_data.sh
new file mode 100644
index 0000000..43dbc47
--- /dev/null
+++ b/scripts/download_optim_data.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+mkdir -p data
+cd data
+echo "Start downloading data for optimization-based setting (~200 MB)"
+wget https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/data/optim_data.zip
+unzip optim_data.zip
+rm optim_data.zip
+echo "Done!"
\ No newline at end of file
diff --git a/scripts/download_shapenet.sh b/scripts/download_shapenet.sh
new file mode 100644
index 0000000..6d6e90c
--- /dev/null
+++ b/scripts/download_shapenet.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+mkdir -p data
+cd data
+echo "Start downloading preprocessed ShapeNet data (~220G)"
+wget https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/data/shapenet_psr.zip
+unzip shapenet_psr.zip
+rm shapenet_psr.zip
+echo "Done!"
\ No newline at end of file
diff --git a/scripts/process_shapenet.py b/scripts/process_shapenet.py
new file mode 100644
index 0000000..cde672b
--- /dev/null
+++ b/scripts/process_shapenet.py
@@ -0,0 +1,101 @@
+import os
+import torch
+import time
+import multiprocessing
+import numpy as np
+from tqdm import tqdm
+from src.dpsr import DPSR
+
+data_path = 'data/ShapeNet' # path for ShapeNet from ONet
+base = 'data' # output base directory
+dataset_name = 'shapenet_psr'
+multiprocess = True
+njobs = 8
+save_pointcloud = True
+save_psr_field = True
+resolution = 128
+zero_level = 0.0
+num_points = 100000
+padding = 1.2
+
+dpsr = DPSR(res=(resolution, resolution, resolution), sig=0)
+
+def process_one(obj):
+
+ obj_name = obj.split('/')[-1]
+ c = obj.split('/')[-2]
+
+ # create new for the current object
+ out_path_cur = os.path.join(base, dataset_name, c)
+ out_path_cur_obj = os.path.join(out_path_cur, obj_name)
+ os.makedirs(out_path_cur_obj, exist_ok=True)
+
+ gt_path = os.path.join(data_path, c, obj_name, 'pointcloud.npz')
+ data = np.load(gt_path)
+ points = data['points']
+ normals = data['normals']
+
+ # normalize the point to [0, 1)
+ points = points / padding + 0.5
+ # to scale back during inference, we should:
+ #! p = (p - 0.5) * padding
+
+ if save_pointcloud:
+ outdir = os.path.join(out_path_cur_obj, 'pointcloud.npz')
+ # np.savez(outdir, points=points, normals=normals)
+ np.savez(outdir, points=data['points'], normals=data['normals'])
+ # return
+
+ if save_psr_field:
+ psr_gt = dpsr(torch.from_numpy(points.astype(np.float32))[None],
+ torch.from_numpy(normals.astype(np.float32))[None]).squeeze().cpu().numpy().astype(np.float16)
+
+ outdir = os.path.join(out_path_cur_obj, 'psr.npz')
+ np.savez(outdir, psr=psr_gt)
+
+
+def main(c):
+
+ print('---------------------------------------')
+ print('Processing {} {}'.format(c, split))
+ print('---------------------------------------')
+
+ for split in ['train', 'val', 'test']:
+ fname = os.path.join(data_path, c, split+'.lst')
+ with open(fname, 'r') as f:
+ obj_list = f.read().splitlines()
+
+ obj_list = [c+'/'+s for s in obj_list]
+
+ if multiprocess:
+ # multiprocessing.set_start_method('spawn', force=True)
+ pool = multiprocessing.Pool(njobs)
+ try:
+ for _ in tqdm(pool.imap_unordered(process_one, obj_list), total=len(obj_list)):
+ pass
+ # pool.map_async(process_one, obj_list).get()
+ except KeyboardInterrupt:
+ # Allow ^C to interrupt from any thread.
+ exit()
+ pool.close()
+ else:
+ for obj in tqdm(obj_list):
+ process_one(obj)
+
+ print('Done Processing {} {}!'.format(c, split))
+
+
+if __name__ == "__main__":
+
+ classes = ['02691156', '02828884', '02933112',
+ '02958343', '03211117', '03001627',
+ '03636649', '03691459', '04090263',
+ '04256520', '04379243', '04401088', '04530566']
+
+
+ t_start = time.time()
+ for c in classes:
+ main(c)
+
+ t_end = time.time()
+ print('Total processing time: ', t_end - t_start)
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/config.py b/src/config.py
new file mode 100644
index 0000000..2c318e3
--- /dev/null
+++ b/src/config.py
@@ -0,0 +1,146 @@
+import yaml
+from torchvision import transforms
+from src import data, generation
+from src.dpsr import DPSR
+from ipdb import set_trace as st
+
+
+# Generator for final mesh extraction
+def get_generator(model, cfg, device, **kwargs):
+ ''' Returns the generator object.
+
+ Args:
+ model (nn.Module): Occupancy Network model
+ cfg (dict): imported yaml config
+ device (device): pytorch device
+ '''
+
+ if cfg['generation']['psr_resolution'] == 0:
+ psr_res = cfg['model']['grid_res']
+ psr_sigma = cfg['model']['psr_sigma']
+ else:
+ psr_res = cfg['generation']['psr_resolution']
+ psr_sigma = cfg['generation']['psr_sigma']
+
+ dpsr = DPSR(res=(psr_res, psr_res, psr_res),
+ sig= psr_sigma).to(device)
+
+
+ generator = generation.Generator3D(
+ model,
+ device=device,
+ threshold=cfg['data']['zero_level'],
+ sample=cfg['generation']['use_sampling'],
+ input_type = cfg['data']['input_type'],
+ padding=cfg['data']['padding'],
+ dpsr=dpsr,
+ psr_tanh=cfg['model']['psr_tanh']
+ )
+ return generator
+
+# Datasets
+def get_dataset(mode, cfg, return_idx=False):
+ ''' Returns the dataset.
+
+ Args:
+ model (nn.Module): the model which is used
+ cfg (dict): config dictionary
+ return_idx (bool): whether to include an ID field
+ '''
+ dataset_type = cfg['data']['dataset']
+ dataset_folder = cfg['data']['path']
+ categories = cfg['data']['class']
+
+ # Get split
+ splits = {
+ 'train': cfg['data']['train_split'],
+ 'val': cfg['data']['val_split'],
+ 'test': cfg['data']['test_split'],
+ 'vis': cfg['data']['val_split'],
+ }
+
+ split = splits[mode]
+
+ # Create dataset
+ if dataset_type == 'Shapes3D':
+ fields = get_data_fields(mode, cfg)
+ # Input fields
+ inputs_field = get_inputs_field(mode, cfg)
+ if inputs_field is not None:
+ fields['inputs'] = inputs_field
+
+ if return_idx:
+ fields['idx'] = data.IndexField()
+
+ dataset = data.Shapes3dDataset(
+ dataset_folder, fields,
+ split=split,
+ categories=categories,
+ cfg = cfg
+ )
+ else:
+ raise ValueError('Invalid dataset "%s"' % cfg['data']['dataset'])
+
+ return dataset
+
+
+def get_inputs_field(mode, cfg):
+ ''' Returns the inputs fields.
+
+ Args:
+ mode (str): the mode which is used
+ cfg (dict): config dictionary
+ '''
+ input_type = cfg['data']['input_type']
+
+ if input_type is None:
+ inputs_field = None
+ elif input_type == 'pointcloud':
+ noise_level = cfg['data']['pointcloud_noise']
+ if cfg['data']['pointcloud_outlier_ratio']>0:
+ transform = transforms.Compose([
+ data.SubsamplePointcloud(cfg['data']['pointcloud_n']),
+ data.PointcloudNoise(noise_level),
+ data.PointcloudOutliers(cfg['data']['pointcloud_outlier_ratio'])
+ ])
+ else:
+ transform = transforms.Compose([
+ data.SubsamplePointcloud(cfg['data']['pointcloud_n']),
+ data.PointcloudNoise(noise_level)
+ ])
+
+ data_type = cfg['data']['data_type']
+ inputs_field = data.PointCloudField(
+ cfg['data']['pointcloud_file'], data_type, transform,
+ multi_files= cfg['data']['multi_files']
+ )
+ else:
+ raise ValueError(
+ 'Invalid input type (%s)' % input_type)
+ return inputs_field
+
+def get_data_fields(mode, cfg):
+ ''' Returns the data fields.
+
+ Args:
+ mode (str): the mode which is used
+ cfg (dict): imported yaml config
+ '''
+ data_type = cfg['data']['data_type']
+ fields = {}
+
+ if (mode in ('val', 'test')):
+ transform = data.SubsamplePointcloud(100000)
+ else:
+ transform = data.SubsamplePointcloud(cfg['data']['num_gt_points'])
+
+ data_name = cfg['data']['pointcloud_file']
+ fields['gt_points'] = data.PointCloudField(data_name,
+ transform=transform, data_type=data_type, multi_files=cfg['data']['multi_files'])
+ if data_type == 'psr_full':
+ if mode != 'test':
+ fields['gt_psr'] = data.FullPSRField(multi_files=cfg['data']['multi_files'])
+ else:
+ raise ValueError('Invalid data type (%s)' % data_type)
+
+ return fields
\ No newline at end of file
diff --git a/src/data/__init__.py b/src/data/__init__.py
new file mode 100644
index 0000000..a7f46c1
--- /dev/null
+++ b/src/data/__init__.py
@@ -0,0 +1,25 @@
+
+from src.data.core import (
+ Shapes3dDataset, collate_remove_none, worker_init_fn, collate_stack_together
+)
+from src.data.fields import (
+ IndexField, PointCloudField, FullPSRField
+)
+from src.data.transforms import (
+ PointcloudNoise, SubsamplePointcloud,
+ PointcloudOutliers,
+)
+__all__ = [
+ # Core
+ Shapes3dDataset,
+ collate_remove_none,
+ worker_init_fn,
+ # Fields
+ IndexField,
+ PointCloudField,
+ FullPSRField,
+ # Transforms
+ PointcloudNoise,
+ SubsamplePointcloud,
+ PointcloudOutliers,
+]
diff --git a/src/data/core.py b/src/data/core.py
new file mode 100644
index 0000000..142747c
--- /dev/null
+++ b/src/data/core.py
@@ -0,0 +1,237 @@
+import os
+import logging
+from torch.utils import data
+from pdb import set_trace as st
+import numpy as np
+import yaml
+
+
+logger = logging.getLogger(__name__)
+
+
+# Fields
+class Field(object):
+ ''' Data fields class.
+ '''
+
+ def load(self, data_path, idx, category):
+ ''' Loads a data point.
+
+ Args:
+ data_path (str): path to data file
+ idx (int): index of data point
+ category (int): index of category
+ '''
+ raise NotImplementedError
+
+ def check_complete(self, files):
+ ''' Checks if set is complete.
+
+ Args:
+ files: files
+ '''
+ raise NotImplementedError
+
+
+class Shapes3dDataset(data.Dataset):
+ ''' 3D Shapes dataset class.
+ '''
+
+ def __init__(self, dataset_folder, fields, split=None,
+ categories=None, no_except=True, transform=None, cfg=None):
+ ''' Initialization of the the 3D shape dataset.
+
+ Args:
+ dataset_folder (str): dataset folder
+ fields (dict): dictionary of fields
+ split (str): which split is used
+ categories (list): list of categories to use
+ no_except (bool): no exception
+ transform (callable): transformation applied to data points
+ cfg (yaml): config file
+ '''
+ # Attributes
+ self.dataset_folder = dataset_folder
+ self.fields = fields
+ self.no_except = no_except
+ self.transform = transform
+ self.cfg = cfg
+
+ # If categories is None, use all subfolders
+ if categories is None:
+ categories = os.listdir(dataset_folder)
+ categories = [c for c in categories
+ if os.path.isdir(os.path.join(dataset_folder, c))]
+
+ # Read metadata file
+ metadata_file = os.path.join(dataset_folder, 'metadata.yaml')
+
+ if os.path.exists(metadata_file):
+ with open(metadata_file, 'r') as f:
+ self.metadata = yaml.load(f, Loader=yaml.Loader)
+ else:
+ self.metadata = {
+ c: {'id': c, 'name': 'n/a'} for c in categories
+ }
+
+ # Set index
+ for c_idx, c in enumerate(categories):
+ self.metadata[c]['idx'] = c_idx
+
+ # Get all models
+ self.models = []
+ for c_idx, c in enumerate(categories):
+ subpath = os.path.join(dataset_folder, c)
+ if not os.path.isdir(subpath):
+ logger.warning('Category %s does not exist in dataset.' % c)
+
+ if split is None:
+ self.models += [
+ {'category': c, 'model': m} for m in [d for d in os.listdir(subpath) if (os.path.isdir(os.path.join(subpath, d)) and d != '') ]
+ ]
+
+ else:
+ split_file = os.path.join(subpath, split + '.lst')
+ with open(split_file, 'r') as f:
+ models_c = f.read().split('\n')
+
+ if '' in models_c:
+ models_c.remove('')
+
+ self.models += [
+ {'category': c, 'model': m}
+ for m in models_c
+ ]
+
+ # precompute
+ self.split = split
+
+ def __len__(self):
+ ''' Returns the length of the dataset.
+ '''
+ return len(self.models)
+
+ def __getitem__(self, idx):
+ ''' Returns an item of the dataset.
+
+ Args:
+ idx (int): ID of data point
+ '''
+
+ category = self.models[idx]['category']
+ model = self.models[idx]['model']
+ c_idx = self.metadata[category]['idx']
+
+ model_path = os.path.join(self.dataset_folder, category, model)
+ data = {}
+
+ info = c_idx
+
+ if self.cfg['data']['multi_files'] is not None:
+ idx = np.random.randint(self.cfg['data']['multi_files'])
+ if self.split != 'train':
+ idx = 0
+
+ for field_name, field in self.fields.items():
+ try:
+ field_data = field.load(model_path, idx, info)
+ except Exception:
+ if self.no_except:
+ logger.warn(
+ 'Error occured when loading field %s of model %s'
+ % (field_name, model)
+ )
+ return None
+ else:
+ raise
+
+ if isinstance(field_data, dict):
+ for k, v in field_data.items():
+ if k is None:
+ data[field_name] = v
+ else:
+ data['%s.%s' % (field_name, k)] = v
+ else:
+ data[field_name] = field_data
+
+ if self.transform is not None:
+ data = self.transform(data)
+
+ return data
+
+
+ def get_model_dict(self, idx):
+ return self.models[idx]
+
+ def test_model_complete(self, category, model):
+ ''' Tests if model is complete.
+
+ Args:
+ model (str): modelname
+ '''
+ model_path = os.path.join(self.dataset_folder, category, model)
+ files = os.listdir(model_path)
+ for field_name, field in self.fields.items():
+ if not field.check_complete(files):
+ logger.warn('Field "%s" is incomplete: %s'
+ % (field_name, model_path))
+ return False
+
+ return True
+
+
+def collate_remove_none(batch):
+ ''' Collater that puts each data field into a tensor with outer dimension
+ batch size.
+
+ Args:
+ batch: batch
+ '''
+
+ batch = list(filter(lambda x: x is not None, batch))
+ return data.dataloader.default_collate(batch)
+
+def collate_stack_together(batch):
+ ''' Collater that puts each data field into a tensor with outer dimension
+ batch size.
+
+ Args:
+ batch: batch
+ '''
+
+ batch = list(filter(lambda x: x is not None, batch))
+ keys = batch[0].keys()
+ concat = {}
+ if len(batch)>1:
+ for key in keys:
+ key_val = [item[key] for item in batch]
+ concat[key] = np.concatenate(key_val, axis=0)
+ if key == 'inputs':
+ n_pts = [item[key].shape[0] for item in batch]
+
+ concat['batch_ind'] = np.concatenate(
+ [i * np.ones(n, dtype=int) for i, n in enumerate(n_pts)], axis=0)
+
+ return data.dataloader.default_collate([concat])
+ else:
+ n_pts = batch[0]['inputs'].shape[0]
+ batch[0]['batch_ind'] = np.zeros(n_pts, dtype=int)
+ return data.dataloader.default_collate(batch)
+
+
+def worker_init_fn(worker_id):
+ ''' Worker init function to ensure true randomness.
+ '''
+ def set_num_threads(nt):
+ try:
+ import mkl; mkl.set_num_threads(nt)
+ except:
+ pass
+ torch.set_num_threads(1)
+ os.environ['IPC_ENABLE']='1'
+ for o in ['OPENBLAS_NUM_THREADS','NUMEXPR_NUM_THREADS','OMP_NUM_THREADS','MKL_NUM_THREADS']:
+ os.environ[o] = str(nt)
+
+ random_data = os.urandom(4)
+ base_seed = int.from_bytes(random_data, byteorder="big")
+ np.random.seed(base_seed + worker_id)
diff --git a/src/data/fields.py b/src/data/fields.py
new file mode 100644
index 0000000..7f78435
--- /dev/null
+++ b/src/data/fields.py
@@ -0,0 +1,118 @@
+import os
+import glob
+import time
+import random
+from PIL import Image
+import numpy as np
+import trimesh
+from src.data.core import Field
+from pdb import set_trace as st
+
+
+class IndexField(Field):
+ ''' Basic index field.'''
+ def load(self, model_path, idx, category):
+ ''' Loads the index field.
+
+ Args:
+ model_path (str): path to model
+ idx (int): ID of data point
+ category (int): index of category
+ '''
+ return idx
+
+ def check_complete(self, files):
+ ''' Check if field is complete.
+
+ Args:
+ files: files
+ '''
+ return True
+
+class FullPSRField(Field):
+ def __init__(self, transform=None, multi_files=None):
+ self.transform = transform
+ # self.unpackbits = unpackbits
+ self.multi_files = multi_files
+
+ def load(self, model_path, idx, category):
+
+ # try:
+ # t0 = time.time()
+ if self.multi_files is not None:
+ psr_path = os.path.join(model_path, 'psr', 'psr_{:02d}.npz'.format(idx))
+ else:
+ psr_path = os.path.join(model_path, 'psr.npz')
+ psr_dict = np.load(psr_path)
+ # t1 = time.time()
+ psr = psr_dict['psr']
+ psr = psr.astype(np.float32)
+ # t2 = time.time()
+ # print('load PSR: {:.4f}, change type: {:.4f}, total: {:.4f}'.format(t1 - t0, t2 - t1, t2-t0))
+ data = {None: psr}
+
+ if self.transform is not None:
+ data = self.transform(data)
+
+ return data
+
+class PointCloudField(Field):
+ ''' Point cloud field.
+
+ It provides the field used for point cloud data. These are the points
+ randomly sampled on the mesh.
+
+ Args:
+ file_name (str): file name
+ transform (list): list of transformations applied to data points
+ multi_files (callable): number of files
+ '''
+ def __init__(self, file_name, data_type=None, transform=None, multi_files=None, padding=0.1, scale=1.2):
+ self.file_name = file_name
+ self.data_type = data_type # to make sure the range of input is correct
+ self.transform = transform
+ self.multi_files = multi_files
+ self.padding = padding
+ self.scale = scale
+
+ def load(self, model_path, idx, category):
+ ''' Loads the data point.
+
+ Args:
+ model_path (str): path to model
+ idx (int): ID of data point
+ category (int): index of category
+ '''
+ if self.multi_files is None:
+ file_path = os.path.join(model_path, self.file_name)
+ else:
+ # num = np.random.randint(self.multi_files)
+ # file_path = os.path.join(model_path, self.file_name, '%s_%02d.npz' % (self.file_name, num))
+ file_path = os.path.join(model_path, self.file_name, 'pointcloud_%02d.npz' % (idx))
+
+ pointcloud_dict = np.load(file_path)
+
+ points = pointcloud_dict['points'].astype(np.float32)
+ normals = pointcloud_dict['normals'].astype(np.float32)
+
+ data = {
+ None: points,
+ 'normals': normals,
+ }
+ if self.transform is not None:
+ data = self.transform(data)
+
+ if self.data_type == 'psr_full':
+ # scale the point cloud to the range of (0, 1)
+ data[None] = data[None] / self.scale + 0.5
+
+ return data
+
+ def check_complete(self, files):
+ ''' Check if field is complete.
+
+ Args:
+ files: files
+ '''
+ complete = (self.file_name in files)
+ return complete
diff --git a/src/data/transforms.py b/src/data/transforms.py
new file mode 100644
index 0000000..8909594
--- /dev/null
+++ b/src/data/transforms.py
@@ -0,0 +1,86 @@
+import numpy as np
+
+
+# Transforms
+class PointcloudNoise(object):
+ ''' Point cloud noise transformation class.
+
+ It adds noise to point cloud data.
+
+ Args:
+ stddev (int): standard deviation
+ '''
+
+ def __init__(self, stddev):
+ self.stddev = stddev
+
+ def __call__(self, data):
+ ''' Calls the transformation.
+
+ Args:
+ data (dictionary): data dictionary
+ '''
+ data_out = data.copy()
+ points = data[None]
+ noise = self.stddev * np.random.randn(*points.shape)
+ noise = noise.astype(np.float32)
+ data_out[None] = points + noise
+ return data_out
+
+class PointcloudOutliers(object):
+ ''' Point cloud outlier transformation class.
+
+ It adds outliers to point cloud data.
+
+ Args:
+ ratio (int): outlier percentage to the entire point cloud
+ '''
+
+ def __init__(self, ratio):
+ self.ratio = ratio
+
+ def __call__(self, data):
+ ''' Calls the transformation.
+
+ Args:
+ data (dictionary): data dictionary
+ '''
+ data_out = data.copy()
+ points = data[None]
+ n_points = points.shape[0]
+ n_outlier_points = int(n_points*self.ratio)
+ ind = np.random.randint(0, n_points, n_outlier_points)
+
+ outliers = np.random.uniform(-0.55, 0.55, (n_outlier_points, 3))
+ outliers = outliers.astype(np.float32)
+ points[ind] = outliers
+ data_out[None] = points
+ return data_out
+
+class SubsamplePointcloud(object):
+ ''' Point cloud subsampling transformation class.
+
+ It subsamples the point cloud data.
+
+ Args:
+ N (int): number of points to be subsampled
+ '''
+ def __init__(self, N):
+ self.N = N
+
+ def __call__(self, data):
+ ''' Calls the transformation.
+
+ Args:
+ data (dict): data dictionary
+ '''
+ data_out = data.copy()
+ points = data[None]
+
+ indices = np.random.randint(points.shape[0], size=self.N)
+ data_out[None] = points[indices, :]
+ if 'normals' in data.keys():
+ normals = data['normals']
+ data_out['normals'] = normals[indices, :]
+
+ return data_out
\ No newline at end of file
diff --git a/src/data_loader.py b/src/data_loader.py
new file mode 100644
index 0000000..75ea59c
--- /dev/null
+++ b/src/data_loader.py
@@ -0,0 +1,228 @@
+import os
+import cv2
+import torch
+import numpy as np
+from glob import glob
+from torch.utils import data
+from src.utils import load_rgb, load_mask, get_camera_params
+from pytorch3d.renderer import PerspectiveCameras
+from skimage import img_as_float32
+
+##################################################
+# Below are for the differentiable renderer
+# Taken from https://github.com/lioryariv/idr/blob/main/code/utils/rend_util.py
+
+def load_rgb(path):
+ img = imageio.imread(path)
+ img = img_as_float32(img)
+
+ # pixel values between [-1,1]
+ # img -= 0.5
+ # img *= 2.
+
+ # img = img.transpose(2, 0, 1)
+ return img
+
+def load_mask(path):
+ alpha = imageio.imread(path, as_gray=True)
+ alpha = img_as_float32(alpha)
+ object_mask = alpha > 127.5
+
+ return object_mask
+
+
+def get_camera_params(uv, pose, intrinsics):
+ if pose.shape[1] == 7: #In case of quaternion vector representation
+ cam_loc = pose[:, 4:]
+ R = quat_to_rot(pose[:,:4])
+ p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float()
+ p[:, :3, :3] = R
+ p[:, :3, 3] = cam_loc
+ else: # In case of pose matrix representation
+ cam_loc = pose[:, :3, 3]
+ p = pose
+
+ batch_size, num_samples, _ = uv.shape
+
+ depth = torch.ones((batch_size, num_samples))
+ x_cam = uv[:, :, 0].view(batch_size, -1)
+ y_cam = uv[:, :, 1].view(batch_size, -1)
+ z_cam = depth.view(batch_size, -1)
+
+ pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics)
+
+ # permute for batch matrix product
+ pixel_points_cam = pixel_points_cam.permute(0, 2, 1)
+
+ world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3]
+ ray_dirs = world_coords - cam_loc[:, None, :]
+ ray_dirs = F.normalize(ray_dirs, dim=2)
+
+ return ray_dirs, cam_loc
+
+def quat_to_rot(q):
+ batch_size, _ = q.shape
+ q = F.normalize(q, dim=1)
+ R = torch.ones((batch_size, 3,3)).cuda()
+ qr=q[:,0]
+ qi = q[:, 1]
+ qj = q[:, 2]
+ qk = q[:, 3]
+ R[:, 0, 0]=1-2 * (qj**2 + qk**2)
+ R[:, 0, 1] = 2 * (qj *qi -qk*qr)
+ R[:, 0, 2] = 2 * (qi * qk + qr * qj)
+ R[:, 1, 0] = 2 * (qj * qi + qk * qr)
+ R[:, 1, 1] = 1-2 * (qi**2 + qk**2)
+ R[:, 1, 2] = 2*(qj*qk - qi*qr)
+ R[:, 2, 0] = 2 * (qk * qi-qj * qr)
+ R[:, 2, 1] = 2 * (qj*qk + qi*qr)
+ R[:, 2, 2] = 1-2 * (qi**2 + qj**2)
+ return R
+
+def lift(x, y, z, intrinsics):
+ # parse intrinsics
+ # intrinsics = intrinsics.cuda()
+ fx = intrinsics[:, 0, 0]
+ fy = intrinsics[:, 1, 1]
+ cx = intrinsics[:, 0, 2]
+ cy = intrinsics[:, 1, 2]
+ sk = intrinsics[:, 0, 1]
+
+ x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z
+ y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z
+
+ # homogeneous
+ return torch.stack((x_lift, y_lift, z, torch.ones_like(z)), dim=-1)
+
+
+class PixelNeRFDTUDataset(data.Dataset):
+ """
+ Processed DTU from pixelNeRF
+ """
+ def __init__(self,
+ data_dir='data/DTU',
+ scan_id=65,
+ img_size=None,
+ device=None,
+ fixed_scale=0,
+ ):
+ data_dir = os.path.join(data_dir, "scan{}".format(scan_id))
+ rgb_paths = [
+ x for x in glob(os.path.join(data_dir, "image", "*"))
+ if (x.endswith(".jpg") or x.endswith(".png"))
+ ]
+ rgb_paths = sorted(rgb_paths)
+ mask_paths = sorted(glob(os.path.join(data_dir, "mask", "*.png")))
+ if len(mask_paths) == 0:
+ mask_paths = [None] * len(rgb_paths)
+ sel_indices = np.arange(len(rgb_paths))
+
+ cam_path = os.path.join(data_dir, "cameras.npz")
+ all_cam = np.load(cam_path)
+ all_imgs = []
+ all_poses = []
+ all_masks = []
+ all_rays = []
+ all_light_pose = []
+ all_K = []
+ all_R = []
+ all_T = []
+
+ for idx, (rgb_path, mask_path) in enumerate(zip(rgb_paths, mask_paths)):
+
+ i = sel_indices[idx]
+ rgb = load_rgb(rgb_path)
+ mask = load_mask(mask_path)
+ rgb[~mask] = 0.
+ rgb = torch.from_numpy(rgb).float().to(device)
+ mask = torch.from_numpy(mask).float().to(device)
+ x_scale = y_scale = 1.0
+ xy_delta = 0.0
+
+ P = all_cam["world_mat_" + str(i)]
+ P = P[:3]
+
+ # scale the original shape to really [-0.9, 0.9]
+ if fixed_scale!=0.:
+ scale_mat_new = np.eye(4, 4)
+ scale_mat_new[:3, :3] *= fixed_scale # scale to [-0.9, 0.9]
+ P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)] @ scale_mat_new
+ else:
+ P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)]
+
+ P = P[:3, :4]
+ K, R, t = cv2.decomposeProjectionMatrix(P)[:3]
+ K = K / K[2, 2]
+
+ fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
+
+ ########!!!!!
+ RR = torch.from_numpy(R).permute(1, 0).unsqueeze(0)
+ tt = torch.from_numpy(-R@(t[:3] / t[3])).permute(1, 0)
+ focal = torch.tensor((fx, fy), dtype=torch.float32).unsqueeze(0)
+ pc = torch.tensor((cx, cy), dtype=torch.float32).unsqueeze(0)
+ im_size = (rgb.shape[1], rgb.shape[0])
+
+ # check https://pytorch3d.org/docs/cameras for how to transform from screen to NDC
+ s = min(im_size)
+ focal[:, 0] = focal[:, 0] * 2 / (s-1)
+ focal[:, 1] = focal[:, 1] * 2 /(s-1)
+ pc[:, 0] = -(pc[:, 0] - (im_size[0]-1)/2) * 2 / (s-1)
+ pc[:, 1] = -(pc[:, 1] - (im_size[1]-1)/2) * 2 / (s-1)
+
+ camera = PerspectiveCameras(focal_length=-focal, principal_point=pc,
+ device=device, R=RR, T=tt)
+
+ # calculate camera rays
+ uv = uv_creation(im_size)[None].float()
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose()
+ pose[:3,3] = (t[:3] / t[3])[:,0]
+ pose = torch.from_numpy(pose)[None].float()
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+ intrinsics[0, 1] = 0. #! remove skew for now
+ intrinsics = torch.from_numpy(intrinsics)[None].float()
+
+
+ rays, _ = get_camera_params(uv, pose, intrinsics)
+ rays = -rays.to(device)
+
+
+
+ all_poses.append(camera)
+ all_imgs.append(rgb)
+ all_masks.append(mask)
+ all_rays.append(rays)
+ all_light_pose.append(pose)
+ # only for neural renderer
+ all_K.append(torch.tensor(K).to(device))
+ all_R.append(torch.tensor(R).to(device))
+ all_T.append(torch.tensor(t[:3]/t[3]).to(device))
+
+ all_imgs = torch.stack(all_imgs)
+ all_masks = torch.stack(all_masks)
+ all_rays = torch.stack(all_rays)
+ all_light_pose = torch.stack(all_light_pose).squeeze()
+ # only for neural renderer
+ all_K = torch.stack(all_K).float()
+ all_R = torch.stack(all_R).float()
+ all_T = torch.stack(all_T).permute(0, 2, 1).float()
+
+ uv = uv_creation((all_imgs.size(2), all_imgs.size(1)))
+ self.data = {'rgbs': all_imgs,
+ 'masks': all_masks,
+ 'poses': all_poses,
+ 'rays': all_rays,
+ 'uv': uv,
+ 'light_pose': all_light_pose, # for rendering lights
+ 'K': all_K,
+ 'R': all_R,
+ 'T': all_T,
+ }
+
+ def __len__(self):
+ return 1
+
+ def __getitem__(self, idx):
+ return self.data
\ No newline at end of file
diff --git a/src/dpsr.py b/src/dpsr.py
new file mode 100644
index 0000000..cbb5c9b
--- /dev/null
+++ b/src/dpsr.py
@@ -0,0 +1,75 @@
+
+import torch
+import torch.nn as nn
+from src.utils import spec_gaussian_filter, fftfreqs, img, grid_interp, point_rasterize
+import numpy as np
+import torch.fft
+
+class DPSR(nn.Module):
+ def __init__(self, res, sig=10, scale=True, shift=True):
+ """
+ :param res: tuple of output field resolution. eg., (128,128)
+ :param sig: degree of gaussian smoothing
+ """
+ super(DPSR, self).__init__()
+ self.res = res
+ self.sig = sig
+ self.dim = len(res)
+ self.denom = np.prod(res)
+ G = spec_gaussian_filter(res=res, sig=sig).float()
+ G = G
+ # self.G.requires_grad = False # True, if we also make sig a learnable parameter
+ self.omega = fftfreqs(res, dtype=torch.float32)
+ self.scale = scale
+ self.shift = shift
+ self.register_buffer("G", G)
+
+ def forward(self, V, N):
+ """
+ :param V: (batch, nv, 2 or 3) tensor for point cloud coordinates
+ :param N: (batch, nv, 2 or 3) tensor for point normals
+ :return phi: (batch, res, res, ...) tensor of output indicator function field
+ """
+ assert(V.shape == N.shape) # [b, nv, ndims]
+ ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2]
+
+ #!!! OLD
+ # ras_s = torch.rfft(ras_p, signal_ndim=self.dim) # [b, n_dim, dim0, dim1, dim2/2+1, 2]
+ # ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+2))+[1, self.dim+2]))
+ # N_ = (ras_s * self.G) # [b, n_dim, dim0, dim1, dim2/2+1, 2]
+
+ ras_s = torch.fft.rfftn(ras_p, dim=(2,3,4))
+ ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+1))+[self.dim+1, 1]))
+ N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1]
+
+ omega = fftfreqs(self.res, dtype=torch.float32).unsqueeze(-1) # [dim0, dim1, dim2/2+1, n_dim, 1]
+ omega *= 2 * np.pi # normalize frequencies
+ omega = omega.to(V.device)
+
+ # DivN = torch.sum(-img(N_) * omega, dim=-2) #!!! OLD [b, dim0, dim1, dim2/2+1, 2]
+ DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2)
+
+ Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1]
+ Phi = DivN / (Lap+1e-6) # [b, dim0, dim1, dim2/2+1, 2]
+ Phi = Phi.permute(*tuple([list(range(1,self.dim+2)) + [0]])) # [dim0, dim1, dim2/2+1, 2, b]
+ Phi[tuple([0] * self.dim)] = 0
+ Phi = Phi.permute(*tuple([[self.dim+1] + list(range(self.dim+1))])) # [b, dim0, dim1, dim2/2+1, 2]
+
+ # phi = torch.irfft(Phi, signal_ndim=self.dim, signal_sizes=self.res)#!!! OLD [b, dim0, dim1, dim2]
+ phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1,2,3))
+
+ if self.shift or self.scale:
+ # ensure values at points are zero
+ fv = grid_interp(phi.unsqueeze(-1), V, batched=True).squeeze(-1) # [b, nv]
+ if self.shift: # offset points to have mean of 0
+ offset = torch.mean(fv, dim=-1) # [b,]
+ phi -= offset.view(*tuple([-1] + [1] * self.dim))
+
+ phi = phi.permute(*tuple([list(range(1,self.dim+1)) + [0]]))
+ fv0 = phi[tuple([0] * self.dim)] # [b,]
+ phi = phi.permute(*tuple([[self.dim] + list(range(self.dim))]))
+
+ if self.scale:
+ # phi = phi / fv0.view(*tuple([-1] + [1] * self.dim)) * 0.5
+ phi = -phi / torch.abs(fv0.view(*tuple([-1]+[1] * self.dim))) *0.5
+ return phi
\ No newline at end of file
diff --git a/src/eval.py b/src/eval.py
new file mode 100644
index 0000000..5ce65fe
--- /dev/null
+++ b/src/eval.py
@@ -0,0 +1,168 @@
+import logging
+import numpy as np
+import trimesh
+from pykdtree.kdtree import KDTree
+
+EMPTY_PCL_DICT = {
+ 'completeness': np.sqrt(3),
+ 'accuracy': np.sqrt(3),
+ 'completeness2': 3,
+ 'accuracy2': 3,
+ 'chamfer': 6,
+}
+
+EMPTY_PCL_DICT_NORMALS = {
+ 'normals completeness': -1.,
+ 'normals accuracy': -1.,
+ 'normals': -1.,
+}
+
+logger = logging.getLogger(__name__)
+
+
+class MeshEvaluator(object):
+ ''' Mesh evaluation class.
+ It handles the mesh evaluation process.
+ Args:
+ n_points (int): number of points to be used for evaluation
+ '''
+
+ def __init__(self, n_points=100000):
+ self.n_points = n_points
+
+ def eval_mesh(self, mesh, pointcloud_tgt, normals_tgt, thresholds=np.linspace(1./1000, 1, 1000)):
+ ''' Evaluates a mesh.
+ Args:
+ mesh (trimesh): mesh which should be evaluated
+ pointcloud_tgt (numpy array): target point cloud
+ normals_tgt (numpy array): target normals
+ thresholds (numpy arry): for F-Score
+ '''
+ if len(mesh.vertices) != 0 and len(mesh.faces) != 0:
+ pointcloud, idx = mesh.sample(self.n_points, return_index=True)
+
+ pointcloud = pointcloud.astype(np.float32)
+ normals = mesh.face_normals[idx]
+ else:
+ pointcloud = np.empty((0, 3))
+ normals = np.empty((0, 3))
+
+ out_dict = self.eval_pointcloud(
+ pointcloud, pointcloud_tgt, normals, normals_tgt, thresholds=thresholds)
+
+ return out_dict
+
+ def eval_pointcloud(self, pointcloud, pointcloud_tgt,
+ normals=None, normals_tgt=None,
+ thresholds=np.linspace(1./1000, 1, 1000)):
+ ''' Evaluates a point cloud.
+ Args:
+ pointcloud (numpy array): predicted point cloud
+ pointcloud_tgt (numpy array): target point cloud
+ normals (numpy array): predicted normals
+ normals_tgt (numpy array): target normals
+ thresholds (numpy array): threshold values for the F-score calculation
+ '''
+ # Return maximum losses if pointcloud is empty
+ if pointcloud.shape[0] == 0:
+ logger.warn('Empty pointcloud / mesh detected!')
+ out_dict = EMPTY_PCL_DICT.copy()
+ if normals is not None and normals_tgt is not None:
+ out_dict.update(EMPTY_PCL_DICT_NORMALS)
+ return out_dict
+
+ pointcloud = np.asarray(pointcloud)
+ pointcloud_tgt = np.asarray(pointcloud_tgt)
+
+
+ # Completeness: how far are the points of the target point cloud
+ # from thre predicted point cloud
+ completeness, completeness_normals = distance_p2p(
+ pointcloud_tgt, normals_tgt, pointcloud, normals
+ )
+ recall = get_threshold_percentage(completeness, thresholds)
+ completeness2 = completeness**2
+
+ completeness = completeness.mean()
+ completeness2 = completeness2.mean()
+ completeness_normals = completeness_normals.mean()
+
+ # Accuracy: how far are th points of the predicted pointcloud
+ # from the target pointcloud
+ accuracy, accuracy_normals = distance_p2p(
+ pointcloud, normals, pointcloud_tgt, normals_tgt
+ )
+ precision = get_threshold_percentage(accuracy, thresholds)
+ accuracy2 = accuracy**2
+
+ accuracy = accuracy.mean()
+ accuracy2 = accuracy2.mean()
+ accuracy_normals = accuracy_normals.mean()
+
+ # Chamfer distance
+ chamferL2 = 0.5 * (completeness2 + accuracy2)
+ normals_correctness = (
+ 0.5 * completeness_normals + 0.5 * accuracy_normals
+ )
+ chamferL1 = 0.5 * (completeness + accuracy)
+
+ # F-Score
+ F = [
+ 2 * precision[i] * recall[i] / (precision[i] + recall[i])
+ for i in range(len(precision))
+ ]
+
+ out_dict = {
+ 'completeness': completeness,
+ 'accuracy': accuracy,
+ 'normals completeness': completeness_normals,
+ 'normals accuracy': accuracy_normals,
+ 'normals': normals_correctness,
+ 'completeness2': completeness2,
+ 'accuracy2': accuracy2,
+ 'chamfer-L2': chamferL2,
+ 'chamfer-L1': chamferL1,
+ 'f-score': F[9], # threshold = 1.0%
+ 'f-score-15': F[14], # threshold = 1.5%
+ 'f-score-20': F[19], # threshold = 2.0%
+ }
+
+ return out_dict
+
+
+def distance_p2p(points_src, normals_src, points_tgt, normals_tgt):
+ ''' Computes minimal distances of each point in points_src to points_tgt.
+ Args:
+ points_src (numpy array): source points
+ normals_src (numpy array): source normals
+ points_tgt (numpy array): target points
+ normals_tgt (numpy array): target normals
+ '''
+ kdtree = KDTree(points_tgt)
+ dist, idx = kdtree.query(points_src)
+
+ if normals_src is not None and normals_tgt is not None:
+ normals_src = \
+ normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True)
+ normals_tgt = \
+ normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True)
+
+ normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1)
+ # Handle normals that point into wrong direction gracefully
+ # (mostly due to mehtod not caring about this in generation)
+ normals_dot_product = np.abs(normals_dot_product)
+ else:
+ normals_dot_product = np.array(
+ [np.nan] * points_src.shape[0], dtype=np.float32)
+ return dist, normals_dot_product
+
+def get_threshold_percentage(dist, thresholds):
+ ''' Evaluates a point cloud.
+ Args:
+ dist (numpy array): calculated distance
+ thresholds (numpy array): threshold values for the F-score calculation
+ '''
+ in_threshold = [
+ (dist <= t).mean() for t in thresholds
+ ]
+ return in_threshold
\ No newline at end of file
diff --git a/src/generation.py b/src/generation.py
new file mode 100644
index 0000000..9abbe6d
--- /dev/null
+++ b/src/generation.py
@@ -0,0 +1,63 @@
+import torch
+import time
+import trimesh
+import numpy as np
+from src.utils import mc_from_psr
+
+class Generator3D(object):
+ ''' Generator class for Occupancy Networks.
+
+ It provides functions to generate the final mesh as well refining options.
+
+ Args:
+ model (nn.Module): trained Occupancy Network model
+ points_batch_size (int): batch size for points evaluation
+ threshold (float): threshold value
+ device (device): pytorch device
+ padding (float): how much padding should be used for MISE
+ sample (bool): whether z should be sampled
+ input_type (str): type of input
+ '''
+
+ def __init__(self, model, points_batch_size=100000,
+ threshold=0.5, device=None, padding=0.1,
+ sample=False, input_type = None, dpsr=None, psr_tanh=True):
+ self.model = model.to(device)
+ self.points_batch_size = points_batch_size
+ self.threshold = threshold
+ self.device = device
+ self.input_type = input_type
+ self.padding = padding
+ self.sample = sample
+ self.dpsr = dpsr
+ self.psr_tanh = psr_tanh
+
+ def generate_mesh(self, data, return_stats=True):
+ ''' Generates the output mesh.
+
+ Args:
+ data (tensor): data tensor
+ return_stats (bool): whether stats should be returned
+ '''
+ self.model.eval()
+ device = self.device
+ stats_dict = {}
+
+ p = data.get('inputs', torch.empty(1, 0)).to(device)
+
+ t0 = time.time()
+ points, normals = self.model(p)
+ t1 = time.time()
+ psr_grid = self.dpsr(points, normals)
+ t2 = time.time()
+ v, f, _ = mc_from_psr(psr_grid,
+ zero_level=self.threshold)
+ stats_dict['pcl'] = t1 - t0
+ stats_dict['dpsr'] = t2 - t1
+ stats_dict['mc'] = time.time() - t2
+ stats_dict['total'] = time.time() - t0
+
+ if return_stats:
+ return v, f, points, normals, stats_dict
+ else:
+ return v, f, points, normals
\ No newline at end of file
diff --git a/src/model.py b/src/model.py
new file mode 100644
index 0000000..a2e978f
--- /dev/null
+++ b/src/model.py
@@ -0,0 +1,181 @@
+import torch
+import numpy as np
+import time
+from src.utils import point_rasterize, grid_interp, mc_from_psr, \
+calc_inters_points
+from src.dpsr import DPSR
+import torch.nn as nn
+from src.network import encoder_dict, decoder_dict
+from src.network.utils import map2local
+
+class PSR2Mesh(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, psr_grid):
+ """
+ In the forward pass we receive a Tensor containing the input and return
+ a Tensor containing the output. ctx is a context object that can be used
+ to stash information for backward computation. You can cache arbitrary
+ objects for use in the backward pass using the ctx.save_for_backward method.
+ """
+ verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
+ verts = verts.unsqueeze(0)
+ faces = faces.unsqueeze(0)
+ normals = normals.unsqueeze(0)
+
+ res = torch.tensor(psr_grid.detach().shape[2])
+ ctx.save_for_backward(verts, normals, res)
+
+ return verts, faces, normals
+
+ @staticmethod
+ def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals):
+ """
+ In the backward pass we receive a Tensor containing the gradient of the loss
+ with respect to the output, and we need to compute the gradient of the loss
+ with respect to the input.
+ """
+ vert_pts, normals, res = ctx.saved_tensors
+ res = (res.item(), res.item(), res.item())
+ # matrix multiplication between dL/dV and dV/dPSR
+ # dV/dPSR = - normals
+ grad_vert = torch.matmul(dL_dVertex.permute(1, 0, 2), -normals.permute(1, 2, 0))
+ grad_grid = point_rasterize(vert_pts, grad_vert.permute(1, 0, 2), res) # b x 1 x res x res x res
+
+ return grad_grid
+
+class PSR2SurfacePoints(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample):
+ verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
+ verts = verts * 2. - 1. # within the range of [-1, 1]
+
+
+ p_all, n_all, mask_all = [], [], []
+
+ for i in range(len(poses)):
+ pose = poses[i]
+ if mask_sample is not None:
+ p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size, mask_gt=mask_sample[i])
+ else:
+ p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size)
+
+ n_inters = grid_interp(psr_grad[None], (p_inters[None].detach() + 1) / 2).squeeze()
+ p_all.append(p_inters)
+ n_all.append(n_inters)
+ mask_all.append(mask)
+ p_inters_all = torch.cat(p_all, dim=0)
+ n_inters_all = torch.cat(n_all, dim=0)
+ mask_visible = torch.stack(mask_all, dim=0)
+
+
+ res = torch.tensor(psr_grid.detach().shape[2])
+ ctx.save_for_backward(p_inters_all, n_inters_all, res)
+
+ return p_inters_all, mask_visible
+
+ @staticmethod
+ def backward(ctx, dL_dp, dL_dmask):
+ pts, pts_n, res = ctx.saved_tensors
+ res = (res.item(), res.item(), res.item())
+
+ # grad from the p_inters via MLP renderer
+ grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None])
+ grad_grid_pts = point_rasterize((pts[None]+1)/2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res
+
+ return grad_grid_pts, None, None, None, None, None
+
+class Encode2Points(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+
+ encoder = cfg['model']['encoder']
+ decoder = cfg['model']['decoder']
+ dim = cfg['data']['dim'] # input dim
+ c_dim = cfg['model']['c_dim']
+ encoder_kwargs = cfg['model']['encoder_kwargs']
+ if encoder_kwargs == None:
+ encoder_kwargs = {}
+ decoder_kwargs = cfg['model']['decoder_kwargs']
+ padding = cfg['data']['padding']
+ self.predict_normal = cfg['model']['predict_normal']
+ self.predict_offset = cfg['model']['predict_offset']
+
+ out_dim = 3
+ out_dim_offset = 3
+ num_offset = cfg['data']['num_offset']
+ # each point predict more than one offset to add output points
+ if num_offset > 1:
+ out_dim_offset = out_dim * num_offset
+ self.num_offset = num_offset
+
+ # local mapping
+ self.map2local = None
+ if cfg['model']['local_coord']:
+ if 'unet' in encoder_kwargs.keys():
+ unit_size = 1 / encoder_kwargs['plane_resolution']
+ else:
+ unit_size = 1 / encoder_kwargs['grid_resolution']
+
+ local_mapping = map2local(unit_size)
+
+ self.encoder = encoder_dict[encoder](
+ dim=dim, c_dim=c_dim, map2local=local_mapping,
+ **encoder_kwargs
+ )
+
+ if self.predict_normal:
+ # decoder for normal prediction
+ self.decoder_normal = decoder_dict[decoder](
+ dim=dim, c_dim=c_dim, out_dim=out_dim,
+ **decoder_kwargs)
+ if self.predict_offset:
+ # decoder for offset prediction
+ self.decoder_offset = decoder_dict[decoder](
+ dim=dim, c_dim=c_dim, out_dim=out_dim_offset,
+ map2local=local_mapping,
+ **decoder_kwargs)
+
+ self.s_off = cfg['model']['s_offset']
+
+
+ def forward(self, p):
+ ''' Performs a forward pass through the network.
+
+ Args:
+ p (tensor): input unoriented points
+ '''
+
+ time_dict = {}
+ mask = None
+
+ batch_size = p.size(0)
+ points = p.clone()
+
+ # encode the input point cloud to a feature volume
+ t0 = time.perf_counter()
+ c = self.encoder(p)
+ t1 = time.perf_counter()
+ if self.predict_offset:
+ offset = self.decoder_offset(p, c)
+ # more than one offset is predicted per-point
+ if self.num_offset > 1:
+ points = points.repeat(1, 1, self.num_offset).reshape(batch_size, -1, 3)
+ points = points + self.s_off * offset
+ else:
+ points = p
+
+ if self.predict_normal:
+ normals = self.decoder_normal(points, c)
+ t2 = time.perf_counter()
+
+ time_dict['encode'] = t1 - t0
+ time_dict['predict'] = t2 - t1
+
+ points = torch.clamp(points, 0.0, 0.99)
+ if self.cfg['model']['normal_normalize']:
+ normals = normals / (normals.norm(dim=-1, keepdim=True)+1e-8)
+
+
+ return points, normals
+
\ No newline at end of file
diff --git a/src/model_rgb.py b/src/model_rgb.py
new file mode 100644
index 0000000..79f20a1
--- /dev/null
+++ b/src/model_rgb.py
@@ -0,0 +1,139 @@
+import torch
+from src.network.net_rgb import RenderingNetwork
+from src.utils import approx_psr_grad
+from pytorch3d.renderer import (
+ RasterizationSettings,
+ PerspectiveCameras,
+ MeshRenderer,
+ MeshRasterizer,
+ SoftSilhouetteShader)
+from pytorch3d.structures import Meshes
+
+
+def approx_psr_grad(psr_grid, res, normalize=True):
+ delta_x = delta_y = delta_z = 1/res
+ psr_pad = torch.nn.ReplicationPad3d(1)(psr_grid).squeeze()
+
+ grad_x = (psr_pad[2:, :, :] - psr_pad[:-2, :, :]) / 2 / delta_x
+ grad_y = (psr_pad[:, 2:, :] - psr_pad[:, :-2, :]) / 2 / delta_y
+ grad_z = (psr_pad[:, :, 2:] - psr_pad[:, :, :-2]) / 2 / delta_z
+ grad_x = grad_x[:, 1:-1, 1:-1]
+ grad_y = grad_y[1:-1, :, 1:-1]
+ grad_z = grad_z[1:-1, 1:-1, :]
+
+ psr_grad = torch.stack([grad_x, grad_y, grad_z], dim=3) # [res_x, res_y, res_z, 3]
+ if normalize:
+ psr_grad = psr_grad / (psr_grad.norm(dim=3, keepdim=True) + 1e-12)
+
+ return psr_grad
+
+
+class SAP2Image(nn.Module):
+ def __init__(self, cfg, img_size):
+ super().__init__()
+
+ self.psr2sur = PSR2SurfacePoints.apply
+ self.psr2mesh = PSR2Mesh.apply
+ # initialize DPSR
+ self.dpsr = DPSR(res=(cfg['model']['grid_res'],
+ cfg['model']['grid_res'],
+ cfg['model']['grid_res']),
+ sig=cfg['model']['psr_sigma'])
+ self.cfg = cfg
+ if cfg['train']['l_weight']['rgb'] != 0.:
+ self.rendering_network = RenderingNetwork(**cfg['model']['renderer'])
+
+ if cfg['train']['l_weight']['mask'] != 0.:
+ # initialize rasterizer
+ sigma = 1e-4
+ raster_settings_soft = RasterizationSettings(
+ image_size=img_size,
+ blur_radius=np.log(1. / 1e-4 - 1.)*sigma,
+ faces_per_pixel=150,
+ perspective_correct=False
+ )
+
+ # initialize silhouette renderer
+ self.mesh_rasterizer = MeshRenderer(
+ rasterizer=MeshRasterizer(
+ raster_settings=raster_settings_soft
+ ),
+ shader=SoftSilhouetteShader()
+ )
+
+ self.cfg = cfg
+ self.img_size = img_size
+
+ def forward(self, inputs, data):
+ points, normals = inputs[...,:3], inputs[...,3:]
+ points = torch.sigmoid(points)
+ normals = normals / normals.norm(dim=-1, keepdim=True)
+
+ # DPSR to get grid
+ psr_grid = self.dpsr(points, normals).unsqueeze(1)
+ psr_grid = torch.tanh(psr_grid)
+
+ return self.render_img(psr_grid, data)
+
+ def render_img(self, psr_grid, data):
+
+ n_views = len(data['masks'])
+ n_views_per_iter = self.cfg['data']['n_views_per_iter']
+
+ rgb_render_mode = self.cfg['model']['renderer']['mode']
+ uv = data['uv']
+
+ idx = np.random.randint(0, n_views, n_views_per_iter)
+ pose = [data['poses'][i] for i in idx]
+ rgb = data['rgbs'][idx]
+ mask_gt = data['masks'][idx]
+ ray = None
+ pred_rgb = None
+ pred_mask = None
+
+ if self.cfg['train']['l_weight']['rgb'] != 0.:
+ psr_grad = approx_psr_grad(psr_grid, self.cfg['model']['grid_res'])
+ p_inters, visible_mask = self.psr2sur(psr_grid, pose, self.img_size, uv, psr_grad, None)
+ n_inters = grid_interp(psr_grad[None], (p_inters.detach()[None] + 1) / 2)
+ fea_interp = None
+ if 'rays' in data.keys():
+ ray = data['rays'].squeeze()[idx][visible_mask]
+ pred_rgb = self.rendering_network(p_inters, normals=n_inters.squeeze(), view_dirs=ray, feature_vectors=fea_interp)
+
+ # silhouette loss
+ if self.cfg['train']['l_weight']['mask'] != 0.:
+ # build mesh
+ v, f, _ = self.psr2mesh(psr_grid)
+ v = v * 2. - 1 # within the range of [-1, 1]
+ # ! Fast but more GPU usage
+ mesh = Meshes(verts=[v.squeeze()], faces=[f.squeeze()])
+ if True:
+ #! PyTorch3D silhouette loss
+ # build pose
+ R = torch.cat([p.R for p in pose], dim=0)
+ T = torch.cat([p.T for p in pose], dim=0)
+ focal = torch.cat([p.focal_length for p in pose], dim=0)
+ pp = torch.cat([p.principal_point for p in pose], dim=0)
+ pose_cur = PerspectiveCameras(
+ focal_length=focal,
+ principal_point=pp,
+ R=R, T=T,
+ device=R.device)
+ pred_mask = self.mesh_rasterizer(mesh.extend(n_views_per_iter), cameras=pose_cur)[..., 3]
+ else:
+ pred_mask = []
+ # ! Slow but less GPU usage
+ for i in range(n_views_per_iter):
+ #! PyTorch3D silhouette loss
+ pred_mask.append(self.mesh_rasterizer(mesh, cameras=pose[i])[..., 3])
+ pred_mask = torch.cat(pred_mask, dim=0)
+
+ output = {
+ 'rgb': pred_rgb,
+ 'rgb_gt': rgb,
+ 'mask': pred_mask,
+ 'mask_gt': mask_gt,
+ 'vis_mask': visible_mask,
+ }
+
+ return output
\ No newline at end of file
diff --git a/src/network/__init__.py b/src/network/__init__.py
new file mode 100644
index 0000000..f6a0b61
--- /dev/null
+++ b/src/network/__init__.py
@@ -0,0 +1,8 @@
+from src.network import encoder, decoder
+
+encoder_dict = {
+ 'local_pool_pointnet': encoder.LocalPoolPointnet,
+}
+decoder_dict = {
+ 'simple_local': decoder.LocalDecoder,
+}
\ No newline at end of file
diff --git a/src/network/decoder.py b/src/network/decoder.py
new file mode 100644
index 0000000..5fc7bd1
--- /dev/null
+++ b/src/network/decoder.py
@@ -0,0 +1,106 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from ipdb import set_trace as st
+from src.network.utils import normalize_3d_coordinate, ResnetBlockFC, \
+ normalize_coordinate
+
+
+class LocalDecoder(nn.Module):
+ ''' Decoder.
+ Instead of conditioning on global features, on plane/volume local features.
+ Args:
+ dim (int): input dimension
+ c_dim (int): dimension of latent conditioned code c
+ hidden_size (int): hidden size of Decoder network
+ n_blocks (int): number of blocks ResNetBlockFC layers
+ leaky (bool): whether to use leaky ReLUs
+ sample_mode (str): sampling feature strategy, bilinear|nearest
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
+ '''
+
+ def __init__(self, dim=3, c_dim=128, out_dim=3,
+ hidden_size=256, n_blocks=5, leaky=False, sample_mode='bilinear', padding=0.1, map2local=None):
+ super().__init__()
+ self.c_dim = c_dim
+ self.n_blocks = n_blocks
+
+ if c_dim != 0:
+ self.fc_c = nn.ModuleList([
+ nn.Linear(c_dim, hidden_size) for i in range(n_blocks)
+ ])
+
+
+ self.fc_p = nn.Linear(dim, hidden_size)
+
+ self.blocks = nn.ModuleList([
+ ResnetBlockFC(hidden_size) for i in range(n_blocks)
+ ])
+
+ self.fc_out = nn.Linear(hidden_size, out_dim)
+
+ if not leaky:
+ self.actvn = F.relu
+ else:
+ self.actvn = lambda x: F.leaky_relu(x, 0.2)
+
+ self.sample_mode = sample_mode
+ self.padding = padding
+ self.map2local = map2local
+ self.out_dim = out_dim
+
+
+ def sample_plane_feature(self, p, c, plane='xz'):
+ xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
+ xy = xy[:, :, None].float()
+ vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
+ c = F.grid_sample(c, vgrid, padding_mode='border',
+ align_corners=True,
+ mode=self.sample_mode).squeeze(-1)
+ return c
+
+ def sample_grid_feature(self, p, c):
+ p_nor = normalize_3d_coordinate(p.clone())
+ p_nor = p_nor[:, :, None, None].float()
+ vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1)
+ # acutally trilinear interpolation if mode = 'bilinear'
+ c = F.grid_sample(c, vgrid, padding_mode='border',
+ align_corners=True,
+ mode=self.sample_mode).squeeze(-1).squeeze(-1)
+ return c
+
+ def forward(self, p, c_plane, **kwargs):
+ batch_size = p.shape[0]
+ plane_type = list(c_plane.keys())
+ c = 0
+ if 'grid' in plane_type:
+ c += self.sample_grid_feature(p, c_plane['grid'])
+ if 'xz' in plane_type:
+ c += self.sample_plane_feature(p, c_plane['xz'], plane='xz')
+ if 'xy' in plane_type:
+ c += self.sample_plane_feature(p, c_plane['xy'], plane='xy')
+ if 'yz' in plane_type:
+ c += self.sample_plane_feature(p, c_plane['yz'], plane='yz')
+ c = c.transpose(1, 2)
+
+ p = p.float()
+
+ if self.map2local:
+ p = self.map2local(p)
+
+ net = self.fc_p(p)
+
+ for i in range(self.n_blocks):
+ if self.c_dim != 0:
+ net = net + self.fc_c[i](c)
+
+ net = self.blocks[i](net)
+
+ out = self.fc_out(self.actvn(net))
+
+
+ if self.out_dim > 3:
+ out = out.reshape(batch_size, -1, 3)
+
+ return out
\ No newline at end of file
diff --git a/src/network/encoder.py b/src/network/encoder.py
new file mode 100644
index 0000000..4385a9f
--- /dev/null
+++ b/src/network/encoder.py
@@ -0,0 +1,181 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from src.network.unet3d import UNet3D
+from src.network.unet import UNet
+from ipdb import set_trace as st
+from torch_scatter import scatter_mean, scatter_max
+from src.network.utils import get_embedder, normalize_3d_coordinate,\
+coordinate2index, ResnetBlockFC, normalize_coordinate
+
+class LocalPoolPointnet(nn.Module):
+ ''' PointNet-based encoder network with ResNet blocks for each point.
+ Number of input points are fixed.
+
+ Args:
+ c_dim (int): dimension of latent code c
+ dim (int): input points dimension
+ hidden_dim (int): hidden dimension of the network
+ scatter_type (str): feature aggregation when doing local pooling
+ unet (bool): weather to use U-Net
+ unet_kwargs (str): U-Net parameters
+ unet3d (bool): weather to use 3D U-Net
+ unet3d_kwargs (str): 3D U-Net parameters
+ plane_resolution (int): defined resolution for plane feature
+ grid_resolution (int): defined resolution for grid feature
+ plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
+ n_blocks (int): number of blocks ResNetBlockFC layers
+ map2local (function): map global coordintes to local ones
+ pos_encoding (int): frequency for the positional encoding
+ '''
+
+ def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max',
+ unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None,
+ plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5,
+ map2local=None, pos_encoding=0):
+ super().__init__()
+
+ self.c_dim = c_dim
+
+ self.fc_pos = nn.Linear(dim, 2*hidden_dim)
+ self.blocks = nn.ModuleList([
+ ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
+ ])
+ self.fc_c = nn.Linear(hidden_dim, c_dim)
+
+ self.actvn = nn.ReLU()
+ self.hidden_dim = hidden_dim
+
+ self.unet = None
+ if unet:
+ self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
+
+ self.unet3d = None
+ if unet3d:
+ self.unet3d = UNet3D(**unet3d_kwargs)
+
+ self.reso_plane = plane_resolution
+ self.reso_grid = grid_resolution
+ self.plane_type = plane_type
+ self.padding = padding
+
+ self.fc_pos = nn.Linear(dim, 2*hidden_dim)
+ self.pe = None
+ if pos_encoding > 0:
+ embed_fn, input_ch = get_embedder(pos_encoding, d_in=dim)
+ self.pe = embed_fn
+ self.fc_pos = nn.Linear(input_ch, 2*hidden_dim)
+
+ self.map2local = map2local
+
+
+ if scatter_type == 'max':
+ self.scatter = scatter_max
+ elif scatter_type == 'mean':
+ self.scatter = scatter_mean
+ else:
+ raise ValueError('incorrect scatter type')
+
+ def generate_plane_features(self, p, c, plane='xz'):
+ # acquire indices of features in plane
+ xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
+ index = coordinate2index(xy, self.reso_plane)
+
+ # scatter plane features from points
+ fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
+ c = c.permute(0, 2, 1) # B x 512 x T
+ fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
+ fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)
+
+ # process the plane features with UNet
+ if self.unet is not None:
+ fea_plane = self.unet(fea_plane)
+
+ return fea_plane
+
+ def generate_grid_features(self, p, c):
+ p_nor = normalize_3d_coordinate(p.clone())
+ index = coordinate2index(p_nor, self.reso_grid, coord_type='3d')
+ # scatter grid features from points
+ fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3)
+ c = c.permute(0, 2, 1)
+ fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3
+ fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) # sparce matrix (B x 512 x reso x reso)
+
+ if self.unet3d is not None:
+ fea_grid = self.unet3d(fea_grid)
+
+ return fea_grid
+
+ def pool_local(self, xy, index, c):
+ bs, fea_dim = c.size(0), c.size(2)
+ keys = xy.keys()
+
+ c_out = 0
+ for key in keys:
+ # scatter plane features from points
+ if key == 'grid':
+ fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid**3)
+ else:
+ fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2)
+ if self.scatter == scatter_max:
+ fea = fea[0]
+ # gather feature back to points
+ fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
+ c_out += fea
+ return c_out.permute(0, 2, 1)
+
+
+ def forward(self, p, normalize=True):
+ batch_size, T, D = p.size()
+
+ # acquire the index for each point
+ coord = {}
+ index = {}
+ if 'xz' in self.plane_type:
+ coord['xz'] = normalize_coordinate(p.clone(), plane='xz')
+ index['xz'] = coordinate2index(coord['xz'], self.reso_plane)
+ if 'xy' in self.plane_type:
+ coord['xy'] = normalize_coordinate(p.clone(), plane='xy')
+ index['xy'] = coordinate2index(coord['xy'], self.reso_plane)
+ if 'yz' in self.plane_type:
+ coord['yz'] = normalize_coordinate(p.clone(), plane='yz')
+ index['yz'] = coordinate2index(coord['yz'], self.reso_plane)
+ if 'grid' in self.plane_type:
+ if normalize:
+ coord['grid'] = normalize_3d_coordinate(p.clone())
+ else:
+ coord['grid'] = p.clone()[...,:3]
+ index['grid'] = coordinate2index(coord['grid'], self.reso_grid, coord_type='3d')
+
+
+ if self.pe:
+ p = self.pe(p)
+
+ if self.map2local:
+ pp = self.map2local(p)
+ net = self.fc_pos(pp)
+ else:
+ net = self.fc_pos(p)
+ # net = self.fc_pos(p)
+
+ net = self.blocks[0](net)
+ for block in self.blocks[1:]:
+ pooled = self.pool_local(coord, index, net)
+ net = torch.cat([net, pooled], dim=2)
+ net = block(net)
+
+ c = self.fc_c(net)
+
+ fea = {}
+ if 'grid' in self.plane_type:
+ fea['grid'] = self.generate_grid_features(p, c)
+ if 'xz' in self.plane_type:
+ fea['xz'] = self.generate_plane_features(p, c, plane='xz')
+ if 'xy' in self.plane_type:
+ fea['xy'] = self.generate_plane_features(p, c, plane='xy')
+ if 'yz' in self.plane_type:
+ fea['yz'] = self.generate_plane_features(p, c, plane='yz')
+
+ return fea
\ No newline at end of file
diff --git a/src/network/net_rgb.py b/src/network/net_rgb.py
new file mode 100644
index 0000000..acdde01
--- /dev/null
+++ b/src/network/net_rgb.py
@@ -0,0 +1,234 @@
+# code from IDR (https://github.com/lioryariv/idr/blob/main/code/model/implicit_differentiable_renderer.py)
+import torch
+import torch.nn as nn
+import numpy as np
+from src.network.utils import get_embedder
+from pdb import set_trace as st
+
+class RenderingNetwork(nn.Module):
+ def __init__(
+ self,
+ fea_size=0,
+ mode='naive',
+ d_out=3,
+ dims=[512, 512, 512, 512],
+ weight_norm=True,
+ pe_freq_view=0 # for positional encoding
+ ):
+ super().__init__()
+
+ self.mode = mode
+ if mode == 'naive':
+ d_in = 3
+ elif mode == 'no_feature':
+ d_in = 3 + 3 + 3
+ fea_size = 0
+ elif mode == 'full':
+ d_in = 3 + 3 + 3
+ else:
+ d_in = 3 + 3
+ dims = [d_in + fea_size] + dims + [d_out]
+
+ self.embedview_fn = None
+ if pe_freq_view > 0:
+ embedview_fn, input_ch = get_embedder(pe_freq_view, d_in=3)
+ self.embedview_fn = embedview_fn
+ dims[0] += (input_ch - 3)
+
+ self.num_layers = len(dims)
+
+ for l in range(0, self.num_layers - 1):
+ out_dim = dims[l + 1]
+ lin = nn.Linear(dims[l], out_dim)
+
+ if weight_norm:
+ lin = nn.utils.weight_norm(lin)
+
+ setattr(self, "lin" + str(l), lin)
+
+ self.relu = nn.ReLU()
+ self.tanh = nn.Tanh()
+
+ def forward(self, points, normals=None, view_dirs=None, feature_vectors=None):
+ if self.embedview_fn is not None:
+ view_dirs = self.embedview_fn(view_dirs)
+ # points = self.embedview_fn(points)
+
+ if (self.mode == 'full') & (feature_vectors is not None):
+ rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
+ elif (self.mode == 'no_feature') | ((self.mode == 'full') & (feature_vectors is None)):
+ rendering_input = torch.cat([points, view_dirs, normals], dim=-1)
+ elif self.mode == 'no_view_dir':
+ rendering_input = torch.cat([points, normals], dim=-1)
+ elif self.mode == 'no_normal':
+ rendering_input = torch.cat([points, view_dirs], dim=-1)
+ else:
+ rendering_input = points
+
+ x = rendering_input
+
+ for l in range(0, self.num_layers - 1):
+ lin = getattr(self, "lin" + str(l))
+
+ x = lin(x)
+
+ if l < self.num_layers - 2:
+ x = self.relu(x)
+
+ x = self.tanh(x)
+ return x
+
+
+class NeRFRenderingNetwork(nn.Module):
+ def __init__(
+ self,
+ feature_vector_size=0,
+ mode='naive',
+ d_in=3,
+ d_out=3,
+ dims=[512, 512, 512, 256],
+ weight_norm=True,
+ multires=0, # positional encoding of points
+ multires_view=0 # positional encoding of view
+ ):
+ super().__init__()
+
+ self.mode = mode
+ dims = [d_in + feature_vector_size] + dims
+
+
+ self.embed_fn = None
+ if multires > 0:
+ embed_fn, input_ch = get_embedder(multires, d_in=d_in)
+ self.embed_fn = embed_fn
+ dims[0] += (input_ch - 3)
+
+ self.num_layers = len(dims)
+
+ self.pts_net = nn.ModuleList([nn.Linear(dims[i], dims[i + 1]) for i in range(self.num_layers - 1)])
+
+ self.embedview_fn = None
+ if multires_view > 0:
+ embedview_fn, view_ch = get_embedder(multires_view, d_in=3)
+ self.embedview_fn = embedview_fn
+ # dims[0] += (input_ch - 3)
+
+ if mode == 'full':
+ self.view_net = nn.ModuleList([nn.Linear(dims[-1]+view_ch, 128)])
+ self.rgb_net = nn.Linear(128, 3)
+ else:
+ self.rgb_net = nn.Linear(dims[-1], 3)
+
+
+ self.relu = nn.ReLU()
+ self.tanh = nn.Tanh()
+
+ def forward(self, points, normals=None, view_dirs=None, feature_vectors=None):
+ if self.embed_fn is not None:
+ points = self.embed_fn(points)
+ if self.embedview_fn is not None:
+ view_dirs = self.embedview_fn(view_dirs)
+
+ x = points
+ for net in self.pts_net:
+ x = net(x)
+ x = self.relu(x)
+
+ if self.mode=='full':
+ x = torch.cat([x, view_dirs], -1)
+ for net in self.view_net:
+ x = net(x)
+ x = self.relu(x)
+
+ x = self.rgb_net(x)
+ x = self.tanh(x)
+ return x
+
+class ImplicitNetwork(nn.Module):
+ def __init__(
+ self,
+ d_in,
+ d_out,
+ dims,
+ geometric_init=True,
+ feature_vector_size=0,
+ bias=1.0,
+ skip_in=(),
+ weight_norm=True,
+ multires=0
+ ):
+ super().__init__()
+
+ dims = [d_in] + dims + [d_out + feature_vector_size]
+
+ self.embed_fn = None
+ if multires > 0:
+ embed_fn, input_ch = get_embedder(multires)
+ self.embed_fn = embed_fn
+ dims[0] = input_ch
+
+ self.num_layers = len(dims)
+ self.skip_in = skip_in
+
+ for l in range(0, self.num_layers - 1):
+ if l + 1 in self.skip_in:
+ out_dim = dims[l + 1] - dims[0]
+ else:
+ out_dim = dims[l + 1]
+
+ lin = nn.Linear(dims[l], out_dim)
+
+ if geometric_init:
+ if l == self.num_layers - 2:
+ torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
+ torch.nn.init.constant_(lin.bias, -bias)
+ elif multires > 0 and l == 0:
+ torch.nn.init.constant_(lin.bias, 0.0)
+ torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
+ torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
+ elif multires > 0 and l in self.skip_in:
+ torch.nn.init.constant_(lin.bias, 0.0)
+ torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
+ torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
+ else:
+ torch.nn.init.constant_(lin.bias, 0.0)
+ torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
+
+ if weight_norm:
+ lin = nn.utils.weight_norm(lin)
+
+ setattr(self, "lin" + str(l), lin)
+
+ self.softplus = nn.Softplus(beta=100)
+
+ def forward(self, input, compute_grad=False):
+ if self.embed_fn is not None:
+ input = self.embed_fn(input)
+
+ x = input
+
+ for l in range(0, self.num_layers - 1):
+ lin = getattr(self, "lin" + str(l))
+
+ if l in self.skip_in:
+ x = torch.cat([x, input], 1) / np.sqrt(2)
+
+ x = lin(x)
+
+ if l < self.num_layers - 2:
+ x = self.softplus(x)
+
+ return x
+
+ def gradient(self, x):
+ x.requires_grad_(True)
+ y = self.forward(x)[:,:1]
+ d_output = torch.ones_like(y, requires_grad=False, device=y.device)
+ gradients = torch.autograd.grad(
+ outputs=y,
+ inputs=x,
+ grad_outputs=d_output,
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True)[0]
+ return gradients.unsqueeze(1)
\ No newline at end of file
diff --git a/src/network/unet.py b/src/network/unet.py
new file mode 100644
index 0000000..a58cc31
--- /dev/null
+++ b/src/network/unet.py
@@ -0,0 +1,256 @@
+'''
+Codes are from:
+https://github.com/jaxony/unet-pytorch/blob/master/model.py
+'''
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import init
+import numpy as np
+
+def conv3x3(in_channels, out_channels, stride=1,
+ padding=1, bias=True, groups=1):
+ return nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=padding,
+ bias=bias,
+ groups=groups)
+
+def upconv2x2(in_channels, out_channels, mode='transpose'):
+ if mode == 'transpose':
+ return nn.ConvTranspose2d(
+ in_channels,
+ out_channels,
+ kernel_size=2,
+ stride=2)
+ else:
+ # out_channels is always going to be the same
+ # as in_channels
+ return nn.Sequential(
+ nn.Upsample(mode='bilinear', scale_factor=2),
+ conv1x1(in_channels, out_channels))
+
+def conv1x1(in_channels, out_channels, groups=1):
+ return nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ groups=groups,
+ stride=1)
+
+
+class DownConv(nn.Module):
+ """
+ A helper Module that performs 2 convolutions and 1 MaxPool.
+ A ReLU activation follows each convolution.
+ """
+ def __init__(self, in_channels, out_channels, pooling=True):
+ super(DownConv, self).__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.pooling = pooling
+
+ self.conv1 = conv3x3(self.in_channels, self.out_channels)
+ self.conv2 = conv3x3(self.out_channels, self.out_channels)
+
+ if self.pooling:
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+
+ def forward(self, x):
+ x = F.relu(self.conv1(x))
+ x = F.relu(self.conv2(x))
+ before_pool = x
+ if self.pooling:
+ x = self.pool(x)
+ return x, before_pool
+
+
+class UpConv(nn.Module):
+ """
+ A helper Module that performs 2 convolutions and 1 UpConvolution.
+ A ReLU activation follows each convolution.
+ """
+ def __init__(self, in_channels, out_channels,
+ merge_mode='concat', up_mode='transpose'):
+ super(UpConv, self).__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.merge_mode = merge_mode
+ self.up_mode = up_mode
+
+ self.upconv = upconv2x2(self.in_channels, self.out_channels,
+ mode=self.up_mode)
+
+ if self.merge_mode == 'concat':
+ self.conv1 = conv3x3(
+ 2*self.out_channels, self.out_channels)
+ else:
+ # num of input channels to conv2 is same
+ self.conv1 = conv3x3(self.out_channels, self.out_channels)
+ self.conv2 = conv3x3(self.out_channels, self.out_channels)
+
+
+ def forward(self, from_down, from_up):
+ """ Forward pass
+ Arguments:
+ from_down: tensor from the encoder pathway
+ from_up: upconv'd tensor from the decoder pathway
+ """
+ from_up = self.upconv(from_up)
+ if self.merge_mode == 'concat':
+ x = torch.cat((from_up, from_down), 1)
+ else:
+ x = from_up + from_down
+ x = F.relu(self.conv1(x))
+ x = F.relu(self.conv2(x))
+ return x
+
+
+class UNet(nn.Module):
+ """ `UNet` class is based on https://arxiv.org/abs/1505.04597
+
+ The U-Net is a convolutional encoder-decoder neural network.
+ Contextual spatial information (from the decoding,
+ expansive pathway) about an input tensor is merged with
+ information representing the localization of details
+ (from the encoding, compressive pathway).
+
+ Modifications to the original paper:
+ (1) padding is used in 3x3 convolutions to prevent loss
+ of border pixels
+ (2) merging outputs does not require cropping due to (1)
+ (3) residual connections can be used by specifying
+ UNet(merge_mode='add')
+ (4) if non-parametric upsampling is used in the decoder
+ pathway (specified by upmode='upsample'), then an
+ additional 1x1 2d convolution occurs after upsampling
+ to reduce channel dimensionality by a factor of 2.
+ This channel halving happens with the convolution in
+ the tranpose convolution (specified by upmode='transpose')
+ """
+
+ def __init__(self, num_classes, in_channels=3, depth=5,
+ start_filts=64, up_mode='transpose',
+ merge_mode='concat', **kwargs):
+ """
+ Arguments:
+ in_channels: int, number of channels in the input tensor.
+ Default is 3 for RGB images.
+ depth: int, number of MaxPools in the U-Net.
+ start_filts: int, number of convolutional filters for the
+ first conv.
+ up_mode: string, type of upconvolution. Choices: 'transpose'
+ for transpose convolution or 'upsample' for nearest neighbour
+ upsampling.
+ """
+ super(UNet, self).__init__()
+
+ if up_mode in ('transpose', 'upsample'):
+ self.up_mode = up_mode
+ else:
+ raise ValueError("\"{}\" is not a valid mode for "
+ "upsampling. Only \"transpose\" and "
+ "\"upsample\" are allowed.".format(up_mode))
+
+ if merge_mode in ('concat', 'add'):
+ self.merge_mode = merge_mode
+ else:
+ raise ValueError("\"{}\" is not a valid mode for"
+ "merging up and down paths. "
+ "Only \"concat\" and "
+ "\"add\" are allowed.".format(up_mode))
+
+ # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
+ if self.up_mode == 'upsample' and self.merge_mode == 'add':
+ raise ValueError("up_mode \"upsample\" is incompatible "
+ "with merge_mode \"add\" at the moment "
+ "because it doesn't make sense to use "
+ "nearest neighbour to reduce "
+ "depth channels (by half).")
+
+ self.num_classes = num_classes
+ self.in_channels = in_channels
+ self.start_filts = start_filts
+ self.depth = depth
+
+ self.down_convs = []
+ self.up_convs = []
+
+ # create the encoder pathway and add to a list
+ for i in range(depth):
+ ins = self.in_channels if i == 0 else outs
+ outs = self.start_filts*(2**i)
+ pooling = True if i < depth-1 else False
+
+ down_conv = DownConv(ins, outs, pooling=pooling)
+ self.down_convs.append(down_conv)
+
+ # create the decoder pathway and add to a list
+ # - careful! decoding only requires depth-1 blocks
+ for i in range(depth-1):
+ ins = outs
+ outs = ins // 2
+ up_conv = UpConv(ins, outs, up_mode=up_mode,
+ merge_mode=merge_mode)
+ self.up_convs.append(up_conv)
+
+ # add the list of modules to current module
+ self.down_convs = nn.ModuleList(self.down_convs)
+ self.up_convs = nn.ModuleList(self.up_convs)
+
+ self.conv_final = conv1x1(outs, self.num_classes)
+
+ self.reset_params()
+
+ @staticmethod
+ def weight_init(m):
+ if isinstance(m, nn.Conv2d):
+ init.xavier_normal_(m.weight)
+ init.constant_(m.bias, 0)
+
+
+ def reset_params(self):
+ for i, m in enumerate(self.modules()):
+ self.weight_init(m)
+
+
+ def forward(self, x):
+ encoder_outs = []
+ # encoder pathway, save outputs for merging
+ for i, module in enumerate(self.down_convs):
+ x, before_pool = module(x)
+ encoder_outs.append(before_pool)
+ for i, module in enumerate(self.up_convs):
+ before_pool = encoder_outs[-(i+2)]
+ x = module(before_pool, x)
+
+ # No softmax is used. This means you need to use
+ # nn.CrossEntropyLoss is your training script,
+ # as this module includes a softmax already.
+ x = self.conv_final(x)
+ return x
+
+if __name__ == "__main__":
+ """
+ testing
+ """
+ model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32)
+ print(model)
+ print(sum(p.numel() for p in model.parameters()))
+
+ reso = 176
+ x = np.zeros((1, 1, reso, reso))
+ x[:,:,int(reso/2-1), int(reso/2-1)] = np.nan
+ x = torch.FloatTensor(x)
+
+ out = model(x)
+ print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso)))
+
+ # loss = torch.sum(out)
+ # loss.backward()
diff --git a/src/network/unet3d.py b/src/network/unet3d.py
new file mode 100644
index 0000000..3c2bae2
--- /dev/null
+++ b/src/network/unet3d.py
@@ -0,0 +1,559 @@
+'''
+Code from the 3D UNet implementation:
+https://github.com/wolny/pytorch-3dunet/
+'''
+import importlib
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from functools import partial
+from src.network.utils import get_embedder
+
+def number_of_features_per_level(init_channel_number, num_levels):
+ return [init_channel_number * 2 ** k for k in range(num_levels)]
+
+
+def conv3d(in_channels, out_channels, kernel_size, bias, padding=1):
+ return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
+
+
+def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1):
+ """
+ Create a list of modules with together constitute a single conv layer with non-linearity
+ and optional batchnorm/groupnorm.
+
+ Args:
+ in_channels (int): number of input channels
+ out_channels (int): number of output channels
+ order (string): order of things, e.g.
+ 'cr' -> conv + ReLU
+ 'gcr' -> groupnorm + conv + ReLU
+ 'cl' -> conv + LeakyReLU
+ 'ce' -> conv + ELU
+ 'bcr' -> batchnorm + conv + ReLU
+ num_groups (int): number of groups for the GroupNorm
+ padding (int): add zero-padding to the input
+
+ Return:
+ list of tuple (name, module)
+ """
+ assert 'c' in order, "Conv layer MUST be present"
+ assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer'
+
+ modules = []
+ for i, char in enumerate(order):
+ if char == 'r':
+ modules.append(('ReLU', nn.ReLU(inplace=True)))
+ elif char == 'l':
+ modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True)))
+ elif char == 'e':
+ modules.append(('ELU', nn.ELU(inplace=True)))
+ elif char == 'c':
+ # add learnable bias only in the absence of batchnorm/groupnorm
+ bias = not ('g' in order or 'b' in order)
+ modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding)))
+ elif char == 'g':
+ is_before_conv = i < order.index('c')
+ if is_before_conv:
+ num_channels = in_channels
+ else:
+ num_channels = out_channels
+
+ # use only one group if the given number of groups is greater than the number of channels
+ if num_channels < num_groups:
+ num_groups = 1
+
+ assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}'
+ modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)))
+ elif char == 'b':
+ is_before_conv = i < order.index('c')
+ if is_before_conv:
+ modules.append(('batchnorm', nn.BatchNorm3d(in_channels)))
+ else:
+ modules.append(('batchnorm', nn.BatchNorm3d(out_channels)))
+ else:
+ raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']")
+
+ return modules
+
+
+class SingleConv(nn.Sequential):
+ """
+ Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
+ of operations can be specified via the `order` parameter
+
+ Args:
+ in_channels (int): number of input channels
+ out_channels (int): number of output channels
+ kernel_size (int): size of the convolving kernel
+ order (string): determines the order of layers, e.g.
+ 'cr' -> conv + ReLU
+ 'crg' -> conv + ReLU + groupnorm
+ 'cl' -> conv + LeakyReLU
+ 'ce' -> conv + ELU
+ num_groups (int): number of groups for the GroupNorm
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8, padding=1):
+ super(SingleConv, self).__init__()
+
+ for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding):
+ self.add_module(name, module)
+
+
+class DoubleConv(nn.Sequential):
+ """
+ A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
+ We use (Conv3d+ReLU+GroupNorm3d) by default.
+ This can be changed however by providing the 'order' argument, e.g. in order
+ to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
+ Use padded convolutions to make sure that the output (H_out, W_out) is the same
+ as (H_in, W_in), so that you don't have to crop in the decoder path.
+
+ Args:
+ in_channels (int): number of input channels
+ out_channels (int): number of output channels
+ encoder (bool): if True we're in the encoder path, otherwise we're in the decoder
+ kernel_size (int): size of the convolving kernel
+ order (string): determines the order of layers, e.g.
+ 'cr' -> conv + ReLU
+ 'crg' -> conv + ReLU + groupnorm
+ 'cl' -> conv + LeakyReLU
+ 'ce' -> conv + ELU
+ num_groups (int): number of groups for the GroupNorm
+ """
+
+ def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='crg', num_groups=8):
+ super(DoubleConv, self).__init__()
+ if encoder:
+ # we're in the encoder path
+ conv1_in_channels = in_channels
+ conv1_out_channels = out_channels // 2
+ if conv1_out_channels < in_channels:
+ conv1_out_channels = in_channels
+ conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels
+ else:
+ # we're in the decoder path, decrease the number of channels in the 1st convolution
+ conv1_in_channels, conv1_out_channels = in_channels, out_channels
+ conv2_in_channels, conv2_out_channels = out_channels, out_channels
+
+ # conv1
+ self.add_module('SingleConv1',
+ SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups))
+ # conv2
+ self.add_module('SingleConv2',
+ SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups))
+
+
+class ExtResNetBlock(nn.Module):
+ """
+ Basic UNet block consisting of a SingleConv followed by the residual block.
+ The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number
+ of output channels is compatible with the residual block that follows.
+ This block can be used instead of standard DoubleConv in the Encoder module.
+ Motivated by: https://arxiv.org/pdf/1706.00120.pdf
+
+ Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs):
+ super(ExtResNetBlock, self).__init__()
+
+ # first convolution
+ self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
+ # residual block
+ self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
+ # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
+ n_order = order
+ for c in 'rel':
+ n_order = n_order.replace(c, '')
+ self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order,
+ num_groups=num_groups)
+
+ # create non-linearity separately
+ if 'l' in order:
+ self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ elif 'e' in order:
+ self.non_linearity = nn.ELU(inplace=True)
+ else:
+ self.non_linearity = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ # apply first convolution and save the output as a residual
+ out = self.conv1(x)
+ residual = out
+
+ # residual block
+ out = self.conv2(out)
+ out = self.conv3(out)
+
+ out += residual
+ out = self.non_linearity(out)
+
+ return out
+
+
+class Encoder(nn.Module):
+ """
+ A single module from the encoder path consisting of the optional max
+ pooling layer (one may specify the MaxPool kernel_size to be different
+ than the standard (2,2,2), e.g. if the volumetric data is anisotropic
+ (make sure to use complementary scale_factor in the decoder path) followed by
+ a DoubleConv module.
+ Args:
+ in_channels (int): number of input channels
+ out_channels (int): number of output channels
+ conv_kernel_size (int): size of the convolving kernel
+ apply_pooling (bool): if True use MaxPool3d before DoubleConv
+ pool_kernel_size (tuple): the size of the window to take a max over
+ pool_type (str): pooling layer: 'max' or 'avg'
+ basic_module(nn.Module): either ResNetBlock or DoubleConv
+ conv_layer_order (string): determines the order of layers
+ in `DoubleConv` module. See `DoubleConv` for more info.
+ num_groups (int): number of groups for the GroupNorm
+ """
+
+ def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True,
+ pool_kernel_size=(2, 2, 2), pool_type='max', basic_module=DoubleConv, conv_layer_order='crg',
+ num_groups=8):
+ super(Encoder, self).__init__()
+ assert pool_type in ['max', 'avg']
+ if apply_pooling:
+ if pool_type == 'max':
+ self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
+ else:
+ self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size)
+ else:
+ self.pooling = None
+
+ self.basic_module = basic_module(in_channels, out_channels,
+ encoder=True,
+ kernel_size=conv_kernel_size,
+ order=conv_layer_order,
+ num_groups=num_groups)
+
+ def forward(self, x):
+ if self.pooling is not None:
+ x = self.pooling(x)
+ x = self.basic_module(x)
+ return x
+
+
+class Decoder(nn.Module):
+ """
+ A single module for decoder path consisting of the upsampling layer
+ (either learned ConvTranspose3d or nearest neighbor interpolation) followed by a basic module (DoubleConv or ExtResNetBlock).
+ Args:
+ in_channels (int): number of input channels
+ out_channels (int): number of output channels
+ kernel_size (int): size of the convolving kernel
+ scale_factor (tuple): used as the multiplier for the image H/W/D in
+ case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation
+ from the corresponding encoder
+ basic_module(nn.Module): either ResNetBlock or DoubleConv
+ conv_layer_order (string): determines the order of layers
+ in `DoubleConv` module. See `DoubleConv` for more info.
+ num_groups (int): number of groups for the GroupNorm
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv,
+ conv_layer_order='crg', num_groups=8, mode='nearest'):
+ super(Decoder, self).__init__()
+ if basic_module == DoubleConv:
+ # if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining
+ self.upsampling = Upsampling(transposed_conv=False, in_channels=in_channels, out_channels=out_channels,
+ kernel_size=kernel_size, scale_factor=scale_factor, mode=mode)
+ # concat joining
+ self.joining = partial(self._joining, concat=True)
+ else:
+ # if basic_module=ExtResNetBlock use transposed convolution upsampling and summation joining
+ self.upsampling = Upsampling(transposed_conv=True, in_channels=in_channels, out_channels=out_channels,
+ kernel_size=kernel_size, scale_factor=scale_factor, mode=mode)
+ # sum joining
+ self.joining = partial(self._joining, concat=False)
+ # adapt the number of in_channels for the ExtResNetBlock
+ in_channels = out_channels
+
+ self.basic_module = basic_module(in_channels, out_channels,
+ encoder=False,
+ kernel_size=kernel_size,
+ order=conv_layer_order,
+ num_groups=num_groups)
+
+ def forward(self, encoder_features, x):
+ x = self.upsampling(encoder_features=encoder_features, x=x)
+ x = self.joining(encoder_features, x)
+ x = self.basic_module(x)
+ return x
+
+ @staticmethod
+ def _joining(encoder_features, x, concat):
+ if concat:
+ return torch.cat((encoder_features, x), dim=1)
+ else:
+ return encoder_features + x
+
+
+class Upsampling(nn.Module):
+ """
+ Upsamples a given multi-channel 3D data using either interpolation or learned transposed convolution.
+
+ Args:
+ transposed_conv (bool): if True uses ConvTranspose3d for upsampling, otherwise uses interpolation
+ concat_joining (bool): if True uses concatenation joining between encoder and decoder features, otherwise
+ uses summation joining (see Residual U-Net)
+ in_channels (int): number of input channels for transposed conv
+ out_channels (int): number of output channels for transpose conv
+ kernel_size (int or tuple): size of the convolving kernel
+ scale_factor (int or tuple): stride of the convolution
+ mode (str): algorithm used for upsampling:
+ 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest'
+ """
+
+ def __init__(self, transposed_conv, in_channels=None, out_channels=None, kernel_size=3,
+ scale_factor=(2, 2, 2), mode='nearest'):
+ super(Upsampling, self).__init__()
+
+ if transposed_conv:
+ # make sure that the output size reverses the MaxPool3d from the corresponding encoder
+ # (D_out = (D_in − 1) ×  stride[0] − 2 ×  padding[0] +  kernel_size[0] +  output_padding[0])
+ self.upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor,
+ padding=1)
+ else:
+ self.upsample = partial(self._interpolate, mode=mode)
+
+ def forward(self, encoder_features, x):
+ output_size = encoder_features.size()[2:]
+ return self.upsample(x, output_size)
+
+ @staticmethod
+ def _interpolate(x, size, mode):
+ return F.interpolate(x, size=size, mode=mode)
+
+
+class FinalConv(nn.Sequential):
+ """
+ A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution
+ which reduces the number of channels to 'out_channels'.
+ with the number of output channels 'out_channels // 2' and 'out_channels' respectively.
+ We use (Conv3d+ReLU+GroupNorm3d) by default.
+ This can be change however by providing the 'order' argument, e.g. in order
+ to change to Conv3d+BatchNorm3d+ReLU use order='cbr'.
+ Args:
+ in_channels (int): number of input channels
+ out_channels (int): number of output channels
+ kernel_size (int): size of the convolving kernel
+ order (string): determines the order of layers, e.g.
+ 'cr' -> conv + ReLU
+ 'crg' -> conv + ReLU + groupnorm
+ num_groups (int): number of groups for the GroupNorm
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8):
+ super(FinalConv, self).__init__()
+
+ # conv1
+ self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups))
+
+ # in the last layer a 1×1 convolution reduces the number of output channels to out_channels
+ final_conv = nn.Conv3d(in_channels, out_channels, 1)
+ self.add_module('final_conv', final_conv)
+
+class Abstract3DUNet(nn.Module):
+ """
+ Base class for standard and residual UNet.
+
+ Args:
+ in_channels (int): number of input channels
+ out_channels (int): number of output segmentation masks;
+ Note that that the of out_channels might correspond to either
+ different semantic classes or to different binary segmentation mask.
+ It's up to the user of the class to interpret the out_channels and
+ use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class)
+ or BCEWithLogitsLoss (two-class) respectively)
+ f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number
+ of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4
+ final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the
+ final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used
+ to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model.
+ basic_module: basic model for the encoder/decoder (DoubleConv, ExtResNetBlock, ....)
+ layer_order (string): determines the order of layers
+ in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d.
+ See `SingleConv` for more info
+ f_maps (int, tuple): if int: number of feature maps in the first conv layer of the encoder (default: 64);
+ if tuple: number of feature maps at each level
+ num_groups (int): number of groups for the GroupNorm
+ num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int)
+ is_segmentation (bool): if True (semantic segmentation problem) Sigmoid/Softmax normalization is applied
+ after the final convolution; if False (regression problem) the normalization layer is skipped at the end
+ testing (bool): if True (testing mode) the `final_activation` (if present, i.e. `is_segmentation=true`)
+ will be applied as the last operation during the forward pass; if False the model is in training mode
+ and the `final_activation` (even if present) won't be applied; default: False
+ """
+
+ def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr',
+ num_groups=8, num_levels=4, is_segmentation=False, testing=False, pe_freq=0, **kwargs):
+ super(Abstract3DUNet, self).__init__()
+
+ self.testing = testing
+
+ if isinstance(f_maps, int):
+ f_maps = number_of_features_per_level(f_maps, num_levels=num_levels)
+
+ # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)`
+
+ self.embed_fn = None
+ if pe_freq > 0:
+ embed_fn, input_ch = get_embedder(pe_freq, d_in=in_channels)
+ self.embed_fn = embed_fn
+ in_channels = input_ch
+
+ encoders = []
+ for i, out_feature_num in enumerate(f_maps):
+ if i == 0:
+ encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=basic_module,
+ conv_layer_order=layer_order, num_groups=num_groups)
+ else:
+ # TODO: adapt for anisotropy in the data, i.e. use proper pooling kernel to make the data isotropic after 1-2 pooling operations
+ # currently pools with a constant kernel: (2, 2, 2)
+ encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=basic_module,
+ conv_layer_order=layer_order, num_groups=num_groups)
+ encoders.append(encoder)
+
+ self.encoders = nn.ModuleList(encoders)
+
+ # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1`
+ decoders = []
+ reversed_f_maps = list(reversed(f_maps))
+ for i in range(len(reversed_f_maps) - 1):
+ if basic_module == DoubleConv:
+ in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1]
+ else:
+ in_feature_num = reversed_f_maps[i]
+
+ out_feature_num = reversed_f_maps[i + 1]
+ # TODO: if non-standard pooling was used, make sure to use correct striding for transpose conv
+ # currently strides with a constant stride: (2, 2, 2)
+ decoder = Decoder(in_feature_num, out_feature_num, basic_module=basic_module,
+ conv_layer_order=layer_order, num_groups=num_groups)
+ decoders.append(decoder)
+
+ self.decoders = nn.ModuleList(decoders)
+
+ # in the last layer a 1×1 convolution reduces the number of output
+ # channels to the number of labels
+ self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)
+
+ if is_segmentation:
+ # semantic segmentation problem
+ if final_sigmoid:
+ self.final_activation = nn.Sigmoid()
+ else:
+ self.final_activation = nn.Softmax(dim=1)
+ else:
+ # regression problem
+ self.final_activation = None
+
+ def forward(self, x):
+
+ if self.embed_fn is not None:
+ x = self.embed_fn(x.permute(0, 2, 3, 4, 1))
+ x = x.permute(0, 4, 1, 2, 3)
+
+ # encoder part
+ encoders_features = []
+ for encoder in self.encoders:
+ x = encoder(x)
+ # reverse the encoder outputs to be aligned with the decoder
+ encoders_features.insert(0, x)
+
+ # remove the last encoder's output from the list
+ # !!remember: it's the 1st in the list
+ encoders_features = encoders_features[1:]
+
+ # decoder part
+ for decoder, encoder_features in zip(self.decoders, encoders_features):
+ # pass the output from the corresponding encoder and the output
+ # of the previous decoder
+ x = decoder(encoder_features, x)
+
+ x = self.final_conv(x)
+
+ # apply final_activation (i.e. Sigmoid or Softmax) only during prediction. During training the network outputs
+ # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric
+ if self.testing and self.final_activation is not None:
+ x = self.final_activation(x)
+
+ return x
+
+
+class UNet3D(Abstract3DUNet):
+ """
+ 3DUnet model from
+ `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation"
+ `.
+
+ Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder
+ """
+
+ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
+ num_groups=8, num_levels=4, is_segmentation=True, **kwargs):
+ super(UNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid,
+ basic_module=DoubleConv, f_maps=f_maps, layer_order=layer_order,
+ num_groups=num_groups, num_levels=num_levels, is_segmentation=is_segmentation,
+ **kwargs)
+
+
+class ResidualUNet3D(Abstract3DUNet):
+ """
+ Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
+ Uses ExtResNetBlock as a basic building block, summation joining instead
+ of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts).
+ Since the model effectively becomes a residual net, in theory it allows for deeper UNet.
+ """
+
+ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
+ num_groups=8, num_levels=5, is_segmentation=True, **kwargs):
+ super(ResidualUNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels,
+ final_sigmoid=final_sigmoid,
+ basic_module=ExtResNetBlock, f_maps=f_maps, layer_order=layer_order,
+ num_groups=num_groups, num_levels=num_levels,
+ is_segmentation=is_segmentation,
+ **kwargs)
+
+
+def get_model(config):
+ def _model_class(class_name):
+ m = importlib.import_module('pytorch3dunet.unet3d.model')
+ clazz = getattr(m, class_name)
+ return clazz
+
+ assert 'model' in config, 'Could not find model configuration'
+ model_config = config['model']
+ model_class = _model_class(model_config['name'])
+ return model_class(**model_config)
+
+
+if __name__ == "__main__":
+ """
+ testing
+ """
+ in_channels = 1
+ out_channels = 1
+ f_maps = 32
+ num_levels = 2
+ model = UNet3D(in_channels, out_channels, f_maps=f_maps, num_levels=num_levels, layer_order='cr')
+ print(model)
+ print('number of parameters: ', sum(p.numel() for p in model.parameters()))
+
+ reso = 18
+
+ import numpy as np
+ import torch
+ x = np.zeros((1, 1, reso, reso, reso))
+ x[:,:, int(reso/2-1), int(reso/2-1), int(reso/2-1)] = np.nan
+ x = torch.FloatTensor(x)
+
+ out = model(x)
+ print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso*reso)))
+
\ No newline at end of file
diff --git a/src/network/utils.py b/src/network/utils.py
new file mode 100644
index 0000000..68ea744
--- /dev/null
+++ b/src/network/utils.py
@@ -0,0 +1,167 @@
+""" Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """
+
+import torch
+import torch.nn as nn
+
+class Embedder:
+ def __init__(self, **kwargs):
+ self.kwargs = kwargs
+ self.create_embedding_fn()
+
+ def create_embedding_fn(self):
+ embed_fns = []
+ d = self.kwargs['input_dims']
+ out_dim = 0
+ if self.kwargs['include_input']:
+ embed_fns.append(lambda x: x)
+ out_dim += d
+
+ max_freq = self.kwargs['max_freq_log2']
+ N_freqs = self.kwargs['num_freqs']
+
+ if self.kwargs['log_sampling']:
+ freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
+ else:
+ freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
+
+ for freq in freq_bands:
+ for p_fn in self.kwargs['periodic_fns']:
+ embed_fns.append(lambda x, p_fn=p_fn,
+ freq=freq: p_fn(x * freq))
+ out_dim += d
+
+ self.embed_fns = embed_fns
+ self.out_dim = out_dim
+
+ def embed(self, inputs):
+ return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
+
+def get_embedder(multires, d_in=3):
+ embed_kwargs = {
+ 'include_input': True,
+ 'input_dims': d_in,
+ 'max_freq_log2': multires-1,
+ 'num_freqs': multires,
+ 'log_sampling': True,
+ 'periodic_fns': [torch.sin, torch.cos],
+ }
+
+ embedder_obj = Embedder(**embed_kwargs)
+ def embed(x, eo=embedder_obj): return eo.embed(x)
+ return embed, embedder_obj.out_dim
+
+def normalize_coordinate(p, plane='xz'):
+ ''' Normalize coordinate to [0, 1] for unit cube experiments
+
+ Args:
+ p (tensor): point
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
+ plane (str): plane feature type, ['xz', 'xy', 'yz']
+ '''
+ if plane == 'xz':
+ xy = p[:, :, [0, 2]]
+ elif plane =='xy':
+ xy = p[:, :, [0, 1]]
+ else:
+ xy = p[:, :, [1, 2]]
+
+ xy_new = xy
+ # f there are outliers out of the range
+ if xy_new.max() >= 1:
+ xy_new[xy_new >= 1] = 1 - 10e-6
+ if xy_new.min() < 0:
+ xy_new[xy_new < 0] = 0.0
+ return xy_new
+
+
+def normalize_3d_coordinate(p):
+ ''' Normalize coordinate to [0, 1] for unit cube experiments.
+ '''
+ if p.max() >= 1:
+ p[p >= 1] = 1 - 10e-6
+ if p.min() < 0:
+ p[p < 0] = 0.0
+ return p
+
+def coordinate2index(x, reso, coord_type='2d'):
+ ''' Normalize coordinate to [0, 1] for unit cube experiments.
+ Corresponds to our 3D model
+
+ Args:
+ x (tensor): coordinate
+ reso (int): defined resolution
+ coord_type (str): coordinate type
+ '''
+ x = (x * reso).long()
+ if coord_type == '2d': # plane
+ index = x[:, :, 0] + reso * x[:, :, 1]
+ elif coord_type == '3d': # grid
+ index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2])
+ index = index[:, None, :]
+ return index
+
+
+class map2local(object):
+ ''' Add new keys to the given input
+
+ Args:
+ s (float): the defined voxel size
+ pos_encoding (str): method for the positional encoding, linear|sin_cos
+ '''
+ def __init__(self, s, pos_encoding='linear'):
+ super().__init__()
+ self.s = s
+ # self.pe = positional_encoding(basis_function=pos_encoding, local=True)
+
+ def __call__(self, p):
+ # p = torch.remainder(p, self.s) / self.s # always possitive
+ p = (p % self.s) / self.s
+ p[p < 0] = 0.0
+ # p = torch.fmod(p, self.s) / self.s # same sign as input p!
+ # p = self.pe(p)
+ return p
+
+# Resnet Blocks
+class ResnetBlockFC(nn.Module):
+ ''' Fully connected ResNet Block class.
+
+ Args:
+ size_in (int): input dimension
+ size_out (int): output dimension
+ size_h (int): hidden dimension
+ '''
+
+ def __init__(self, size_in, size_out=None, size_h=None, siren=False):
+ super().__init__()
+ # Attributes
+ if size_out is None:
+ size_out = size_in
+
+ if size_h is None:
+ size_h = min(size_in, size_out)
+
+ self.size_in = size_in
+ self.size_h = size_h
+ self.size_out = size_out
+ # Submodules
+ self.fc_0 = nn.Linear(size_in, size_h)
+ self.fc_1 = nn.Linear(size_h, size_out)
+ self.actvn = nn.ReLU()
+
+ if size_in == size_out:
+ self.shortcut = None
+ else:
+ self.shortcut = nn.Linear(size_in, size_out, bias=False)
+ # Initialization
+ nn.init.zeros_(self.fc_1.weight)
+
+ def forward(self, x):
+ net = self.fc_0(self.actvn(x))
+ dx = self.fc_1(self.actvn(net))
+
+ if self.shortcut is not None:
+ x_s = self.shortcut(x)
+ else:
+ x_s = x
+
+ return x_s + dx
\ No newline at end of file
diff --git a/src/optimization.py b/src/optimization.py
new file mode 100644
index 0000000..12a0c1d
--- /dev/null
+++ b/src/optimization.py
@@ -0,0 +1,349 @@
+import time, os
+import numpy as np
+import torch
+from torch.nn import functional as F
+import trimesh
+
+from src.dpsr import DPSR
+from src.model import PSR2Mesh
+from src.utils import grid_interp, verts_on_largest_mesh,\
+ export_pointcloud, mc_from_psr, GaussianSmoothing
+from src.visualize import visualize_points_mesh, visualize_psr_grid, \
+ visualize_mesh_phong, render_rgb
+from torchvision.utils import save_image
+from torchvision.io import write_video
+from pytorch3d.loss import chamfer_distance
+import open3d as o3d
+
+class Trainer(object):
+ '''
+ Args:
+ cfg : config file
+ optimizer : pytorch optimizer object
+ device : pytorch device
+ '''
+
+ def __init__(self, cfg, optimizer, device=None):
+ self.optimizer = optimizer
+ self.device = device
+ self.cfg = cfg
+ self.psr2mesh = PSR2Mesh.apply
+ self.data_type = cfg['data']['data_type']
+
+ # initialize DPSR
+ self.dpsr = DPSR(res=(cfg['model']['grid_res'],
+ cfg['model']['grid_res'],
+ cfg['model']['grid_res']),
+ sig=cfg['model']['psr_sigma'])
+ if torch.cuda.device_count() > 1:
+ self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR
+ self.dpsr = self.dpsr.to(device)
+
+ def train_step(self, data, inputs, model, it):
+ ''' Performs a training step.
+
+ Args:
+ data (dict) : data dictionary
+ inputs (torch.tensor) : input point clouds
+ model (nn.Module or None): a neural network or None
+ it (int) : the number of iterations
+ '''
+
+ self.optimizer.zero_grad()
+ loss, loss_each = self.compute_loss(inputs, data, model, it)
+
+ loss.backward()
+ self.optimizer.step()
+
+ return loss.item(), loss_each
+
+ def compute_loss(self, inputs, data, model, it=0):
+ ''' Compute the loss.
+ Args:
+ data (dict) : data dictionary
+ inputs (torch.tensor) : input point clouds
+ model (nn.Module or None): a neural network or None
+ it (int) : the number of iterations
+ '''
+
+ device = self.device
+ res = self.cfg['model']['grid_res']
+
+ # source oriented point clouds to PSR grid
+ psr_grid, points, normals = self.pcl2psr(inputs)
+
+ # build mesh
+ v, f, n = self.psr2mesh(psr_grid)
+
+ # the output is in the range of [0, 1), we make it to the real range [0, 1].
+ # This is a hack for our DPSR solver
+ v = v * res / (res-1)
+
+ points = points * 2. - 1.
+ v = v * 2. - 1. # within the range of (-1, 1)
+
+ loss = 0
+ loss_each = {}
+ # compute loss
+ if self.data_type == 'point':
+ if self.cfg['train']['w_chamfer'] > 0:
+ loss_ = self.cfg['train']['w_chamfer'] * \
+ self.compute_3d_loss(v, data)
+ loss_each['chamfer'] = loss_
+ loss += loss_
+ elif self.data_type == 'img':
+ loss, loss_each = self.compute_2d_loss(inputs, data, model)
+
+ return loss, loss_each
+
+
+ def pcl2psr(self, inputs):
+ ''' Convert an oriented point cloud to PSR indicator grid
+ Args:
+ inputs (torch.tensor): input oriented point clouds
+ '''
+
+ points, normals = inputs[...,:3], inputs[...,3:]
+ if self.cfg['model']['apply_sigmoid']:
+ points = torch.sigmoid(points)
+ if self.cfg['model']['normal_normalize']:
+ normals = normals / normals.norm(dim=-1, keepdim=True)
+
+ # DPSR to get grid
+ psr_grid = self.dpsr(points, normals).unsqueeze(1)
+ psr_grid = torch.tanh(psr_grid)
+
+ return psr_grid, points, normals
+
+ def compute_3d_loss(self, v, data):
+ ''' Compute the loss for point clouds.
+ Args:
+ v (torch.tensor) : mesh vertices
+ data (dict) : data dictionary
+ '''
+
+ pts_gt = data.get('target_points')
+ idx = np.random.randint(pts_gt.shape[1], size=self.cfg['train']['n_sup_point'])
+ if self.cfg['train']['subsample_vertex']:
+ #chamfer distance only on random sampled vertices
+ idx = np.random.randint(v.shape[1], size=self.cfg['train']['n_sup_point'])
+ loss, _ = chamfer_distance(v[:, idx], pts_gt)
+ else:
+ loss, _ = chamfer_distance(v, pts_gt)
+
+ return loss
+
+ def compute_2d_loss(self, inputs, data, model):
+ ''' Compute the 2D losses.
+ Args:
+ inputs (torch.tensor) : input source point clouds
+ data (dict) : data dictionary
+ model (nn.Module or None): neural network or None
+ '''
+
+ losses = {"color":
+ {"weight": self.cfg['train']['l_weight']['rgb'],
+ "values": []
+ },
+ "silhouette":
+ {"weight": self.cfg['train']['l_weight']['mask'],
+ "values": []},
+ }
+ loss_all = {k: torch.tensor(0.0, device=self.device) for k in losses}
+
+ # forward pass
+ out = model(inputs, data)
+
+ if out['rgb'] is not None:
+ rgb_gt = out['rgb_gt'].reshape(self.cfg['data']['n_views_per_iter'],
+ -1, 3)[out['vis_mask']]
+ loss_all["color"] += torch.nn.L1Loss(reduction='sum')(rgb_gt,
+ out['rgb']) / out['rgb'].shape[0]
+
+ if out['mask'] is not None:
+ loss_all["silhouette"] += ((out['mask'] - out['mask_gt']) ** 2).mean()
+
+ # weighted sum of the losses
+ loss = torch.tensor(0.0, device=self.device)
+ for k, l in loss_all.items():
+ loss += l * losses[k]["weight"]
+ losses[k]["values"].append(l)
+
+ return loss, loss_all
+
+ def point_resampling(self, inputs):
+ ''' Resample points
+ Args:
+ inputs (torch.tensor): oriented point clouds
+ '''
+
+ psr_grid, points, normals = self.pcl2psr(inputs)
+
+ # shortcuts
+ n_grow = self.cfg['train']['n_grow_points']
+
+ # [hack] for points resampled from the mesh from marching cubes,
+ # we need to divide by s instead of (s-1), and the scale is correct.
+ verts, faces, _ = mc_from_psr(psr_grid, real_scale=False, zero_level=0)
+
+ # find the largest component
+ pts_mesh, faces_mesh = verts_on_largest_mesh(verts, faces)
+
+ # sample vertices only from the largest component, not from fragments
+ mesh = trimesh.Trimesh(vertices=pts_mesh, faces=faces_mesh)
+ pi, face_idx = mesh.sample(n_grow+points.shape[1], return_index=True)
+ normals_i = mesh.face_normals[face_idx].astype('float32')
+ pts_mesh = torch.tensor(pi.astype('float32')).to(self.device)[None]
+ n_mesh = torch.tensor(normals_i).to(self.device)[None]
+
+ points, normals = pts_mesh, n_mesh
+ print('{} total points are resampled'.format(points.shape[1]))
+
+ # update inputs
+ points = torch.log(points / (1 - points)) # inverse sigmoid
+ inputs = torch.cat([points, normals], dim=-1)
+ inputs.requires_grad = True
+
+ return inputs
+
+ def visualize(self, data, inputs, renderer, epoch, o3d_vis=None):
+ ''' Visualization.
+ Args:
+ data (dict) : data dictionary
+ inputs (torch.tensor) : source point clouds
+ renderer (nn.Module or None): a neural network or None
+ epoch (int) : the number of iterations
+ o3d_vis (o3d.Visualizer) : open3d visualizer
+ '''
+
+ data_type = self.cfg['data']['data_type']
+ it = '{:04d}'.format(int(epoch/self.cfg['train']['visualize_every']))
+
+
+ if (self.cfg['train']['exp_mesh']) \
+ | (self.cfg['train']['exp_pcl']) \
+ | (self.cfg['train']['o3d_show']):
+ psr_grid, points, normals = self.pcl2psr(inputs)
+
+ with torch.no_grad():
+ v, f, n = mc_from_psr(psr_grid, pytorchify=True,
+ zero_level=self.cfg['data']['zero_level'], real_scale=True)
+ v, f, n = v[None], f[None], n[None]
+
+ v = v * 2. - 1. # change to the range of [-1, 1]
+
+ color_v = None
+ if data_type == 'img':
+ if self.cfg['train']['vis_vert_color'] & \
+ (self.cfg['train']['l_weight']['rgb'] != 0.):
+ color_v = renderer['color'](v, n).squeeze().detach().cpu().numpy()
+ color_v[color_v<0], color_v[color_v>1] = 0., 1.
+
+ vv = v.detach().squeeze().cpu().numpy()
+ ff = f.detach().squeeze().cpu().numpy()
+ points = points * 2 - 1
+ visualize_points_mesh(o3d_vis, points, normals,
+ vv, ff, self.cfg, it, epoch, color_v=color_v)
+
+ else:
+ v, f, n = inputs
+
+
+ if (data_type == 'img') & (self.cfg['train']['vis_rendering']):
+ pred_imgs = []
+ pred_masks = []
+ n_views = len(data['poses'])
+ # idx_list = trange(n_views)
+ idx_list = [13, 24, 27, 48]
+
+ #!
+ model = renderer.eval()
+ for idx in idx_list:
+ pose = data['poses'][idx]
+ rgb = data['rgbs'][idx]
+ mask_gt = data['masks'][idx]
+ img_size = rgb.shape[0] if rgb.shape[0]== rgb.shape[1] else (rgb.shape[0], rgb.shape[1])
+ ray = None
+ if 'rays' in data.keys():
+ ray = data['rays'][idx]
+ if self.cfg['train']['l_weight']['rgb'] != 0.:
+ fea_grid = None
+ if model.unet3d is not None:
+ with torch.no_grad():
+ fea_grid = model.unet3d(psr_grid).permute(0, 2, 3, 4, 1)
+ if model.encoder is not None:
+ pp = torch.cat([(points+1)/2, normals], dim=-1)
+ fea_grid = model.encoder(pp,
+ normalize=False).permute(0, 2, 3, 4, 1)
+
+ pred, visible_mask = render_rgb(v, f, n, pose,
+ model.rendering_network.eval(),
+ img_size, ray=ray, fea_grid=fea_grid)
+ img_pred = torch.zeros([rgb.shape[0]*rgb.shape[1], 3])
+ img_pred[visible_mask] = pred.detach().cpu()
+
+ img_pred = img_pred.reshape(rgb.shape[0], rgb.shape[1], 3)
+ img_pred[img_pred<0], img_pred[img_pred>1] = 0., 1.
+ filename=os.path.join(self.cfg['train']['dir_rendering'],
+ 'rendering_{}_{:d}.png'.format(it, idx))
+ save_image(img_pred.permute(2, 0, 1), filename)
+ pred_imgs.append(img_pred)
+
+ #! Mesh rendering using Phong shading model
+ filename=os.path.join(self.cfg['train']['dir_rendering'],
+ 'mesh_{}_{:d}.png'.format(it, idx))
+ visualize_mesh_phong(v, f, n, pose, img_size, name=filename)
+
+ if len(pred_imgs) >= 1:
+ pred_imgs = torch.stack(pred_imgs, dim=0)
+ save_image(pred_imgs.permute(0, 3, 1, 2),
+ os.path.join(self.cfg['train']['dir_rendering'],
+ '{}.png'.format(it)), nrow=4)
+ if self.cfg['train']['save_video']:
+ write_video(os.path.join(self.cfg['train']['dir_rendering'],
+ '{}.mp4'.format(it)),
+ (pred_imgs*255.).type(torch.uint8), fps=24)
+
+ def save_mesh_pointclouds(self, inputs, epoch, center=None, scale=None):
+ ''' Save meshes and point clouds.
+ Args:
+ inputs (torch.tensor) : source point clouds
+ epoch (int) : the number of iterations
+ center (numpy.array) : center of the shape
+ scale (numpy.array) : scale of the shape
+ '''
+
+ exp_pcl = self.cfg['train']['exp_pcl']
+ exp_mesh = self.cfg['train']['exp_mesh']
+
+ psr_grid, points, normals = self.pcl2psr(inputs)
+
+ if exp_pcl:
+ dir_pcl = self.cfg['train']['dir_pcl']
+ p = points.squeeze(0).detach().cpu().numpy()
+ p = p * 2 - 1
+ n = normals.squeeze(0).detach().cpu().numpy()
+ if scale is not None:
+ p *= scale
+ if center is not None:
+ p += center
+ export_pointcloud(os.path.join(dir_pcl, '{:04d}.ply'.format(epoch)), p, n)
+ if exp_mesh:
+ dir_mesh = self.cfg['train']['dir_mesh']
+ with torch.no_grad():
+ v, f, _ = mc_from_psr(psr_grid,
+ zero_level=self.cfg['data']['zero_level'], real_scale=True)
+ v = v * 2 - 1
+ if scale is not None:
+ v *= scale
+ if center is not None:
+ v += center
+ mesh = o3d.geometry.TriangleMesh()
+ mesh.vertices = o3d.utility.Vector3dVector(v)
+ mesh.triangles = o3d.utility.Vector3iVector(f)
+ outdir_mesh = os.path.join(dir_mesh, '{:04d}.ply'.format(epoch))
+ o3d.io.write_triangle_mesh(outdir_mesh, mesh)
+
+ if self.cfg['train']['vis_psr']:
+ dir_psr_vis = self.cfg['train']['out_dir']+'/psr_vis_all'
+ visualize_psr_grid(psr_grid, out_dir=dir_psr_vis)
diff --git a/src/training.py b/src/training.py
new file mode 100644
index 0000000..b38a1d4
--- /dev/null
+++ b/src/training.py
@@ -0,0 +1,207 @@
+import os
+import numpy as np
+import torch
+from torch.nn import functional as F
+from collections import defaultdict
+import trimesh
+from tqdm import tqdm
+
+from src.dpsr import DPSR
+from src.utils import grid_interp, export_pointcloud, export_mesh, \
+ mc_from_psr, scale2onet, GaussianSmoothing
+from pytorch3d.ops.knn import knn_gather, knn_points
+from pytorch3d.loss import chamfer_distance
+from pdb import set_trace as st
+
+class Trainer(object):
+ '''
+ Args:
+ model (nn.Module): our defined model
+ optimizer (optimizer): pytorch optimizer object
+ device (device): pytorch device
+ input_type (str): input type
+ vis_dir (str): visualization directory
+ '''
+ def __init__(self, cfg, optimizer, device=None):
+ self.optimizer = optimizer
+ self.device = device
+ self.cfg = cfg
+ if self.cfg['train']['w_raw'] != 0:
+ from src.model import PSR2Mesh
+ self.psr2mesh = PSR2Mesh.apply
+
+ # initialize DPSR
+ self.dpsr = DPSR(res=(cfg['model']['grid_res'],
+ cfg['model']['grid_res'],
+ cfg['model']['grid_res']),
+ sig=cfg['model']['psr_sigma'])
+ if torch.cuda.device_count() > 1:
+ self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR
+ self.dpsr = self.dpsr.to(device)
+
+ if cfg['train']['gauss_weight']>0.:
+ self.gauss_smooth = GaussianSmoothing(1, 7, 2).to(device)
+
+ def train_step(self, inputs, data, model):
+ ''' Performs a training step.
+
+ Args:
+ data (dict): data dictionary
+ '''
+ self.optimizer.zero_grad()
+ p = data.get('inputs').to(self.device)
+
+ out = model(p)
+
+ points, normals = out
+
+ loss = 0
+ loss_each = {}
+ if self.cfg['train']['w_psr'] != 0:
+ psr_gt = data.get('gt_psr').to(self.device)
+ if self.cfg['model']['psr_tanh']:
+ psr_gt = torch.tanh(psr_gt)
+
+ psr_grid = self.dpsr(points, normals)
+ if self.cfg['model']['psr_tanh']:
+ psr_grid = torch.tanh(psr_grid)
+
+ # apply a rescaling weight based on GT SDF values
+ if self.cfg['train']['gauss_weight']>0:
+ gauss_sigma = self.cfg['train']['gauss_weight']
+ # set up the weighting for loss, higher weights
+ # for points near to the surface
+ psr_gt_pad = torch.nn.ReplicationPad3d(1)(psr_gt.unsqueeze(1)).squeeze(1)
+ delta_x = delta_y = delta_z = 1
+ grad_x = (psr_gt_pad[:, 2:, :, :] - psr_gt_pad[:, :-2, :, :]) / 2 / delta_x
+ grad_y = (psr_gt_pad[:, :, 2:, :] - psr_gt_pad[:, :, :-2, :]) / 2 / delta_y
+ grad_z = (psr_gt_pad[:, :, :, 2:] - psr_gt_pad[:, :, :, :-2]) / 2 / delta_z
+ grad_x = grad_x[:, :, 1:-1, 1:-1]
+ grad_y = grad_y[:, 1:-1, :, 1:-1]
+ grad_z = grad_z[:, 1:-1, 1:-1, :]
+ psr_grad = torch.stack([grad_x, grad_y, grad_z], dim=-1)
+ psr_grad_norm = psr_grad.norm(dim=-1)[:, None]
+ w = torch.nn.ReplicationPad3d(3)(psr_grad_norm)
+ w = 2*self.gauss_smooth(w).squeeze(1)
+ loss_each['psr'] = self.cfg['train']['w_psr'] * F.mse_loss(w*psr_grid, w*psr_gt)
+ else:
+ loss_each['psr'] = self.cfg['train']['w_psr'] * F.mse_loss(psr_grid, psr_gt)
+
+ loss += loss_each['psr']
+
+ # regularization on the input point positions via chamfer distance
+ if self.cfg['train']['w_reg_point'] != 0.:
+ points_gt = data.get('gt_points').to(self.device)
+ loss_reg, loss_norm = chamfer_distance(points, points_gt)
+
+ loss_each['reg'] = self.cfg['train']['w_reg_point'] * loss_reg
+ loss += loss_each['reg']
+
+ if self.cfg['train']['w_normals'] != 0.:
+ points_gt = data.get('gt_points').to(self.device)
+ normals_gt = data.get('gt_points.normals').to(self.device)
+ x_nn = knn_points(points, points_gt, K=1)
+ x_normals_near = knn_gather(normals_gt, x_nn.idx)[..., 0, :]
+
+ cham_norm_x = F.l1_loss(normals, x_normals_near)
+ loss_norm = cham_norm_x
+
+ loss_each['normals'] = self.cfg['train']['w_normals'] * loss_norm
+ loss += loss_each['normals']
+
+ if self.cfg['train']['w_raw'] != 0:
+ res = self.cfg['model']['grid_res']
+ # DPSR to get grid
+ psr_grid = self.dpsr(points, normals)
+ if self.cfg['model']['psr_tanh']:
+ psr_grid = torch.tanh(psr_grid)
+
+ v, f, n = self.psr2mesh(psr_grid)
+
+ pts_gt = data.get('gt_points').to(self.device)
+
+ loss, _ = chamfer_distance(v, pts_gt)
+
+ loss.backward()
+ self.optimizer.step()
+
+ return loss.item(), loss_each
+
+ def save(self, model, data, epoch, id):
+
+ p = data.get('inputs').to(self.device)
+
+ exp_pcl = self.cfg['train']['exp_pcl']
+ exp_mesh = self.cfg['train']['exp_mesh']
+ exp_gt = self.cfg['generation']['exp_gt']
+ exp_input = self.cfg['generation']['exp_input']
+
+ model.eval()
+ with torch.no_grad():
+ points, normals = model(p)
+
+ if exp_gt:
+ points_gt = data.get('gt_points').to(self.device)
+ normals_gt = data.get('gt_points.normals').to(self.device)
+
+ if exp_pcl:
+ dir_pcl = self.cfg['train']['dir_pcl']
+ export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}.ply'.format(epoch, id)), scale2onet(points), normals)
+ if exp_gt:
+ export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}_oracle.ply'.format(epoch, id)), scale2onet(points_gt), normals_gt)
+ if exp_input:
+ export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}_input.ply'.format(epoch, id)), scale2onet(p))
+
+ if exp_mesh:
+ dir_mesh = self.cfg['train']['dir_mesh']
+ psr_grid = self.dpsr(points, normals)
+ # psr_grid = torch.tanh(psr_grid)
+ with torch.no_grad():
+ v, f, _ = mc_from_psr(psr_grid,
+ zero_level=self.cfg['data']['zero_level'])
+ outdir_mesh = os.path.join(dir_mesh, '{:04d}_{:01d}.ply'.format(epoch, id))
+ export_mesh(outdir_mesh, scale2onet(v), f)
+ if exp_gt:
+ psr_gt = self.dpsr(points_gt, normals_gt)
+ with torch.no_grad():
+ v, f, _ = mc_from_psr(psr_gt,
+ zero_level=self.cfg['data']['zero_level'])
+ export_mesh(os.path.join(dir_mesh, '{:04d}_{:01d}_oracle.ply'.format(epoch, id)), scale2onet(v), f)
+
+ def evaluate(self, val_loader, model):
+ ''' Performs an evaluation.
+ Args:
+ val_loader (dataloader): pytorch dataloader
+ '''
+ eval_list = defaultdict(list)
+
+ for data in tqdm(val_loader):
+ eval_step_dict = self.eval_step(data, model)
+
+ for k, v in eval_step_dict.items():
+ eval_list[k].append(v)
+
+ eval_dict = {k: np.mean(v) for k, v in eval_list.items()}
+ return eval_dict
+
+ def eval_step(self, data, model):
+ ''' Performs an evaluation step.
+ Args:
+ data (dict): data dictionary
+ '''
+ model.eval()
+ eval_dict = {}
+
+ p = data.get('inputs').to(self.device)
+ psr_gt = data.get('gt_psr').to(self.device)
+
+ with torch.no_grad():
+ # forward pass
+ points, normals = model(p)
+ # DPSR to get predicted psr grid
+ psr_grid = self.dpsr(points, normals)
+
+ eval_dict['psr_l1'] = F.l1_loss(psr_grid, psr_gt).item()
+ eval_dict['psr_l2'] = F.mse_loss(psr_grid, psr_gt).item()
+
+ return eval_dict
\ No newline at end of file
diff --git a/src/utils.py b/src/utils.py
new file mode 100644
index 0000000..b9ce35b
--- /dev/null
+++ b/src/utils.py
@@ -0,0 +1,645 @@
+import torch
+import io, os, logging, urllib
+import yaml
+import trimesh
+import imageio
+import numbers
+import math
+import numpy as np
+from collections import OrderedDict
+from plyfile import PlyData
+from torch import nn
+from torch.nn import functional as F
+from torch.utils import model_zoo
+from skimage import measure, img_as_float32
+from pytorch3d.structures import Meshes
+from pytorch3d.renderer import PerspectiveCameras, rasterize_meshes
+from igl import adjacency_matrix, connected_components
+import open3d as o3d
+
+##################################################
+# Below are functions for DPSR
+
+def fftfreqs(res, dtype=torch.float32, exact=True):
+ """
+ Helper function to return frequency tensors
+ :param res: n_dims int tuple of number of frequency modes
+ :return:
+ """
+
+ n_dims = len(res)
+ freqs = []
+ for dim in range(n_dims - 1):
+ r_ = res[dim]
+ freq = np.fft.fftfreq(r_, d=1/r_)
+ freqs.append(torch.tensor(freq, dtype=dtype))
+ r_ = res[-1]
+ if exact:
+ freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype))
+ else:
+ freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype))
+ omega = torch.meshgrid(freqs)
+ omega = list(omega)
+ omega = torch.stack(omega, dim=-1)
+
+ return omega
+
+def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag)
+ """
+ multiply tensor x by i ** deg
+ """
+ deg %= 4
+ if deg == 0:
+ res = x
+ elif deg == 1:
+ res = x[..., [1, 0]]
+ res[..., 0] = -res[..., 0]
+ elif deg == 2:
+ res = -x
+ elif deg == 3:
+ res = x[..., [1, 0]]
+ res[..., 1] = -res[..., 1]
+ return res
+
+def spec_gaussian_filter(res, sig):
+ omega = fftfreqs(res, dtype=torch.float64) # [dim0, dim1, dim2, d]
+ dis = torch.sqrt(torch.sum(omega ** 2, dim=-1))
+ filter_ = torch.exp(-0.5*((sig*2*dis/res[0])**2)).unsqueeze(-1).unsqueeze(-1)
+ filter_.requires_grad = False
+
+ return filter_
+
+def grid_interp(grid, pts, batched=True):
+ """
+ :param grid: tensor of shape (batch, *size, in_features)
+ :param pts: tensor of shape (batch, num_points, dim) within range (0, 1)
+ :return values at query points
+ """
+ if not batched:
+ grid = grid.unsqueeze(0)
+ pts = pts.unsqueeze(0)
+ dim = pts.shape[-1]
+ bs = grid.shape[0]
+ size = torch.tensor(grid.shape[1:-1]).to(grid.device).type(pts.dtype)
+ cubesize = 1.0 / size
+
+ ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
+ ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
+ ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
+ tmp = torch.tensor([0,1],dtype=torch.long)
+ com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
+ dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
+ ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
+ ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
+ ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
+ # latent code on neighbor nodes
+ if dim == 2:
+ lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1]] # (batch, num_points, 2**dim, in_features)
+ else:
+ lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2]] # (batch, num_points, 2**dim, in_features)
+
+ # weights of neighboring nodes
+ xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
+ xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
+ xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
+ pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
+ pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
+ pos_ = pos_.type(pts.dtype)
+ dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
+ weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
+ query_values = torch.sum(lat * weights.unsqueeze(-1), dim=-2) # (batch, num_points, in_features)
+ if not batched:
+ query_values = query_values.squeeze(0)
+
+ return query_values
+
+def scatter_to_grid(inds, vals, size):
+ """
+ Scatter update values into empty tensor of size size.
+ :param inds: (#values, dims)
+ :param vals: (#values)
+ :param size: tuple for size. len(size)=dims
+ """
+ dims = inds.shape[1]
+ assert(inds.shape[0] == vals.shape[0])
+ assert(len(size) == dims)
+ dev = vals.device
+ # result = torch.zeros(*size).view(-1).to(dev).type(vals.dtype) # flatten
+ # # flatten inds
+ result = torch.zeros(*size, device=dev).view(-1).type(vals.dtype) # flatten
+ # flatten inds
+ fac = [np.prod(size[i+1:]) for i in range(len(size)-1)] + [1]
+ fac = torch.tensor(fac, device=dev).type(inds.dtype)
+ inds_fold = torch.sum(inds*fac, dim=-1) # [#values,]
+ result.scatter_add_(0, inds_fold, vals)
+ result = result.view(*size)
+ return result
+
+def point_rasterize(pts, vals, size):
+ """
+ :param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1)
+ :param vals: point values, tensor of shape (batch, num_points, features)
+ :param size: len(size)=dim tuple for grid size
+ :return rasterized values (batch, features, res0, res1, res2)
+ """
+ dim = pts.shape[-1]
+ assert(pts.shape[:2] == vals.shape[:2])
+ assert(pts.shape[2] == dim)
+ size_list = list(size)
+ size = torch.tensor(size).to(pts.device).float()
+ cubesize = 1.0 / size
+ bs = pts.shape[0]
+ nf = vals.shape[-1]
+ npts = pts.shape[1]
+ dev = pts.device
+
+ ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
+ ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
+ ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
+ tmp = torch.tensor([0,1],dtype=torch.long)
+ com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
+ dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
+ ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
+ ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
+ # ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
+ ind_b = torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
+
+ # weights of neighboring nodes
+ xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
+ xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
+ xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
+ pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
+ pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
+ pos_ = pos_.type(pts.dtype)
+ dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
+ weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
+
+ ind_b = ind_b.unsqueeze(-1).unsqueeze(-1) # (batch, num_points, 2**dim, 1, 1)
+ ind_n = ind_n.unsqueeze(-2) # (batch, num_points, 2**dim, 1, dim)
+ ind_f = torch.arange(nf, device=dev).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1)
+ # ind_f = torch.arange(nf).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1)
+
+ ind_b = ind_b.expand(bs, npts, 2**dim, nf, 1)
+ ind_n = ind_n.expand(bs, npts, 2**dim, nf, dim).to(dev)
+ ind_f = ind_f.expand(bs, npts, 2**dim, nf, 1)
+ inds = torch.cat([ind_b, ind_f, ind_n], dim=-1) # (batch, num_points, 2**dim, nf, 1+1+dim)
+
+ # weighted values
+ vals = weights.unsqueeze(-1) * vals.unsqueeze(-2) # (batch, num_points, 2**dim, nf)
+
+ inds = inds.view(-1, dim+2).permute(1, 0).long() # (1+dim+1, bs*npts*2**dim*nf)
+ vals = vals.reshape(-1) # (bs*npts*2**dim*nf)
+ tensor_size = [bs, nf] + size_list
+ raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf] + size_list)
+
+ return raster # [batch, nf, res, res, res]
+
+
+
+##################################################
+# Below are the utilization functions in general
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.n = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.n = n
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+ @property
+ def valcavg(self):
+ return self.val.sum().item() / (self.n != 0).sum().item()
+
+ @property
+ def avgcavg(self):
+ return self.avg.sum().item() / (self.count != 0).sum().item()
+
+def load_model_manual(state_dict, model):
+ new_state_dict = OrderedDict()
+ is_model_parallel = isinstance(model, torch.nn.DataParallel)
+ for k, v in state_dict.items():
+ if k.startswith('module.') != is_model_parallel:
+ if k.startswith('module.'):
+ # remove module
+ k = k[7:]
+ else:
+ # add module
+ k = 'module.' + k
+
+ new_state_dict[k]=v
+
+ model.load_state_dict(new_state_dict)
+
+def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0):
+ '''
+ Run marching cubes from PSR grid
+ '''
+ batch_size = psr_grid.shape[0]
+ s = psr_grid.shape[-1] # size of psr_grid
+ psr_grid_numpy = psr_grid.squeeze().detach().cpu().numpy()
+
+ if batch_size>1:
+ verts, faces, normals = [], [], []
+ for i in range(batch_size):
+ verts_cur, faces_cur, normals_cur, values = measure.marching_cubes(psr_grid_numpy[i], level=0)
+ verts.append(verts_cur)
+ faces.append(faces_cur)
+ normals.append(normals_cur)
+ verts = np.stack(verts, axis = 0)
+ faces = np.stack(faces, axis = 0)
+ normals = np.stack(normals, axis = 0)
+ else:
+ try:
+ verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy, level=zero_level)
+ except:
+ verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy)
+ if real_scale:
+ verts = verts / (s-1) # scale to range [0, 1]
+ else:
+ verts = verts / s # scale to range [0, 1)
+
+ if pytorchify:
+ device = psr_grid.device
+ verts = torch.Tensor(np.ascontiguousarray(verts)).to(device)
+ faces = torch.Tensor(np.ascontiguousarray(faces)).to(device)
+ normals = torch.Tensor(np.ascontiguousarray(-normals)).to(device)
+
+ return verts, faces, normals
+
+def calc_inters_points(verts, faces, pose, img_size, mask_gt=None):
+ verts = verts.squeeze()
+ faces = faces.squeeze()
+ pix_to_face, w, mask = mesh_rasterization(verts, faces, pose, img_size)
+ if mask_gt is not None:
+ #! only evaluate within the intersection
+ mask = mask & mask_gt
+ # find 3D points intesected on the mesh
+ if True:
+ w_masked = w[mask]
+ f_p = faces[pix_to_face[mask]].long() # cooresponding faces for each pixel
+ # corresponding vertices for p_closest
+ v_a, v_b, v_c = verts[f_p[..., 0]], verts[f_p[..., 1]], verts[f_p[..., 2]]
+
+ # calculate the intersection point of each pixel and the mesh
+ p_inters = w_masked[..., 0, None] * v_a + \
+ w_masked[..., 1, None] * v_b + \
+ w_masked[..., 2, None] * v_c
+ else:
+ # backproject ndc to world coordinates using z-buffer
+ W, H = img_size[1], img_size[0]
+ xy = uv.to(mask.device)[mask]
+ x_ndc = 1 - (2*xy[:, 0]) / (W - 1)
+ y_ndc = 1 - (2*xy[:, 1]) / (H - 1)
+ z = zbuf.squeeze().reshape(H * W)[mask]
+ xy_depth = torch.stack((x_ndc, y_ndc, z), dim=1)
+
+ p_inters = pose.unproject_points(xy_depth, world_coordinates=True)
+
+ # if there are outlier points, we should remove it
+ if (p_inters.max()>1) | (p_inters.min()<-1):
+ mask_bound = (p_inters>=-1) & (p_inters<=1)
+ mask_bound = (mask_bound.sum(dim=-1)==3)
+ mask[mask==True] = mask_bound
+ p_inters = p_inters[mask_bound]
+ print('!!!!!find outlier!')
+
+ return p_inters, mask, f_p, w_masked
+
+def mesh_rasterization(verts, faces, pose, img_size):
+ '''
+ Use PyTorch3D to rasterize the mesh given a camera
+ '''
+ transformed_v = pose.transform_points(verts.detach()) # world -> ndc coordinate system
+ if isinstance(pose, PerspectiveCameras):
+ transformed_v[..., 2] = 1/transformed_v[..., 2]
+ # find p_closest on mesh of each pixel via rasterization
+ transformed_mesh = Meshes(verts=[transformed_v], faces=[faces])
+ pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
+ transformed_mesh,
+ image_size=img_size,
+ blur_radius=0,
+ faces_per_pixel=1,
+ perspective_correct=False
+ )
+ pix_to_face = pix_to_face.reshape(1, -1) # B x reso x reso -> B x (reso x reso)
+ mask = pix_to_face.clone() != -1
+ mask = mask.squeeze()
+ pix_to_face = pix_to_face.squeeze()
+ w = bary_coords.reshape(-1, 3)
+
+ return pix_to_face, w, mask
+
+def verts_on_largest_mesh(verts, faces):
+ '''
+ verts: Numpy array or Torch.Tensor (N, 3)
+ faces: Numpy array (N, 3)
+ '''
+ if torch.is_tensor(faces):
+ verts = verts.squeeze().detach().cpu().numpy()
+ faces = faces.squeeze().int().detach().cpu().numpy()
+
+ A = adjacency_matrix(faces)
+ num, conn_idx, conn_size = connected_components(A)
+ if num == 0:
+ v_large, f_large = verts, faces
+ else:
+ max_idx = conn_size.argmax() # find the index of the largest component
+ v_large = verts[conn_idx==max_idx] # keep points on the largest component
+
+ if True:
+ mesh_largest = trimesh.Trimesh(verts, faces)
+ connected_comp = mesh_largest.split(only_watertight=False)
+ mesh_largest = connected_comp[max_idx]
+ v_large, f_large = mesh_largest.vertices, mesh_largest.faces
+ v_large = v_large.astype(np.float32)
+ return v_large, f_large
+
+def load_pointcloud(in_file):
+ plydata = PlyData.read(in_file)
+ vertices = np.stack([
+ plydata['vertex']['x'],
+ plydata['vertex']['y'],
+ plydata['vertex']['z']
+ ], axis=1)
+ return vertices
+
+# General config
+def load_config(path, default_path=None):
+ ''' Loads config file.
+
+ Args:
+ path (str): path to config file
+ default_path (bool): whether to use default path
+ '''
+ # Load configuration from file itself
+ with open(path, 'r') as f:
+ cfg_special = yaml.load(f, Loader=yaml.Loader)
+
+ # Check if we should inherit from a config
+ inherit_from = cfg_special.get('inherit_from')
+
+ # If yes, load this config first as default
+ # If no, use the default_path
+ if inherit_from is not None:
+ cfg = load_config(inherit_from, default_path)
+ elif default_path is not None:
+ with open(default_path, 'r') as f:
+ cfg = yaml.load(f, Loader=yaml.Loader)
+ else:
+ cfg = dict()
+
+ # Include main configuration
+ update_recursive(cfg, cfg_special)
+
+ return cfg
+
+def update_config(config, unknown):
+ # update config given args
+ for idx,arg in enumerate(unknown):
+ if arg.startswith("--"):
+ keys = arg.replace("--","").split(':')
+ assert(len(keys)==2)
+ k1, k2 = keys
+ argtype = type(config[k1][k2])
+ if argtype == bool:
+ v = unknown[idx+1].lower() == 'true'
+ else:
+ if config[k1][k2] is not None:
+ v = type(config[k1][k2])(unknown[idx+1])
+ else:
+ v = unknown[idx+1]
+ print(f'Changing {k1}:{k2} ---- {config[k1][k2]} to {v}')
+ config[k1][k2] = v
+ return config
+
+def initialize_logger(cfg):
+ out_dir = cfg['train']['out_dir']
+ if not out_dir:
+ os.makedirs(out_dir)
+
+ cfg['train']['dir_model'] = os.path.join(out_dir, 'model')
+ os.makedirs(cfg['train']['dir_model'], exist_ok=True)
+
+ if cfg['train']['exp_mesh']:
+ cfg['train']['dir_mesh'] = os.path.join(out_dir, 'vis/mesh')
+ os.makedirs(cfg['train']['dir_mesh'], exist_ok=True)
+ if cfg['train']['exp_pcl']:
+ cfg['train']['dir_pcl'] = os.path.join(out_dir, 'vis/pointcloud')
+ os.makedirs(cfg['train']['dir_pcl'], exist_ok=True)
+ if cfg['train']['vis_rendering']:
+ cfg['train']['dir_rendering'] = os.path.join(out_dir, 'vis/rendering')
+ os.makedirs(cfg['train']['dir_rendering'], exist_ok=True)
+ if cfg['train']['o3d_show']:
+ cfg['train']['dir_o3d'] = os.path.join(out_dir, 'vis/o3d')
+ os.makedirs(cfg['train']['dir_o3d'], exist_ok=True)
+
+
+ logger = logging.getLogger("train")
+ logger.setLevel(logging.DEBUG)
+ logger.handlers = []
+ # ch = logging.StreamHandler()
+ # logger.addHandler(ch)
+ fh = logging.FileHandler(os.path.join(cfg['train']['out_dir'], "log.txt"))
+ logger.addHandler(fh)
+ logger.info('Outout dir: %s', out_dir)
+ return logger
+
+def update_recursive(dict1, dict2):
+ ''' Update two config dictionaries recursively.
+
+ Args:
+ dict1 (dict): first dictionary to be updated
+ dict2 (dict): second dictionary which entries should be used
+
+ '''
+ for k, v in dict2.items():
+ if k not in dict1:
+ dict1[k] = dict()
+ if isinstance(v, dict):
+ update_recursive(dict1[k], v)
+ else:
+ dict1[k] = v
+
+def export_pointcloud(name, points, normals=None):
+ if len(points.shape) > 2:
+ points = points[0]
+ if normals is not None:
+ normals = normals[0]
+ if isinstance(points, torch.Tensor):
+ points = points.detach().cpu().numpy()
+ if normals is not None:
+ normals = normals.detach().cpu().numpy()
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points)
+ if normals is not None:
+ pcd.normals = o3d.utility.Vector3dVector(normals)
+ o3d.io.write_point_cloud(name, pcd)
+
+def export_mesh(name, v, f):
+ if len(v.shape) > 2:
+ v, f = v[0], f[0]
+ if isinstance(v, torch.Tensor):
+ v = v.detach().cpu().numpy()
+ f = f.detach().cpu().numpy()
+ mesh = o3d.geometry.TriangleMesh()
+ mesh.vertices = o3d.utility.Vector3dVector(v)
+ mesh.triangles = o3d.utility.Vector3iVector(f)
+ o3d.io.write_triangle_mesh(name, mesh)
+
+def scale2onet(p, scale=1.2):
+ '''
+ Scale the point cloud from SAP to ONet range
+ '''
+ return (p - 0.5) * scale
+
+def update_optimizer(inputs, cfg, epoch, model=None, schedule=None):
+ if model is not None:
+ if schedule is not None:
+ optimizer = torch.optim.Adam([
+ {"params": model.parameters(),
+ "lr": schedule[0].get_learning_rate(epoch)},
+ {"params": inputs,
+ "lr": schedule[1].get_learning_rate(epoch)}])
+ elif 'lr' in cfg['train']:
+ optimizer = torch.optim.Adam([
+ {"params": model.parameters(),
+ "lr": float(cfg['train']['lr'])},
+ {"params": inputs,
+ "lr": float(cfg['train']['lr_pcl'])}])
+ else:
+ raise Exception('no known learning rate')
+ else:
+ if schedule is not None:
+ optimizer = torch.optim.Adam([inputs], lr=schedule[0].get_learning_rate(epoch))
+ else:
+ optimizer = torch.optim.Adam([inputs], lr=float(cfg['train']['lr_pcl']))
+
+ return optimizer
+
+
+def is_url(url):
+ scheme = urllib.parse.urlparse(url).scheme
+ return scheme in ('http', 'https')
+
+def load_url(url):
+ '''Load a module dictionary from url.
+
+ Args:
+ url (str): url to saved model
+ '''
+ print(url)
+ print('=> Loading checkpoint from url...')
+ state_dict = model_zoo.load_url(url, progress=True)
+
+ return state_dict
+
+
+class GaussianSmoothing(nn.Module):
+ """
+ Apply gaussian smoothing on a
+ 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
+ in the input using a depthwise convolution.
+ Arguments:
+ channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well.
+ kernel_size (int, sequence): Size of the gaussian kernel.
+ sigma (float, sequence): Standard deviation of the gaussian kernel.
+ dim (int, optional): The number of dimensions of the data.
+ Default value is 2 (spatial).
+ """
+ def __init__(self, channels, kernel_size, sigma, dim=3):
+ super(GaussianSmoothing, self).__init__()
+ if isinstance(kernel_size, numbers.Number):
+ kernel_size = [kernel_size] * dim
+ if isinstance(sigma, numbers.Number):
+ sigma = [sigma] * dim
+
+ # The gaussian kernel is the product of the
+ # gaussian function of each dimension.
+ kernel = 1
+ meshgrids = torch.meshgrid(
+ [
+ torch.arange(size, dtype=torch.float32)
+ for size in kernel_size
+ ]
+ )
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
+ mean = (size - 1) / 2
+ kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
+ torch.exp(-((mgrid - mean) / std) ** 2 / 2)
+
+ # Make sure sum of values in gaussian kernel equals 1.
+ kernel = kernel / torch.sum(kernel)
+
+ # Reshape to depthwise convolutional weight
+ kernel = kernel.view(1, 1, *kernel.size())
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
+
+ self.register_buffer('weight', kernel)
+ self.groups = channels
+
+ if dim == 1:
+ self.conv = F.conv1d
+ elif dim == 2:
+ self.conv = F.conv2d
+ elif dim == 3:
+ self.conv = F.conv3d
+ else:
+ raise RuntimeError(
+ 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
+ )
+
+ def forward(self, input):
+ """
+ Apply gaussian filter to input.
+ Arguments:
+ input (torch.Tensor): Input to apply gaussian filter on.
+ Returns:
+ filtered (torch.Tensor): Filtered output.
+ """
+ return self.conv(input, weight=self.weight, groups=self.groups)
+
+# Originally from https://github.com/amosgropp/IGR/blob/0db06b1273/code/utils/general.py
+def get_learning_rate_schedules(schedule_specs):
+
+ schedules = []
+
+ for key in schedule_specs.keys():
+ schedules.append(StepLearningRateSchedule(
+ schedule_specs[key]['initial'],
+ schedule_specs[key]["interval"],
+ schedule_specs[key]["factor"],
+ schedule_specs[key]["final"]))
+ return schedules
+
+class LearningRateSchedule:
+ def get_learning_rate(self, epoch):
+ pass
+class StepLearningRateSchedule(LearningRateSchedule):
+ def __init__(self, initial, interval, factor, final=1e-6):
+ self.initial = float(initial)
+ self.interval = interval
+ self.factor = factor
+ self.final = float(final)
+
+ def get_learning_rate(self, epoch):
+ lr = np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6)
+ if lr > self.final:
+ return lr
+ else:
+ return self.final
+
+def adjust_learning_rate(lr_schedules, optimizer, epoch):
+ for i, param_group in enumerate(optimizer.param_groups):
+ param_group["lr"] = lr_schedules[i].get_learning_rate(epoch)
diff --git a/src/visualize.py b/src/visualize.py
new file mode 100644
index 0000000..0eec2da
--- /dev/null
+++ b/src/visualize.py
@@ -0,0 +1,175 @@
+import os
+import torch
+import numpy as np
+import open3d as o3d
+import matplotlib.pyplot as plt
+from skimage import measure
+from src.utils import calc_inters_points, grid_interp
+from scipy import ndimage
+from tqdm import trange
+from torchvision.utils import save_image
+from pdb import set_trace as st
+
+
+def visualize_points_mesh(vis, points, normals, verts, faces, cfg, it, epoch, color_v=None):
+ ''' Visualization.
+
+ Args:
+ data (dict): data dictionary
+ depth (int): PSR depth
+ out_path (str): output path for the mesh
+ '''
+ mesh = o3d.geometry.TriangleMesh()
+ mesh.vertices = o3d.utility.Vector3dVector(verts)
+ mesh.triangles = o3d.utility.Vector3iVector(faces)
+ mesh.paint_uniform_color(np.array([0.7,0.7,0.7]))
+ if color_v is not None:
+ mesh.vertex_colors = o3d.utility.Vector3dVector(color_v)
+
+ if vis is not None:
+ dir_o3d = cfg['train']['dir_o3d']
+ wire = o3d.geometry.LineSet.create_from_triangle_mesh(mesh)
+
+ p = points.squeeze(0).detach().cpu().numpy()
+ n = normals.squeeze(0).detach().cpu().numpy()
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(p)
+ pcd.normals = o3d.utility.Vector3dVector(n)
+ pcd.paint_uniform_color(np.array([0.7,0.7,1.0]))
+ # pcd = pcd.uniform_down_sample(5)
+
+ vis.clear_geometries()
+ vis.add_geometry(mesh)
+ vis.update_geometry(mesh)
+
+ #! Thingi wheel - an example for how to change cameras in Open3D viewers
+ vis.get_view_control().set_front([ 0.0461, -0.7467, 0.6635 ])
+ vis.get_view_control().set_lookat([ 0.0092, 0.0078, 0.0638 ])
+ vis.get_view_control().set_up([ 0.0520, 0.6651, 0.7449 ])
+ vis.get_view_control().set_zoom(0.7)
+ vis.poll_events()
+
+ out_path = os.path.join(dir_o3d, '{}.jpg'.format(it))
+ vis.capture_screen_image(out_path)
+
+ vis.clear_geometries()
+ vis.add_geometry(pcd, reset_bounding_box=False)
+ vis.update_geometry(pcd)
+ vis.get_render_option().point_show_normal=True # visualize point normals
+
+ vis.get_view_control().set_front([ 0.0461, -0.7467, 0.6635 ])
+ vis.get_view_control().set_lookat([ 0.0092, 0.0078, 0.0638 ])
+ vis.get_view_control().set_up([ 0.0520, 0.6651, 0.7449 ])
+ vis.get_view_control().set_zoom(0.7)
+ vis.poll_events()
+
+ out_path = os.path.join(dir_o3d, '{}_pcd.jpg'.format(it))
+ vis.capture_screen_image(out_path)
+
+def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name='video.mp4'):
+ if pose is not None:
+ device = psr_grid.device
+ # get world coordinate of grid points [-1, 1]
+ res = psr_grid.shape[-1]
+ x = torch.linspace(-1, 1, steps=res)
+ co_x, co_y, co_z = torch.meshgrid(x, x, x)
+ co_grid = torch.stack(
+ [co_x.reshape(-1), co_y.reshape(-1), co_z.reshape(-1)],
+ dim=1).to(device).unsqueeze(0)
+
+ # visualize the projected occ_soft value
+ res = 128
+ psr_grid = psr_grid.reshape(-1)
+ out_mask = psr_grid>0
+ in_mask = psr_grid<0
+ pix = pose.transform_points_screen(co_grid, ((res, res),))[..., :2].round().long().squeeze()
+ vis_mask = (pix[..., 0]>=0) & (pix[..., 0]<=res-1) & \
+ (pix[..., 1]>=0) & (pix[..., 1]<=res-1)
+ pix_out = pix[vis_mask & out_mask]
+ pix_in = pix[vis_mask & in_mask]
+
+ img = torch.ones([res,res]).to(device)
+ psr_grid = torch.sigmoid(- psr_grid * 5)
+ img[pix_out[:, 1], pix_out[:, 0]] = psr_grid[vis_mask & out_mask]
+ img[pix_in[:, 1], pix_in[:, 0]] = psr_grid[vis_mask & in_mask]
+ # save_image(img, 'tmp.png', normalize=True)
+ return img
+ elif out_dir is not None:
+ dir_psr_vis = out_dir
+ os.makedirs(dir_psr_vis, exist_ok=True)
+ psr_grid = psr_grid.squeeze().detach().cpu().numpy()
+ axis = ['x', 'y', 'z']
+ s = psr_grid.shape[0]
+ for idx in trange(s):
+ my_dpi = 100
+ plt.figure(figsize=(1000/my_dpi, 300/my_dpi), dpi=my_dpi)
+ plt.subplot(1, 3, 1)
+ plt.imshow(ndimage.rotate(psr_grid[idx], 180, mode='nearest'), cmap='nipy_spectral')
+ plt.clim(-1, 1)
+ plt.colorbar()
+ plt.title('x')
+ plt.grid("off")
+ plt.axis("off")
+
+ plt.subplot(1, 3, 2)
+ plt.imshow(ndimage.rotate(psr_grid[:, idx], 180, mode='nearest'), cmap='nipy_spectral')
+ plt.clim(-1, 1)
+ plt.colorbar()
+ plt.title('y')
+ plt.grid("off")
+ plt.axis("off")
+
+ plt.subplot(1, 3, 3)
+ plt.imshow(ndimage.rotate(psr_grid[:,:,idx], 90, mode='nearest'), cmap='nipy_spectral')
+ plt.clim(-1, 1)
+ plt.colorbar()
+ plt.title('z')
+ plt.grid("off")
+ plt.axis("off")
+
+
+ plt.savefig(os.path.join(dir_psr_vis, '{}'.format(idx)), pad_inches = 0, dpi=100)
+ plt.close()
+ os.system("rm {}/{}".format(dir_psr_vis, out_video_name))
+ os.system("ffmpeg -framerate 25 -start_number 0 -i {}/%d.png -pix_fmt yuv420p -crf 17 {}/{}".format(dir_psr_vis, dir_psr_vis, out_video_name))
+
+def visualize_mesh_phong(v, f, n, pose, img_size, name, device='cpu'):
+ #! Mesh rendering using Phong shading model
+ _, mask, f_p, w = calc_inters_points(v, f, pose, img_size)
+ n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]]
+ n_inters = w[..., 0, None] * n_a.squeeze() + \
+ w[..., 1, None] * n_b.squeeze() + \
+ w[..., 2, None] * n_c.squeeze()
+ n_inters = n_inters.detach().to(device)
+ light_source = -pose.R@pose.T.squeeze()
+ light = (light_source / light_source.norm(2)).permute(1, 0).to(device).float()
+ diffuse_per = torch.Tensor([0.7,0.7,0.7]).float()
+ ambiant = torch.Tensor([0.3,0.3,0.3]).float()
+
+ diffuse = torch.mm(n_inters, light).clamp_min(0).repeat(1, 3) * diffuse_per.unsqueeze(0).to(device)
+
+ phong = torch.ones([img_size[0]*img_size[1], 3]).to(device)
+ phong[mask] = (ambiant.unsqueeze(0).to(device) + diffuse).clamp_max(1.0)
+ pp = phong.reshape(img_size[0], img_size[1], -1)
+ save_image(pp.permute(2, 0, 1), name)
+
+def render_rgb(v, f, n, pose, renderer, img_size, mask_gt=None, ray=None, fea_grid=None):
+ p_inters, mask, f_p, w = calc_inters_points(v.detach(), f, pose, img_size, mask_gt=mask_gt)
+ # normals for p_inters
+ n_inters = None
+ if n is not None:
+ n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]]
+ n_inters = w[..., 0, None] * n_a.squeeze() + \
+ w[..., 1, None] * n_b.squeeze() + \
+ w[..., 2, None] * n_c.squeeze()
+ if ray is not None:
+ ray = ray.squeeze()[mask]
+
+ fea = None
+ if fea_grid is not None:
+ fea = grid_interp(fea_grid, (p_inters.detach()[None] + 1) / 2).squeeze()
+
+ # use MLP to regress color
+ color_pred = renderer(p_inters, normals=n_inters, view_dirs=ray, feature_vectors=fea).squeeze()
+
+ return color_pred, mask
\ No newline at end of file
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..2bc08f8
--- /dev/null
+++ b/train.py
@@ -0,0 +1,185 @@
+import os
+
+abspath = os.path.abspath(__file__)
+dname = os.path.dirname(abspath)
+os.chdir(dname)
+
+import torch
+import torch.optim as optim
+import open3d as o3d
+
+import numpy as np; np.set_printoptions(precision=4)
+import shutil, argparse, time
+from torch.utils.tensorboard import SummaryWriter
+
+from src import config
+from src.data import collate_remove_none, collate_stack_together, worker_init_fn
+from src.training import Trainer
+from src.model import Encode2Points
+from src.utils import load_config, initialize_logger, \
+AverageMeter, load_model_manual
+
+
+def main():
+ parser = argparse.ArgumentParser(description='MNIST toy experiment')
+ parser.add_argument('config', type=str, help='Path to config file.')
+ parser.add_argument('--no_cuda', action='store_true', default=False,
+ help='disables CUDA training')
+ parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
+
+ args = parser.parse_args()
+ cfg = load_config(args.config, 'configs/default.yaml')
+ use_cuda = not args.no_cuda and torch.cuda.is_available()
+ device = torch.device("cuda" if use_cuda else "cpu")
+ input_type = cfg['data']['input_type']
+ batch_size = cfg['train']['batch_size']
+ model_selection_metric = cfg['train']['model_selection_metric']
+
+ # PYTORCH VERSION > 1.0.0
+ assert(float(torch.__version__.split('.')[-3]) > 0)
+
+ # boiler-plate
+ if cfg['train']['timestamp']:
+ cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S")
+ logger = initialize_logger(cfg)
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ shutil.copyfile(args.config, os.path.join(cfg['train']['out_dir'], 'config.yaml'))
+
+ logger.info("using GPU: " + torch.cuda.get_device_name(0))
+
+ # TensorboardX writer
+ tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log")
+ if not os.path.exists(tblogdir):
+ os.makedirs(tblogdir, exist_ok=True)
+ writer = SummaryWriter(log_dir=tblogdir)
+
+
+ inputs = None
+ train_dataset = config.get_dataset('train', cfg)
+ val_dataset = config.get_dataset('val', cfg)
+ vis_dataset = config.get_dataset('vis', cfg)
+
+
+ collate_fn = collate_remove_none
+
+ train_loader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=batch_size, num_workers=cfg['train']['n_workers'], shuffle=True,
+ collate_fn=collate_fn,
+ worker_init_fn=worker_init_fn)
+
+ val_loader = torch.utils.data.DataLoader(
+ val_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False,
+ collate_fn=collate_remove_none,
+ worker_init_fn=worker_init_fn)
+
+ vis_loader = torch.utils.data.DataLoader(
+ vis_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False,
+ collate_fn=collate_fn,
+ worker_init_fn=worker_init_fn)
+
+ if torch.cuda.device_count() > 1:
+ model = torch.nn.DataParallel(Encode2Points(cfg)).to(device)
+ else:
+ model = Encode2Points(cfg).to(device)
+
+ n_parameter = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ logger.info('Number of parameters: %d'% n_parameter)
+ # load model
+ try:
+ # load model
+ state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt'))
+ load_model_manual(state_dict['state_dict'], model)
+
+ out = "Load model from iteration %d" % state_dict.get('it', 0)
+ logger.info(out)
+ # load point cloud
+ except:
+ state_dict = dict()
+
+ metric_val_best = state_dict.get(
+ 'loss_val_best', np.inf)
+
+ logger.info('Current best validation metric (%s): %.8f'
+ % (model_selection_metric, metric_val_best))
+
+ LR = float(cfg['train']['lr'])
+ optimizer = optim.Adam(model.parameters(), lr=LR)
+
+ start_epoch = state_dict.get('epoch', -1)
+ it = state_dict.get('it', -1)
+
+ trainer = Trainer(cfg, optimizer, device=device)
+ runtime = {}
+ runtime['all'] = AverageMeter()
+
+ # training loop
+ for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1):
+
+ for batch in train_loader:
+ it += 1
+
+ start = time.time()
+ loss, loss_each = trainer.train_step(inputs, batch, model)
+
+ # measure elapsed time
+ end = time.time()
+ runtime['all'].update(end - start)
+
+ if it % cfg['train']['print_every'] == 0:
+ log_text = ('[Epoch %02d] it=%d, loss=%.4f') %(epoch, it, loss)
+ writer.add_scalar('train/loss', loss, it)
+ if loss_each is not None:
+ for k, l in loss_each.items():
+ if l.item() != 0.:
+ log_text += (' loss_%s=%.4f') % (k, l.item())
+ writer.add_scalar('train/%s' % k, l, it)
+
+ log_text += (' time=%.3f / %.2f') % (runtime['all'].val, runtime['all'].sum)
+ logger.info(log_text)
+
+ if (it>0)& (it % cfg['train']['visualize_every'] == 0):
+ for i, batch_vis in enumerate(vis_loader):
+ trainer.save(model, batch_vis, it, i)
+ if i >= 4:
+ break
+ logger.info('Saved mesh and pointcloud')
+
+ # run validation
+ if it > 0 and (it % cfg['train']['validate_every']) == 0:
+ eval_dict = trainer.evaluate(val_loader, model)
+ metric_val = eval_dict[model_selection_metric]
+ logger.info('Validation metric (%s): %.4f'
+ % (model_selection_metric, metric_val))
+
+ for k, v in eval_dict.items():
+ writer.add_scalar('val/%s' % k, v, it)
+
+ if -(metric_val - metric_val_best) >= 0:
+ metric_val_best = metric_val
+ logger.info('New best model (loss %.4f)' % metric_val_best)
+ state = {'epoch': epoch,
+ 'it': it,
+ 'loss_val_best': metric_val_best}
+ state['state_dict'] = model.state_dict()
+ torch.save(state, os.path.join(cfg['train']['out_dir'], 'model_best.pt'))
+
+ # save checkpoint
+ if (epoch > 0) & (it % cfg['train']['checkpoint_every'] == 0):
+ state = {'epoch': epoch,
+ 'it': it,
+ 'loss_val_best': metric_val_best}
+ pcl = None
+ state['state_dict'] = model.state_dict()
+
+ torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt'))
+
+ if (it % cfg['train']['backup_every'] == 0):
+ torch.save(state, os.path.join(cfg['train']['dir_model'], '%04d' % it + '.pt'))
+ logger.info("Backup model at iteration %d" % it)
+ logger.info("Save new model at iteration %d" % it)
+
+ done=time.time()
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file