diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d60f8d6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,150 @@ +/out +/data +.vscode +.cache +*.pyc +*.pyd +*.pt +*.so +*.o +*.prof +*.swp +*.lib +*.obj +*.exp +.nfs* +*.jpg +*.png +*.ply +*.off +*.npz +*.txt +# *.sh + + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/README.md b/README.md index b2e3392..e7bf941 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # Shape As Points (SAP) -[**Paper**](https://arxiv.org/abs/2106.03452) | [**Project Page**](https://pengsongyou.github.io/sap) | [**Short Video (6 min)**](https://youtu.be/FL8LMk_qWb4) | [**Long Video (13 min)**](https://youtu.be/TgR0NvYty0A)
+ +### [**Paper**](https://arxiv.org/abs/2106.03452) | [**Project Page**](https://pengsongyou.github.io/sap) | [**Short Video (6 min)**](https://youtu.be/FL8LMk_qWb4) | [**Long Video (12 min)**](https://youtu.be/TgR0NvYty0A)
![](./media/teaser_wheel.gif) @@ -10,13 +11,164 @@ Shape As Points: A Differentiable Poisson Solver **NeurIPS 2021 (Oral)** -## Code is coming soon! - If you find our code or paper useful, please consider citing ```bibtex @inproceedings{Peng2021SAP, - author = {Peng, Songyou and Jiang, Chiyu "Max" and Liao, Yiyi and Niemeyer, Michael and Pollefeys, Marc and Andreas, Geiger}, + author = {Peng, Songyou and Jiang, Chiyu "Max" and Liao, Yiyi and Niemeyer, Michael and Pollefeys, Marc and Geiger, Andreas}, title = {Shape As Points: A Differentiable Poisson Solver}, booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, year = {2021}} ``` + + +## Installation +First you have to make sure that you have all dependencies in place. +The simplest way to do so, is to use [anaconda](https://www.anaconda.com/). + +You can create an anaconda environment called `sap` using +``` +conda env create -f environment.yaml +conda activate sap +``` + +Now, you can install [PyTorch3D](https://pytorch3d.org/) 0.6.0 from the [official instruction](https://github.com/facebookresearch/pytorch3d/blob/master/INSTALL.md#3-install-wheels-for-linux) as follows +```sh +pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu102_pyt190/download.html +``` +And install [PyTorch Scatter](https://github.com/rusty1s/pytorch_scatter): +```sh +conda install pytorch-scatter -c pyg +``` + + +## Demo - Quick Start + +First, run the script to get the demo data: + +```bash +bash scripts/download_demo_data.sh +``` + +### Optimization-based 3D Surface Reconstruction + +You can now quickly test our code on the data shown in the teaser. To this end, simply run: + +```python +python optim_hierarchy.py configs/optim_based/teaser.yaml +``` +This script should create a folder `out/demo_optim` where the output meshes and the optimized oriented point clouds under different grid resolution are stored. + +To visualize the optimization process on the fly, you can set `o3d_show: Frue` in [`configs/optim_based/teaser.yaml`](https://github.com/autonomousvision/shape_as_points/tree/main/configs/optim_based/teaser.yaml). + +### Learning-based 3D Surface Reconstruction +You can also test SAP on another application where we can reconstruct from unoriented point clouds with either **large noises** or **outliers** with a learned network. + +![](./media/results_large_noise.gif) + +For the point clouds with large noise as shown above, you can run: +```python +python generate.py configs/learning_based/demo_large_noise.yaml +``` +The results can been found at `out/demo_shapenet_large_noise/generation/vis`. + +![](./media/results_outliers.gif) +As for the point clouds with outliers, you can run: +```python +python generate.py configs/learning_based/demo_outlier.yaml +``` +You can find the reconstrution on `out/demo_shapenet_outlier/generation/vis`. + + +## Dataset + +We have different dataset for our optimization-based and learning-based settings. + +### Dataset for Optimization-based Reconstruction +Here we consider the following dataset: +- [Thingi10K](https://arxiv.org/abs/1605.04797) (synthetic) +- [Surface Reconstruction Benchmark (SRB)](https://github.com/fwilliams/deep-geometric-prior) (real scans) +- [MPI Dynamic FAUST](https://dfaust.is.tue.mpg.de/) (real scans) + +Please cite the corresponding papers if you use the data. + +You can download the processed dataset (~200 MB) by running: +```bash +bash scripts/download_optim_data.sh +``` + +### Dataset for Learning-based Reconstruction +We train and evaluate on [ShapeNet](https://shapenet.org/). +You can download the processed dataset (~220 GB) by running: +```bash +bash scripts/download_shapenet.sh +``` +After, you should have the dataset in `data/shapenet_psr` folder. + +Alternatively, you can also preprocess the dataset yourself. To this end, you can: +* first download the preprocessed dataset (73.4 GB) by running [the script](https://github.com/autonomousvision/occupancy_networks#preprocessed-data) from Occupancy Networks. +* check [`scripts/process_shapenet.py`](https://github.com/autonomousvision/shape_as_points/tree/main/scripts/process_shapenet.py), modify the base path and run the code + + +## Usage for Optimization-based 3D Reconstruction + +For our optimization-based setting, you can consider running with a coarse-to-fine strategy: +```python +python optim_hierarchy.py configs/optim_based/CONFIG.yaml +``` +We start from a grid resolution of 32^3, and increase to 64^3, 128^3 and finally 256^3. + +Alternatively, you can also run on a single resolution with: + +```python +python optim.py configs/optim_based/CONFIG.yaml +``` +You might need to modify the `CONFIG.yaml` accordingly. + +## Usage for Learning-based 3D Reconstruction + +### Mesh Generation +To generate meshes using a trained model, use +```python +python generate.py configs/learning_based/CONFIG.yaml +``` +where you replace `CONFIG.yaml` with the correct config file. + +#### Use a pre-trained model +The easiest way is to use a pre-trained model. You can do this by using one of the config files with postfix `_pretrained`. + +For example, for 3D reconstruction from point clouds with outliers using our model with 7x offsets, you can simply run: +```python +python generate.py configs/learning_based/outlier/ours_7x_pretrained.yaml +``` + +The script will automatically download the pretrained model and run the generation. You can find the outputs in the `out/.../generation_pretrained` folders. + +**Note** config files are only for generation, not for training new models: when these configs are used for training, the model will be trained from scratch, but during inference our code will still use the pretrained model. + +We provide the following pretrained models: +``` +noise_small/ours.pt +noise_large/ours.pt +outlier/ours_1x.pt +outlier/ours_3x.pt +outlier/ours_5x.pt +outlier/ours_7x.pt +outlier/ours_3plane.pt +``` + + +### Evaluation +To evaluate a trained model, we provide the script [`eval_meshes.py`](https://github.com/autonomousvision/shape_as_points/blob/main/eval_meshes.py). You can run it using: +```python +python eval_meshes.py configs/learning_based/CONFIG.yaml +``` +The script takes the meshes generated in the previous step and evaluates them using a standardized protocol. The output will be written to `.pkl` and `.csv` files in the corresponding generation folder that can be processed using [pandas](https://pandas.pydata.org/). + +### Training + +Finally, to train a new network from scratch, simply run: +```python +python train.py configs/learning_based/CONFIG.yaml +``` +For available training options, please take a look at `configs/default.yaml`. + diff --git a/configs/default.yaml b/configs/default.yaml new file mode 100644 index 0000000..b1205ce --- /dev/null +++ b/configs/default.yaml @@ -0,0 +1,104 @@ +data: + dataset: Shapes3D + path: data/ShapeNet + class: null + data_type: img + input_type: pointcloud + dim: 3 + num_points: 1000 + num_gt_points: 1000 + num_offset: 1 + img_size: null + n_views_input: 20 + n_views_per_iter: 2 + pointcloud_noise: 0 + pointcloud_file: pointcloud.npz + pointcloud_outlier_ratio: 0 + fixed_scale: 0 + train_split: train + val_split: val + test_split: test + points_file: null + points_iou_file: points.npz + points_unpackbits: true + padding: 0.1 + multi_files: null + gt_mesh: null + zero_level: 0 + only_single: False + sample_only_floor: False +model: + apply_sigmoid: True + grid_res: 128 # poisson grid resolution + psr_sigma: 0 + psr_tanh: False + normal_normalize: False + raster: {} + renderer: {} + encoder: null + predict_normal: True + predict_offset: True + s_offset: 0.001 + local_coord: True + encoder_kwargs: {} + unet3d: False + unet3d_kwargs: {} + multi_gpu: false + rotate_matrix: false + c_dim: 512 + sphere_radius: 0.2 +train: + lr: 1e-3 + lr_pcl: 2e-2 + input_mesh: '' + out_dir: out/default + subsample_vertex: False + batch_size: 1 + n_grow_points: 0 + resample_every: 0 + l_weight: {} + w_reg_point: 0 + w_psr: 0 + w_raw: 0 # train with raw point cloud + gauss_weight: 0 + n_sup_point: 0 + w_normals: 0 + total_epochs: 1000 + print_every: 20 + visualize_every: 1000 + save_every: 1000 + vis_vert_color: True + o3d_show: False + o3d_vis_pcl: True + o3d_window_size: 540 + vis_rendering: False + vis_psr: False + save_video: False + exp_mesh: True + exp_pcl: True + checkpoint_every: 1000 + validate_every: 2000000 + backup_every: 100000 + timestamp: False # add timestamp to out_dir name + model_selection_metric: loss + model_selection_mode: minimize + n_workers: 0 + n_workers_val: 0 +test: + threshold: 0.5 + eval_mesh: true + eval_pointcloud: false + model_file: model_best.pt +generation: + batch_size: 100000 + exp_gt: False + exp_oracle: false + exp_input: False + vis_n_outputs: 10 + generate_mesh: true + generate_pointcloud: true + generation_dir: generation + copy_input: true + use_sampling: false + psr_resolution: 0 + psr_sigma: 0 diff --git a/configs/learning_based/demo_large_noise.yaml b/configs/learning_based/demo_large_noise.yaml new file mode 100644 index 0000000..cfb7da2 --- /dev/null +++ b/configs/learning_based/demo_large_noise.yaml @@ -0,0 +1,9 @@ + +inherit_from: configs/learning_based/noise_large/ours.yaml +data: + class: [''] + path: data/demo/shapenet_chair +train: + out_dir: out/demo_shapenet_large_noise +test: + model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_noise_025.pt \ No newline at end of file diff --git a/configs/learning_based/demo_outlier.yaml b/configs/learning_based/demo_outlier.yaml new file mode 100644 index 0000000..246ca49 --- /dev/null +++ b/configs/learning_based/demo_outlier.yaml @@ -0,0 +1,9 @@ + +inherit_from: configs/learning_based/outlier/ours_7x.yaml +data: + class: [''] + path: data/demo/shapenet_lamp +train: + out_dir: out/demo_shapenet_outlier +test: + model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_7x.pt \ No newline at end of file diff --git a/configs/learning_based/noise_large/ours.yaml b/configs/learning_based/noise_large/ours.yaml new file mode 100644 index 0000000..2c8eac1 --- /dev/null +++ b/configs/learning_based/noise_large/ours.yaml @@ -0,0 +1,54 @@ +data: + class: null + data_type: psr_full + input_type: pointcloud + path: data/shapenet_psr + num_gt_points: 10000 + num_offset: 7 + pointcloud_n: 3000 + pointcloud_noise: 0.025 +model: + grid_res: 128 # poisson grid resolution + psr_sigma: 2 + psr_tanh: True + normal_normalize: False + predict_normal: True + predict_offset: True + c_dim: 32 + s_offset: 0.001 + encoder: local_pool_pointnet + encoder_kwargs: + hidden_dim: 32 + plane_type: 'grid' + grid_resolution: 32 + unet3d: True + unet3d_kwargs: + num_levels: 3 + f_maps: 32 + in_channels: 32 + out_channels: 32 + decoder: simple_local + decoder_kwargs: + sample_mode: bilinear # bilinear / nearest + hidden_size: 32 +train: + batch_size: 32 + lr: 5e-4 + out_dir: out/shapenet/noise_025_ours + w_psr: 1 + model_selection_metric: psr_l2 + print_every: 100 + checkpoint_every: 200 + validate_every: 5000 + backup_every: 10000 + total_epochs: 400000 + visualize_every: 5000 + exp_pcl: True + exp_mesh: True + n_workers: 8 + n_workers_val: 0 +generation: + exp_gt: False + exp_input: True + psr_resolution: 128 + psr_sigma: 2 diff --git a/configs/learning_based/noise_large/ours_pretrained.yaml b/configs/learning_based/noise_large/ours_pretrained.yaml new file mode 100644 index 0000000..bec016c --- /dev/null +++ b/configs/learning_based/noise_large/ours_pretrained.yaml @@ -0,0 +1,5 @@ +inherit_from: configs/learning_based/noise_large/ours.yaml +generation: + generation_dir: generation_pretrained +test: + model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_noise_025.pt \ No newline at end of file diff --git a/configs/learning_based/noise_small/ours.yaml b/configs/learning_based/noise_small/ours.yaml new file mode 100644 index 0000000..3b1eb3a --- /dev/null +++ b/configs/learning_based/noise_small/ours.yaml @@ -0,0 +1,54 @@ +data: + class: null + data_type: psr_full + input_type: pointcloud + path: data/shapenet_psr + num_gt_points: 10000 + num_offset: 7 + pointcloud_n: 3000 + pointcloud_noise: 0.005 +model: + grid_res: 128 # poisson grid resolution + psr_sigma: 2 + psr_tanh: True + normal_normalize: False + predict_normal: True + predict_offset: True + c_dim: 32 + s_offset: 0.001 + encoder: local_pool_pointnet + encoder_kwargs: + hidden_dim: 32 + plane_type: 'grid' + grid_resolution: 32 + unet3d: True + unet3d_kwargs: + num_levels: 3 + f_maps: 32 + in_channels: 32 + out_channels: 32 + decoder: simple_local + decoder_kwargs: + sample_mode: bilinear # bilinear / nearest + hidden_size: 32 +train: + batch_size: 32 + lr: 5e-4 + out_dir: out/shapenet/noise_005_ours + w_psr: 1 + model_selection_metric: psr_l2 + print_every: 100 + checkpoint_every: 200 + validate_every: 5000 + backup_every: 10000 + total_epochs: 400000 + visualize_every: 5000 + exp_pcl: True + exp_mesh: True + n_workers: 8 + n_workers_val: 0 +generation: + exp_gt: False + exp_input: True + psr_resolution: 128 + psr_sigma: 2 diff --git a/configs/learning_based/noise_small/ours_pretrained.yaml b/configs/learning_based/noise_small/ours_pretrained.yaml new file mode 100644 index 0000000..fc3317e --- /dev/null +++ b/configs/learning_based/noise_small/ours_pretrained.yaml @@ -0,0 +1,5 @@ +inherit_from: configs/learning_based/noise_small/ours.yaml +generation: + generation_dir: generation_pretrained +test: + model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_noise_005.pt \ No newline at end of file diff --git a/configs/learning_based/outlier/ours_1x.yaml b/configs/learning_based/outlier/ours_1x.yaml new file mode 100644 index 0000000..c973789 --- /dev/null +++ b/configs/learning_based/outlier/ours_1x.yaml @@ -0,0 +1,55 @@ +data: + class: null + data_type: psr_full + input_type: pointcloud + path: data/shapenet_psr + num_gt_points: 10000 + num_offset: 1 + pointcloud_n: 3000 + pointcloud_noise: 0.005 + pointcloud_outlier_ratio: 0.5 +model: + grid_res: 128 # poisson grid resolution + psr_sigma: 2 + psr_tanh: True + normal_normalize: False + predict_normal: True + predict_offset: True + c_dim: 32 + s_offset: 0.001 + encoder: local_pool_pointnet + encoder_kwargs: + hidden_dim: 32 + plane_type: 'grid' + grid_resolution: 32 + unet3d: True + unet3d_kwargs: + num_levels: 3 + f_maps: 32 + in_channels: 32 + out_channels: 32 + decoder: simple_local + decoder_kwargs: + sample_mode: bilinear # bilinear / nearest + hidden_size: 32 +train: + batch_size: 32 + lr: 5e-4 + out_dir: out/shapenet/outlier_ours_1x + w_psr: 1 + model_selection_metric: psr_l2 + print_every: 100 + checkpoint_every: 200 + validate_every: 5000 + backup_every: 10000 + total_epochs: 400000 + visualize_every: 5000 + exp_pcl: True + exp_mesh: True + n_workers: 8 + n_workers_val: 0 +generation: + exp_gt: False + exp_input: True + psr_resolution: 128 + psr_sigma: 2 diff --git a/configs/learning_based/outlier/ours_1x_pretrained.yaml b/configs/learning_based/outlier/ours_1x_pretrained.yaml new file mode 100644 index 0000000..5aaf6b5 --- /dev/null +++ b/configs/learning_based/outlier/ours_1x_pretrained.yaml @@ -0,0 +1,5 @@ +inherit_from: configs/learning_based/outlier/ours_3x/ours.yaml +generation: + generation_dir: generation_pretrained +test: + model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_3x.pt \ No newline at end of file diff --git a/configs/learning_based/outlier/ours_3plane.yaml b/configs/learning_based/outlier/ours_3plane.yaml new file mode 100644 index 0000000..0c2aa3b --- /dev/null +++ b/configs/learning_based/outlier/ours_3plane.yaml @@ -0,0 +1,54 @@ +data: + class: null + data_type: psr_full + input_type: pointcloud + path: data/shapenet_psr + num_gt_points: 10000 + num_offset: 5 + pointcloud_n: 3000 + pointcloud_noise: 0.005 + pointcloud_outlier_ratio: 0.5 +model: + grid_res: 128 # poisson grid resolution + psr_sigma: 2 + psr_tanh: True + normal_normalize: False + predict_normal: True + predict_offset: True + c_dim: 32 + s_offset: 0.001 + encoder: local_pool_pointnet + encoder_kwargs: + hidden_dim: 32 + plane_type: ['xz', 'xy', 'yz'] + plane_resolution: 64 + unet: True + unet_kwargs: + depth: 4 + merge_mode: concat + start_filts: 32 + decoder: simple_local + decoder_kwargs: + sample_mode: bilinear # bilinear / nearest + hidden_size: 32 +train: + batch_size: 32 + lr: 5e-4 + out_dir: out/shapenet/outlier_ours_3plane + w_psr: 1 + model_selection_metric: psr_l2 + print_every: 100 + checkpoint_every: 200 + validate_every: 5000 + backup_every: 10000 + total_epochs: 400000 + visualize_every: 5000 + exp_pcl: True + exp_mesh: True + n_workers: 8 + n_workers_val: 0 +generation: + exp_gt: False + exp_input: True + psr_resolution: 128 + psr_sigma: 2 diff --git a/configs/learning_based/outlier/ours_3x.yaml b/configs/learning_based/outlier/ours_3x.yaml new file mode 100644 index 0000000..e976867 --- /dev/null +++ b/configs/learning_based/outlier/ours_3x.yaml @@ -0,0 +1,55 @@ +data: + class: null + data_type: psr_full + input_type: pointcloud + path: data/shapenet_psr + num_gt_points: 10000 + num_offset: 3 + pointcloud_n: 3000 + pointcloud_noise: 0.005 + pointcloud_outlier_ratio: 0.5 +model: + grid_res: 128 # poisson grid resolution + psr_sigma: 2 + psr_tanh: True + normal_normalize: False + predict_normal: True + predict_offset: True + c_dim: 32 + s_offset: 0.001 + encoder: local_pool_pointnet + encoder_kwargs: + hidden_dim: 32 + plane_type: 'grid' + grid_resolution: 32 + unet3d: True + unet3d_kwargs: + num_levels: 3 + f_maps: 32 + in_channels: 32 + out_channels: 32 + decoder: simple_local + decoder_kwargs: + sample_mode: bilinear # bilinear / nearest + hidden_size: 32 +train: + batch_size: 32 + lr: 5e-4 + out_dir: out/shapenet/outlier_ours_3x + w_psr: 1 + model_selection_metric: psr_l2 + print_every: 100 + checkpoint_every: 200 + validate_every: 5000 + backup_every: 10000 + total_epochs: 400000 + visualize_every: 5000 + exp_pcl: True + exp_mesh: True + n_workers: 8 + n_workers_val: 0 +generation: + exp_gt: False + exp_input: True + psr_resolution: 128 + psr_sigma: 2 diff --git a/configs/learning_based/outlier/ours_3x_pretrained.yaml b/configs/learning_based/outlier/ours_3x_pretrained.yaml new file mode 100644 index 0000000..57e2590 --- /dev/null +++ b/configs/learning_based/outlier/ours_3x_pretrained.yaml @@ -0,0 +1,5 @@ +inherit_from: configs/learning_based/outlier/ours_1x/ours.yaml +generation: + generation_dir: generation_pretrained +test: + model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_1x.pt \ No newline at end of file diff --git a/configs/learning_based/outlier/ours_5x.yaml b/configs/learning_based/outlier/ours_5x.yaml new file mode 100644 index 0000000..0f067c3 --- /dev/null +++ b/configs/learning_based/outlier/ours_5x.yaml @@ -0,0 +1,55 @@ +data: + class: null + data_type: psr_full + input_type: pointcloud + path: data/shapenet_psr + num_gt_points: 10000 + num_offset: 5 + pointcloud_n: 3000 + pointcloud_noise: 0.005 + pointcloud_outlier_ratio: 0.5 +model: + grid_res: 128 # poisson grid resolution + psr_sigma: 2 + psr_tanh: True + normal_normalize: False + predict_normal: True + predict_offset: True + c_dim: 32 + s_offset: 0.001 + encoder: local_pool_pointnet + encoder_kwargs: + hidden_dim: 32 + plane_type: 'grid' + grid_resolution: 32 + unet3d: True + unet3d_kwargs: + num_levels: 3 + f_maps: 32 + in_channels: 32 + out_channels: 32 + decoder: simple_local + decoder_kwargs: + sample_mode: bilinear # bilinear / nearest + hidden_size: 32 +train: + batch_size: 32 + lr: 5e-4 + out_dir: out/shapenet/outlier_ours_5x + w_psr: 1 + model_selection_metric: psr_l2 + print_every: 100 + checkpoint_every: 200 + validate_every: 5000 + backup_every: 10000 + total_epochs: 400000 + visualize_every: 5000 + exp_pcl: True + exp_mesh: True + n_workers: 8 + n_workers_val: 0 +generation: + exp_gt: False + exp_input: True + psr_resolution: 128 + psr_sigma: 2 diff --git a/configs/learning_based/outlier/ours_5x_pretrained.yaml b/configs/learning_based/outlier/ours_5x_pretrained.yaml new file mode 100644 index 0000000..2ae21a2 --- /dev/null +++ b/configs/learning_based/outlier/ours_5x_pretrained.yaml @@ -0,0 +1,5 @@ +inherit_from: configs/learning_based/outlier/ours_5x/ours.yaml +generation: + generation_dir: generation_pretrained +test: + model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_5x.pt \ No newline at end of file diff --git a/configs/learning_based/outlier/ours_7x.yaml b/configs/learning_based/outlier/ours_7x.yaml new file mode 100644 index 0000000..4152279 --- /dev/null +++ b/configs/learning_based/outlier/ours_7x.yaml @@ -0,0 +1,55 @@ +data: + class: null + data_type: psr_full + input_type: pointcloud + path: data/shapenet_psr + num_gt_points: 10000 + num_offset: 7 + pointcloud_n: 3000 + pointcloud_noise: 0.005 + pointcloud_outlier_ratio: 0.5 +model: + grid_res: 128 # poisson grid resolution + psr_sigma: 2 + psr_tanh: True + normal_normalize: False + predict_normal: True + predict_offset: True + c_dim: 32 + s_offset: 0.001 + encoder: local_pool_pointnet + encoder_kwargs: + hidden_dim: 32 + plane_type: 'grid' + grid_resolution: 32 + unet3d: True + unet3d_kwargs: + num_levels: 3 + f_maps: 32 + in_channels: 32 + out_channels: 32 + decoder: simple_local + decoder_kwargs: + sample_mode: bilinear # bilinear / nearest + hidden_size: 32 +train: + batch_size: 32 + lr: 5e-4 + out_dir: out/shapenet/outlier_ours_7x + w_psr: 1 + model_selection_metric: psr_l2 + print_every: 100 + checkpoint_every: 200 + validate_every: 5000 + backup_every: 10000 + total_epochs: 400000 + visualize_every: 5000 + exp_pcl: True + exp_mesh: True + n_workers: 8 + n_workers_val: 0 +generation: + exp_gt: False + exp_input: True + psr_resolution: 128 + psr_sigma: 2 diff --git a/configs/learning_based/outlier/ours_7x_pretrained.yaml b/configs/learning_based/outlier/ours_7x_pretrained.yaml new file mode 100644 index 0000000..a46e621 --- /dev/null +++ b/configs/learning_based/outlier/ours_7x_pretrained.yaml @@ -0,0 +1,5 @@ +inherit_from: configs/learning_based/outlier/ours_7x/ours.yaml +generation: + generation_dir: generation_pretrained +test: + model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_7x.pt \ No newline at end of file diff --git a/configs/optim_based/dfaust.yaml b/configs/optim_based/dfaust.yaml new file mode 100644 index 0000000..226d24d --- /dev/null +++ b/configs/optim_based/dfaust.yaml @@ -0,0 +1,37 @@ +data: + class: 'only_pcl' + data_type: point + data_path: 'data/dfaust/*.ply' + object_id: 0 + num_points: 20000 +model: + sphere_radius: 0.2 + grid_res: 256 # poisson grid resolution + psr_sigma: 2 +train: + schedule: + pcl: + initial: 1e-2 + interval: 700 + factor: 0.5 + final: 1e-3 + out_dir: out/dfaust + w_chamfer: 1 + n_sup_point: 20000 + batch_size: 1 + n_grow_points: 2000 + resample_every: 200 + subsample_vertex: False + total_epochs: 1600 + print_every: 10 + visualize_every: 2 + checkpoint_every: 500 + save_every: 100 + exp_pcl: True + exp_mesh: True + o3d_show: False + o3d_vis_pcl: True + o3d_window_size: 540 + vis_rendering: False + vis_vert_color: False + n_workers: 0 diff --git a/configs/optim_based/dgp.yaml b/configs/optim_based/dgp.yaml new file mode 100644 index 0000000..79f0726 --- /dev/null +++ b/configs/optim_based/dgp.yaml @@ -0,0 +1,38 @@ +data: + class: 'only_pcl' + data_type: point + data_path: 'data/deep_geometric_prior_data/*.ply' + object_id: 0 + num_points: 20000 +model: + sphere_radius: 0.2 + grid_res: 256 # poisson grid resolution + psr_sigma: 2 +train: + schedule: + pcl: + initial: 1e-2 + interval: 700 + factor: 0.5 + final: 1e-3 + out_dir: out/dgp + w_reg_point: 0 + w_chamfer: 1 + n_sup_point: 20000 + batch_size: 1 + n_grow_points: 2000 + resample_every: 200 + subsample_vertex: False + total_epochs: 1600 + print_every: 10 + visualize_every: 2 + checkpoint_every: 500 + save_every: 100 + exp_pcl: True + exp_mesh: True + o3d_show: False + o3d_vis_pcl: True + o3d_window_size: 540 + vis_rendering: False + vis_vert_color: False + n_workers: 0 diff --git a/configs/optim_based/teaser.yaml b/configs/optim_based/teaser.yaml new file mode 100644 index 0000000..40fd2d4 --- /dev/null +++ b/configs/optim_based/teaser.yaml @@ -0,0 +1,37 @@ +data: + class: 'only_pcl' + data_type: point + data_path: 'data/demo/wheel.obj' + object_id: 0 + num_points: 20000 +model: + sphere_radius: 0.2 + grid_res: 128 # poisson grid resolution + psr_sigma: 2 +train: + schedule: + pcl: + initial: 1e-2 + interval: 700 + factor: 0.5 + final: 1e-3 + out_dir: out/demo_optim + w_chamfer: 1 + n_sup_point: 20000 + batch_size: 1 + n_grow_points: 2000 + resample_every: 200 + subsample_vertex: False + total_epochs: 1600 + print_every: 10 + visualize_every: 2 + checkpoint_every: 500 + save_every: 100 + exp_pcl: True + exp_mesh: True + o3d_show: False + o3d_vis_pcl: True + o3d_window_size: 540 + vis_rendering: False + vis_vert_color: False + n_workers: 0 \ No newline at end of file diff --git a/configs/optim_based/thingi.yaml b/configs/optim_based/thingi.yaml new file mode 100644 index 0000000..fa5f1a6 --- /dev/null +++ b/configs/optim_based/thingi.yaml @@ -0,0 +1,39 @@ +data: + class: 'only_pcl' + data_type: point + data_path: 'data/thingi/*.ply' + object_id: 0 + num_points: 20000 +model: + sphere_radius: 0.2 + grid_res: 128 # poisson grid resolution + psr_sigma: 2 +train: + # lr_pcl: 2e-2 + schedule: + pcl: + initial: 1e-2 + interval: 700 + factor: 0.5 + final: 1e-3 + out_dir: out/thingi + w_reg_point: 0 + w_chamfer: 1 + n_sup_point: 20000 + batch_size: 1 + n_grow_points: 2000 + resample_every: 200 + subsample_vertex: False + total_epochs: 1600 + print_every: 10 + visualize_every: 2 + checkpoint_every: 500 + save_every: 100 + exp_pcl: True + exp_mesh: True + o3d_show: False + o3d_vis_pcl: True + o3d_window_size: 540 + vis_rendering: False + vis_vert_color: False + n_workers: 0 diff --git a/configs/optim_based/thingi_noisy.yaml b/configs/optim_based/thingi_noisy.yaml new file mode 100644 index 0000000..3abc166 --- /dev/null +++ b/configs/optim_based/thingi_noisy.yaml @@ -0,0 +1,39 @@ +data: + class: 'only_pcl' + data_type: point + data_path: 'data/thingi_noisy/*.ply' + object_id: 0 + num_points: 20000 +model: + sphere_radius: 0.2 + grid_res: 128 # poisson grid resolution + psr_sigma: 2 +train: + # lr_pcl: 2e-2 + schedule: + pcl: + initial: 1e-2 + interval: 700 + factor: 0.5 + final: 1e-3 + out_dir: out/thingi_noisy + w_reg_point: 0 + w_chamfer: 1 + n_sup_point: 20000 + batch_size: 1 + n_grow_points: 2000 + resample_every: 200 + subsample_vertex: False + total_epochs: 1600 + print_every: 10 + visualize_every: 2 + checkpoint_every: 500 + save_every: 100 + exp_pcl: True + exp_mesh: True + o3d_show: False + o3d_vis_pcl: True + o3d_window_size: 540 + vis_rendering: False + vis_vert_color: False + n_workers: 0 diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000..ae023d7 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,29 @@ +name: sap +channels: + - anaconda + - conda-forge + - pytorch + - defaults +dependencies: + - python=3.8 + - pytorch=1.9.0 + - torchvision=0.10.0 + - cudatoolkit=10.2 + - numpy=1.18.1 + - matplotlib=3.4.3 + - pyyaml=5.3.1 + - scipy=1.4.1 + - tqdm=4.54.0 + - trimesh=3.8.14 + - igl=2.2.1 + - tensorboard=2.6.0 + - pip + - pip: + - plyfile==0.7 + - open3d>=0.11.1 + - scikit-image>=0.18.0 + - python-mnist==0.7 + - opencv-python>=4.4 + - av==8.0.3 + - pykdtree==1.3.4 + - ipdb==0.13.7 diff --git a/eval_meshes.py b/eval_meshes.py new file mode 100644 index 0000000..96a82e5 --- /dev/null +++ b/eval_meshes.py @@ -0,0 +1,155 @@ +import torch +import trimesh +from torch.utils.data import Dataset, DataLoader +import numpy as np; np.set_printoptions(precision=4) +import shutil, argparse, time, os +import pandas as pd +from src.data import collate_remove_none, collate_stack_together, worker_init_fn +from src.training import Trainer +from src.model import Encode2Points +from src.data import PointCloudField, IndexField, Shapes3dDataset +from src.utils import load_config, load_pointcloud +from src.eval import MeshEvaluator +from tqdm import tqdm +from pdb import set_trace as st + + +def main(): + parser = argparse.ArgumentParser(description='MNIST toy experiment') + parser.add_argument('config', type=str, help='Path to config file.') + parser.add_argument('--no_cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') + parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.') + + args = parser.parse_args() + cfg = load_config(args.config, 'configs/default.yaml') + use_cuda = not args.no_cuda and torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + data_type = cfg['data']['data_type'] + # Shorthands + out_dir = cfg['train']['out_dir'] + generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir']) + + if cfg['generation'].get('iter', 0)!=0: + generation_dir += '_%04d'%cfg['generation']['iter'] + elif args.iter is not None: + generation_dir += '_%04d'%args.iter + + print('Evaluate meshes under %s'%generation_dir) + + out_file = os.path.join(generation_dir, 'eval_meshes_full.pkl') + out_file_class = os.path.join(generation_dir, 'eval_meshes.csv') + + # PYTORCH VERSION > 1.0.0 + assert(float(torch.__version__.split('.')[-3]) > 0) + + pointcloud_field = PointCloudField(cfg['data']['pointcloud_file']) + fields = { + 'pointcloud': pointcloud_field, + 'idx': IndexField(), + } + + print('Test split: ', cfg['data']['test_split']) + + dataset_folder = cfg['data']['path'] + dataset = Shapes3dDataset( + dataset_folder, fields, + cfg['data']['test_split'], + categories=cfg['data']['class'], cfg=cfg) + + # Loader + test_loader = torch.utils.data.DataLoader( + dataset, batch_size=1, num_workers=0, shuffle=False) + + # Evaluator + evaluator = MeshEvaluator(n_points=100000) + + eval_dicts = [] + print('Evaluating meshes...') + for it, data in enumerate(tqdm(test_loader)): + + if data is None: + print('Invalid data.') + continue + + mesh_dir = os.path.join(generation_dir, 'meshes') + pointcloud_dir = os.path.join(generation_dir, 'pointcloud') + + + # Get index etc. + idx = data['idx'].item() + try: + model_dict = dataset.get_model_dict(idx) + except AttributeError: + model_dict = {'model': str(idx), 'category': 'n/a'} + + modelname = model_dict['model'] + category_id = model_dict['category'] + + try: + category_name = dataset.metadata[category_id].get('name', 'n/a') + except AttributeError: + category_name = 'n/a' + + if category_id != 'n/a': + mesh_dir = os.path.join(mesh_dir, category_id) + pointcloud_dir = os.path.join(pointcloud_dir, category_id) + + # Evaluate + pointcloud_tgt = data['pointcloud'].squeeze(0).numpy() + normals_tgt = data['pointcloud.normals'].squeeze(0).numpy() + + + eval_dict = { + 'idx': idx, + 'class id': category_id, + 'class name': category_name, + 'modelname':modelname, + } + eval_dicts.append(eval_dict) + + # Evaluate mesh + if cfg['test']['eval_mesh']: + mesh_file = os.path.join(mesh_dir, '%s.off' % modelname) + + if os.path.exists(mesh_file): + mesh = trimesh.load(mesh_file, process=False) + eval_dict_mesh = evaluator.eval_mesh( + mesh, pointcloud_tgt, normals_tgt) + for k, v in eval_dict_mesh.items(): + eval_dict[k + ' (mesh)'] = v + else: + print('Warning: mesh does not exist: %s' % mesh_file) + + # Evaluate point cloud + if cfg['test']['eval_pointcloud']: + pointcloud_file = os.path.join( + pointcloud_dir, '%s.ply' % modelname) + + if os.path.exists(pointcloud_file): + pointcloud = load_pointcloud(pointcloud_file).astype(np.float32) + eval_dict_pcl = evaluator.eval_pointcloud( + pointcloud, pointcloud_tgt) + for k, v in eval_dict_pcl.items(): + eval_dict[k + ' (pcl)'] = v + else: + print('Warning: pointcloud does not exist: %s' + % pointcloud_file) + + + # Create pandas dataframe and save + eval_df = pd.DataFrame(eval_dicts) + eval_df.set_index(['idx'], inplace=True) + eval_df.to_pickle(out_file) + + # Create CSV file with main statistics + eval_df_class = eval_df.groupby(by=['class name']).mean() + eval_df_class.loc['mean'] = eval_df_class.mean() + eval_df_class.to_csv(out_file_class) + + # Print results + print(eval_df_class) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/generate.py b/generate.py new file mode 100644 index 0000000..08dd28b --- /dev/null +++ b/generate.py @@ -0,0 +1,217 @@ +import torch +from torch.utils.data import Dataset, DataLoader +import numpy as np; np.set_printoptions(precision=4) +import shutil, argparse, time, os +import pandas as pd +from collections import defaultdict +from src import config +from src.utils import mc_from_psr, export_mesh, export_pointcloud +from src.dpsr import DPSR +from src.training import Trainer +from src.model import Encode2Points +from src.utils import load_config, load_model_manual, scale2onet, is_url, load_url +from tqdm import tqdm +from pdb import set_trace as st + + +def main(): + parser = argparse.ArgumentParser(description='MNIST toy experiment') + parser.add_argument('config', type=str, help='Path to config file.') + parser.add_argument('--no_cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') + parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.') + + args = parser.parse_args() + cfg = load_config(args.config, 'configs/default.yaml') + use_cuda = not args.no_cuda and torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + data_type = cfg['data']['data_type'] + input_type = cfg['data']['input_type'] + vis_n_outputs = cfg['generation']['vis_n_outputs'] + if vis_n_outputs is None: + vis_n_outputs = -1 + # Shorthands + out_dir = cfg['train']['out_dir'] + if not out_dir: + os.makedirs(out_dir) + generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir']) + out_time_file = os.path.join(generation_dir, 'time_generation_full.pkl') + out_time_file_class = os.path.join(generation_dir, 'time_generation.pkl') + + # PYTORCH VERSION > 1.0.0 + assert(float(torch.__version__.split('.')[-3]) > 0) + + dataset = config.get_dataset('test', cfg, return_idx=True) + test_loader = torch.utils.data.DataLoader( + dataset, batch_size=1, num_workers=0, shuffle=False) + + model = Encode2Points(cfg).to(device) + + # load model + try: + if is_url(cfg['test']['model_file']): + state_dict = load_url(cfg['test']['model_file']) + elif cfg['generation'].get('iter', 0)!=0: + state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% cfg['generation']['iter'])) + generation_dir += '_%04d'%cfg['generation']['iter'] + elif args.iter is not None: + state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% args.iter)) + else: + state_dict = torch.load(os.path.join(out_dir, 'model_best.pt')) + + load_model_manual(state_dict['state_dict'], model) + + except: + print('Model loading error. Exiting.') + exit() + + + # Generator + generator = config.get_generator(model, cfg, device=device) + + # Determine what to generate + generate_mesh = cfg['generation']['generate_mesh'] + generate_pointcloud = cfg['generation']['generate_pointcloud'] + + # Statistics + time_dicts = [] + + # Generate + model.eval() + dpsr = DPSR(res=(cfg['generation']['psr_resolution'], + cfg['generation']['psr_resolution'], + cfg['generation']['psr_resolution']), + sig= cfg['generation']['psr_sigma']).to(device) + + + + # Count how many models already created + model_counter = defaultdict(int) + + print('Generating...') + for it, data in enumerate(tqdm(test_loader)): + + # Output folders + mesh_dir = os.path.join(generation_dir, 'meshes') + in_dir = os.path.join(generation_dir, 'input') + pointcloud_dir = os.path.join(generation_dir, 'pointcloud') + generation_vis_dir = os.path.join(generation_dir, 'vis', ) + + # Get index etc. + idx = data['idx'].item() + + try: + model_dict = dataset.get_model_dict(idx) + except AttributeError: + model_dict = {'model': str(idx), 'category': 'n/a'} + + modelname = model_dict['model'] + category_id = model_dict['category'] + + try: + category_name = dataset.metadata[category_id].get('name', 'n/a') + except AttributeError: + category_name = 'n/a' + + if category_id != 'n/a': + mesh_dir = os.path.join(mesh_dir, str(category_id)) + pointcloud_dir = os.path.join(pointcloud_dir, str(category_id)) + in_dir = os.path.join(in_dir, str(category_id)) + + folder_name = str(category_id) + if category_name != 'n/a': + folder_name = str(folder_name) + '_' + category_name.split(',')[0] + + generation_vis_dir = os.path.join(generation_vis_dir, folder_name) + + # Create directories if necessary + if vis_n_outputs >= 0 and not os.path.exists(generation_vis_dir): + os.makedirs(generation_vis_dir) + + if generate_mesh and not os.path.exists(mesh_dir): + os.makedirs(mesh_dir) + + if generate_pointcloud and not os.path.exists(pointcloud_dir): + os.makedirs(pointcloud_dir) + + if not os.path.exists(in_dir): + os.makedirs(in_dir) + + # Timing dict + time_dict = { + 'idx': idx, + 'class id': category_id, + 'class name': category_name, + 'modelname':modelname, + } + time_dicts.append(time_dict) + + # Generate outputs + out_file_dict = {} + + if generate_mesh: + #! deploy the generator to a separate class + out = generator.generate_mesh(data) + + v, f, points, normals, stats_dict = out + time_dict.update(stats_dict) + + # Write output + mesh_out_file = os.path.join(mesh_dir, '%s.off' % modelname) + export_mesh(mesh_out_file, scale2onet(v), f) + out_file_dict['mesh'] = mesh_out_file + + if generate_pointcloud: + pointcloud_out_file = os.path.join( + pointcloud_dir, '%s.ply' % modelname) + export_pointcloud(pointcloud_out_file, scale2onet(points), normals) + out_file_dict['pointcloud'] = pointcloud_out_file + + if cfg['generation']['copy_input']: + inputs_path = os.path.join(in_dir, '%s.ply' % modelname) + p = data.get('inputs').to(device) + export_pointcloud(inputs_path, scale2onet(p)) + out_file_dict['in'] = inputs_path + + # Copy to visualization directory for first vis_n_output samples + c_it = model_counter[category_id] + if c_it < vis_n_outputs: + # Save output files + img_name = '%02d.off' % c_it + for k, filepath in out_file_dict.items(): + ext = os.path.splitext(filepath)[1] + out_file = os.path.join(generation_vis_dir, '%02d_%s%s' + % (c_it, k, ext)) + shutil.copyfile(filepath, out_file) + + # Also generate oracle meshes + if cfg['generation']['exp_oracle']: + points_gt = data.get('gt_points').to(device) + normals_gt = data.get('gt_points.normals').to(device) + psr_gt = dpsr(points_gt, normals_gt) + v, f, _ = mc_from_psr(psr_gt, + zero_level=cfg['data']['zero_level']) + out_file = os.path.join(generation_vis_dir, '%02d_%s%s' + % (c_it, 'mesh_oracle', '.off')) + export_mesh(out_file, scale2onet(v), f) + + model_counter[category_id] += 1 + + + # Create pandas dataframe and save + time_df = pd.DataFrame(time_dicts) + time_df.set_index(['idx'], inplace=True) + time_df.to_pickle(out_time_file) + + # Create pickle files with main statistics + time_df_class = time_df.groupby(by=['class name']).mean() + time_df_class.loc['mean'] = time_df_class.mean() + time_df_class.to_pickle(out_time_file_class) + + # Print results + print('Timings [s]:') + print(time_df_class) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/media/results_large_noise.gif b/media/results_large_noise.gif new file mode 100644 index 0000000..4feabda Binary files /dev/null and b/media/results_large_noise.gif differ diff --git a/media/results_outliers.gif b/media/results_outliers.gif new file mode 100644 index 0000000..4b3abc0 Binary files /dev/null and b/media/results_outliers.gif differ diff --git a/optim.py b/optim.py new file mode 100644 index 0000000..fa552b0 --- /dev/null +++ b/optim.py @@ -0,0 +1,315 @@ +import torch +import trimesh +import shutil, argparse, time, os, glob + +import numpy as np; np.set_printoptions(precision=4) +import open3d as o3d +from torch.utils.tensorboard import SummaryWriter +from torchvision.utils import save_image +from torchvision.io import write_video + +from src.optimization import Trainer +from src.utils import load_config, update_config, initialize_logger, \ + get_learning_rate_schedules, adjust_learning_rate, AverageMeter,\ + update_optimizer, export_pointcloud +from skimage import measure +from plyfile import PlyData +from pytorch3d.ops import sample_points_from_meshes +from pytorch3d.io import load_objs_as_meshes +from pytorch3d.structures import Meshes + + +def main(): + parser = argparse.ArgumentParser(description='MNIST toy experiment') + parser.add_argument('config', type=str, help='Path to config file.') + parser.add_argument('--no_cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1457, metavar='S', + help='random seed') + + args, unknown = parser.parse_known_args() + cfg = load_config(args.config, 'configs/default.yaml') + cfg = update_config(cfg, unknown) + use_cuda = not args.no_cuda and torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + data_type = cfg['data']['data_type'] + data_class = cfg['data']['class'] + + print(cfg['train']['out_dir']) + + # PYTORCH VERSION > 1.0.0 + assert(float(torch.__version__.split('.')[-3]) > 0) + + # boiler-plate + if cfg['train']['timestamp']: + cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S") + logger = initialize_logger(cfg) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + shutil.copyfile(args.config, + os.path.join(cfg['train']['out_dir'], 'config.yaml')) + + # tensorboardX writer + tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log") + if not os.path.exists(tblogdir): + os.makedirs(tblogdir) + writer = SummaryWriter(log_dir=tblogdir) + + # initialize o3d visualizer + vis = None + if cfg['train']['o3d_show']: + vis = o3d.visualization.Visualizer() + vis.create_window(width=cfg['train']['o3d_window_size'], + height=cfg['train']['o3d_window_size']) + + # initialize dataset + if data_type == 'point': + if cfg['data']['object_id'] != -1: + data_paths = sorted(glob.glob(cfg['data']['data_path'])) + data_path = data_paths[cfg['data']['object_id']] + print('Loaded %d/%d object' % (cfg['data']['object_id']+1, len(data_paths))) + else: + data_path = cfg['data']['data_path'] + print('Data loaded') + ext = data_path.split('.')[-1] + if ext == 'obj': # have GT mesh + mesh = load_objs_as_meshes([data_path], device=device) + # scale the mesh into unit cube + verts = mesh.verts_packed() + N = verts.shape[0] + center = verts.mean(0) + mesh.offset_verts_(-center.expand(N, 3)) + scale = max((verts - center).abs().max(0)[0]) + mesh.scale_verts_((1.0 / float(scale))) + # important for our DPSR to have the range in [0, 1), not reaching 1 + mesh.scale_verts_(0.9) + + target_pts, target_normals = sample_points_from_meshes(mesh, + num_samples=200000, return_normals=True) + elif ext == 'ply': # only have the point cloud + plydata = PlyData.read(data_path) + vertices = np.stack([plydata['vertex']['x'], + plydata['vertex']['y'], + plydata['vertex']['z']], axis=1) + normals = np.stack([plydata['vertex']['nx'], + plydata['vertex']['ny'], + plydata['vertex']['nz']], axis=1) + N = vertices.shape[0] + center = vertices.mean(0) + scale = np.max(np.max(np.abs(vertices - center), axis=0)) + vertices -= center + vertices /= scale + vertices *= 0.9 + + target_pts = torch.tensor(vertices, device=device)[None].float() + target_normals = torch.tensor(normals, device=device)[None].float() + mesh = None # no GT mesh + + if not torch.is_tensor(center): + center = torch.from_numpy(center) + if not torch.is_tensor(scale): + scale = torch.from_numpy(np.array([scale])) + + data = {'target_points': target_pts, + 'target_normals': target_normals, # normals are never used + 'gt_mesh': mesh} + + else: + raise NotImplementedError + + # save the input point cloud + if 'target_points' in data.keys(): + outdir_pcl = os.path.join(cfg['train']['out_dir'], 'target_pcl.ply') + if 'target_normals' in data.keys(): + export_pointcloud(outdir_pcl, data['target_points'], data['target_normals']) + else: + export_pointcloud(outdir_pcl, data['target_points']) + + # save oracle PSR mesh (mesh from our PSR using GT point+normals) + if data.get('gt_mesh') is not None: + gt_verts, gt_faces = data['gt_mesh'].get_mesh_verts_faces(0) + pts_gt, norms_gt = sample_points_from_meshes(data['gt_mesh'], + num_samples=500000, return_normals=True) + pts_gt = (pts_gt + 1) / 2 + from src.dpsr import DPSR + dpsr_tmp = DPSR(res=(cfg['model']['grid_res'], + cfg['model']['grid_res'], + cfg['model']['grid_res']), + sig=cfg['model']['psr_sigma']).to(device) + target = dpsr_tmp(pts_gt, norms_gt).unsqueeze(1).to(device) + target = torch.tanh(target) + s = target.shape[-1] # size of psr_grid + psr_grid_numpy = target.squeeze().detach().cpu().numpy() + verts, faces, _, _ = measure.marching_cubes(psr_grid_numpy) + verts = verts / s * 2. - 1 # [-1, 1] + mesh = o3d.geometry.TriangleMesh() + mesh.vertices = o3d.utility.Vector3dVector(verts) + mesh.triangles = o3d.utility.Vector3iVector(faces) + outdir_mesh = os.path.join(cfg['train']['out_dir'], 'oracle_mesh.ply') + o3d.io.write_triangle_mesh(outdir_mesh, mesh) + + # initialize the source point cloud given an input mesh + if 'input_mesh' in cfg['train'].keys() and \ + os.path.isfile(cfg['train']['input_mesh']): + if cfg['train']['input_mesh'].split('/')[-2] == 'mesh': + mesh_tmp = trimesh.load_mesh(cfg['train']['input_mesh']) + verts = torch.from_numpy(mesh_tmp.vertices[None]).float().to(device) + faces = torch.from_numpy(mesh_tmp.faces[None]).to(device) + mesh = Meshes(verts=verts, faces=faces) + points, normals = sample_points_from_meshes(mesh, + num_samples=cfg['data']['num_points'], return_normals=True) + # mesh is saved in the original scale of the gt + points -= center.float().to(device) + points /= scale.float().to(device) + points *= 0.9 + # make sure the points are within the range of [0, 1) + points = points / 2. + 0.5 + else: + # directly initialize from a point cloud + pcd = o3d.io.read_point_cloud(cfg['train']['input_mesh']) + points = torch.from_numpy(np.array(pcd.points)[None]).float().to(device) + normals = torch.from_numpy(np.array(pcd.normals)[None]).float().to(device) + points -= center.float().to(device) + points /= scale.float().to(device) + points *= 0.9 + points = points / 2. + 0.5 + else: #! initialize our source point cloud from a sphere + sphere_radius = cfg['model']['sphere_radius'] + sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius, + count=[256,256]) + points, idx = sphere_mesh.sample(cfg['data']['num_points'], + return_index=True) + points += 0.5 # make sure the points are within the range of [0, 1) + normals = sphere_mesh.face_normals[idx] + points = torch.from_numpy(points).unsqueeze(0).to(device) + normals = torch.from_numpy(normals).unsqueeze(0).to(device) + + + points = torch.log(points/(1-points)) # inverse sigmoid + inputs = torch.cat([points, normals], axis=-1).float() + inputs.requires_grad = True + + model = None # no network + + # initialize optimizer + cfg['train']['schedule']['pcl']['initial'] = cfg['train']['lr_pcl'] + print('Initial learning rate:', cfg['train']['schedule']['pcl']['initial']) + if 'schedule' in cfg['train']: + lr_schedules = get_learning_rate_schedules(cfg['train']['schedule']) + else: + lr_schedules = None + + optimizer = update_optimizer(inputs, cfg, + epoch=0, model=model, schedule=lr_schedules) + + try: + # load model + state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt')) + if ('pcl' in state_dict.keys()) & (state_dict['pcl'] is not None): + inputs = state_dict['pcl'].to(device) + inputs.requires_grad = True + + optimizer = update_optimizer(inputs, cfg, + epoch=state_dict.get('epoch'), schedule=lr_schedules) + + out = "Load model from epoch %d" % state_dict.get('epoch', 0) + print(out) + logger.info(out) + except: + state_dict = dict() + + start_epoch = state_dict.get('epoch', -1) + + trainer = Trainer(cfg, optimizer, device=device) + runtime = {} + runtime['all'] = AverageMeter() + + # training loop + for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1): + + # schedule the learning rate + if (epoch>0) & (lr_schedules is not None): + if (epoch % lr_schedules[0].interval == 0): + adjust_learning_rate(lr_schedules, optimizer, epoch) + if len(lr_schedules) >1: + print('[epoch {}] net_lr: {}, pcl_lr: {}'.format(epoch, + lr_schedules[0].get_learning_rate(epoch), + lr_schedules[1].get_learning_rate(epoch))) + else: + print('[epoch {}] adjust pcl_lr to: {}'.format(epoch, + lr_schedules[0].get_learning_rate(epoch))) + + start = time.time() + loss, loss_each = trainer.train_step(data, inputs, model, epoch) + runtime['all'].update(time.time() - start) + + if epoch % cfg['train']['print_every'] == 0: + log_text = ('[Epoch %02d] loss=%.5f') %(epoch, loss) + if loss_each is not None: + for k, l in loss_each.items(): + if l.item() != 0.: + log_text += (' loss_%s=%.5f') % (k, l.item()) + + log_text += (' time=%.3f / %.3f') % (runtime['all'].val, + runtime['all'].sum) + logger.info(log_text) + print(log_text) + + # visualize point clouds and meshes + if (epoch % cfg['train']['visualize_every'] == 0) & (vis is not None): + trainer.visualize(data, inputs, model, epoch, o3d_vis=vis) + + # save outputs + if epoch % cfg['train']['save_every'] == 0: + trainer.save_mesh_pointclouds(inputs, epoch, + center.cpu().numpy(), + scale.cpu().numpy()*(1/0.9)) + + # save checkpoints + if (epoch > 0) & (epoch % cfg['train']['checkpoint_every'] == 0): + state = {'epoch': epoch} + pcl = None + if isinstance(inputs, torch.Tensor): + state['pcl'] = inputs.detach().cpu() + + torch.save(state, os.path.join(cfg['train']['dir_model'], + '%04d' % epoch + '.pt')) + print("Save new model at epoch %d" % epoch) + logger.info("Save new model at epoch %d" % epoch) + torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt')) + + # resample and gradually add new points to the source pcl + if (epoch > 0) & \ + (cfg['train']['resample_every']!=0) & \ + (epoch % cfg['train']['resample_every'] == 0) & \ + (epoch < cfg['train']['total_epochs']): + inputs = trainer.point_resampling(inputs) + optimizer = update_optimizer(inputs, cfg, + epoch=epoch, model=model, schedule=lr_schedules) + trainer = Trainer(cfg, optimizer, device=device) + + # visualize the Open3D outputs + if cfg['train']['o3d_show']: + out_video_dir = os.path.join(cfg['train']['out_dir'], + 'vis/o3d/video.mp4') + if os.path.isfile(out_video_dir): + os.system('rm {}'.format(out_video_dir)) + os.system('ffmpeg -framerate 30 \ + -start_number 0 \ + -i {}/vis/o3d/%04d.jpg \ + -pix_fmt yuv420p \ + -crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir)) + out_video_dir = os.path.join(cfg['train']['out_dir'], + 'vis/o3d/video_pcd.mp4') + if os.path.isfile(out_video_dir): + os.system('rm {}'.format(out_video_dir)) + os.system('ffmpeg -framerate 30 \ + -start_number 0 \ + -i {}/vis/o3d/%04d_pcd.jpg \ + -pix_fmt yuv420p \ + -crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir)) + print('Video saved.') + +if __name__ == '__main__': + main() diff --git a/optim_hierarchy.py b/optim_hierarchy.py new file mode 100644 index 0000000..23347b9 --- /dev/null +++ b/optim_hierarchy.py @@ -0,0 +1,69 @@ +import sys, os +import argparse +from src.utils import load_config +import subprocess +os.environ['MKL_THREADING_LAYER'] = 'GNU' + +def main(): + + parser = argparse.ArgumentParser(description='MNIST toy experiment') + parser.add_argument('config', type=str, help='Path to config file.') + parser.add_argument('--start_res', type=int, default=-1, help='Resolution to start with.') + parser.add_argument('--object_id', type=int, default=-1, help='Object index.') + + args, unknown = parser.parse_known_args() + cfg = load_config(args.config, 'configs/default.yaml') + + resolutions=[32, 64, 128, 256] + iterations=[1000, 1000, 1000, 200] + lrs=[2e-3, 2e-3*0.7, 2e-3*(0.7**2), 2e-3*(0.7**3)] # reduce lr + for idx,(res, iteration, lr) in enumerate(zip(resolutions, iterations, lrs)): + + if rescfg['model']['grid_res']: + continue + + psr_sigma= 2 if res<=128 else 3 + + if res > 128: + psr_sigma = 5 if 'thingi_noisy' in args.config else 3 + + if args.object_id != -1: + out_dir = os.path.join(cfg['train']['out_dir'], 'object_%02d'%args.object_id, 'res_%d'%res) + else: + out_dir = os.path.join(cfg['train']['out_dir'], 'res_%d'%res) + + # sample from mesh when resampling is enabled, otherwise reuse the pointcloud + init_shape='mesh' if cfg['train']['resample_every']>0 else 'pointcloud' + + + if args.object_id != -1: + input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'], + 'object_%02d'%args.object_id, 'res_%d' % (resolutions[idx-1]), + 'vis', init_shape, '%04d.ply' % (iterations[idx-1])) + else: + input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'], + 'res_%d' % (resolutions[idx-1]), + 'vis', init_shape, '%04d.ply' % (iterations[idx-1])) + + + cmd = 'export MKL_SERVICE_FORCE_INTEL=1 && ' + cmd += "python optim.py %s --model:grid_res %d --model:psr_sigma %d \ + --train:input_mesh %s --train:total_epochs %d \ + --train:out_dir %s --train:lr_pcl %f \ + --data:object_id %d" % ( + args.config, + res, + psr_sigma, + input_mesh, + iteration, + out_dir, + lr, + args.object_id) + print(cmd) + os.system(cmd) + +if __name__=="__main__": + main() diff --git a/scripts/download_demo_data.sh b/scripts/download_demo_data.sh new file mode 100644 index 0000000..8f617b2 --- /dev/null +++ b/scripts/download_demo_data.sh @@ -0,0 +1,8 @@ +#!/bin/bash +mkdir -p data +cd data +echo "Start downloading ..." +wget https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/data/demo.zip +unzip demo.zip +rm demo.zip +echo "Done!" \ No newline at end of file diff --git a/scripts/download_optim_data.sh b/scripts/download_optim_data.sh new file mode 100644 index 0000000..43dbc47 --- /dev/null +++ b/scripts/download_optim_data.sh @@ -0,0 +1,8 @@ +#!/bin/bash +mkdir -p data +cd data +echo "Start downloading data for optimization-based setting (~200 MB)" +wget https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/data/optim_data.zip +unzip optim_data.zip +rm optim_data.zip +echo "Done!" \ No newline at end of file diff --git a/scripts/download_shapenet.sh b/scripts/download_shapenet.sh new file mode 100644 index 0000000..6d6e90c --- /dev/null +++ b/scripts/download_shapenet.sh @@ -0,0 +1,8 @@ +#!/bin/bash +mkdir -p data +cd data +echo "Start downloading preprocessed ShapeNet data (~220G)" +wget https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/data/shapenet_psr.zip +unzip shapenet_psr.zip +rm shapenet_psr.zip +echo "Done!" \ No newline at end of file diff --git a/scripts/process_shapenet.py b/scripts/process_shapenet.py new file mode 100644 index 0000000..cde672b --- /dev/null +++ b/scripts/process_shapenet.py @@ -0,0 +1,101 @@ +import os +import torch +import time +import multiprocessing +import numpy as np +from tqdm import tqdm +from src.dpsr import DPSR + +data_path = 'data/ShapeNet' # path for ShapeNet from ONet +base = 'data' # output base directory +dataset_name = 'shapenet_psr' +multiprocess = True +njobs = 8 +save_pointcloud = True +save_psr_field = True +resolution = 128 +zero_level = 0.0 +num_points = 100000 +padding = 1.2 + +dpsr = DPSR(res=(resolution, resolution, resolution), sig=0) + +def process_one(obj): + + obj_name = obj.split('/')[-1] + c = obj.split('/')[-2] + + # create new for the current object + out_path_cur = os.path.join(base, dataset_name, c) + out_path_cur_obj = os.path.join(out_path_cur, obj_name) + os.makedirs(out_path_cur_obj, exist_ok=True) + + gt_path = os.path.join(data_path, c, obj_name, 'pointcloud.npz') + data = np.load(gt_path) + points = data['points'] + normals = data['normals'] + + # normalize the point to [0, 1) + points = points / padding + 0.5 + # to scale back during inference, we should: + #! p = (p - 0.5) * padding + + if save_pointcloud: + outdir = os.path.join(out_path_cur_obj, 'pointcloud.npz') + # np.savez(outdir, points=points, normals=normals) + np.savez(outdir, points=data['points'], normals=data['normals']) + # return + + if save_psr_field: + psr_gt = dpsr(torch.from_numpy(points.astype(np.float32))[None], + torch.from_numpy(normals.astype(np.float32))[None]).squeeze().cpu().numpy().astype(np.float16) + + outdir = os.path.join(out_path_cur_obj, 'psr.npz') + np.savez(outdir, psr=psr_gt) + + +def main(c): + + print('---------------------------------------') + print('Processing {} {}'.format(c, split)) + print('---------------------------------------') + + for split in ['train', 'val', 'test']: + fname = os.path.join(data_path, c, split+'.lst') + with open(fname, 'r') as f: + obj_list = f.read().splitlines() + + obj_list = [c+'/'+s for s in obj_list] + + if multiprocess: + # multiprocessing.set_start_method('spawn', force=True) + pool = multiprocessing.Pool(njobs) + try: + for _ in tqdm(pool.imap_unordered(process_one, obj_list), total=len(obj_list)): + pass + # pool.map_async(process_one, obj_list).get() + except KeyboardInterrupt: + # Allow ^C to interrupt from any thread. + exit() + pool.close() + else: + for obj in tqdm(obj_list): + process_one(obj) + + print('Done Processing {} {}!'.format(c, split)) + + +if __name__ == "__main__": + + classes = ['02691156', '02828884', '02933112', + '02958343', '03211117', '03001627', + '03636649', '03691459', '04090263', + '04256520', '04379243', '04401088', '04530566'] + + + t_start = time.time() + for c in classes: + main(c) + + t_end = time.time() + print('Total processing time: ', t_end - t_start) diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..2c318e3 --- /dev/null +++ b/src/config.py @@ -0,0 +1,146 @@ +import yaml +from torchvision import transforms +from src import data, generation +from src.dpsr import DPSR +from ipdb import set_trace as st + + +# Generator for final mesh extraction +def get_generator(model, cfg, device, **kwargs): + ''' Returns the generator object. + + Args: + model (nn.Module): Occupancy Network model + cfg (dict): imported yaml config + device (device): pytorch device + ''' + + if cfg['generation']['psr_resolution'] == 0: + psr_res = cfg['model']['grid_res'] + psr_sigma = cfg['model']['psr_sigma'] + else: + psr_res = cfg['generation']['psr_resolution'] + psr_sigma = cfg['generation']['psr_sigma'] + + dpsr = DPSR(res=(psr_res, psr_res, psr_res), + sig= psr_sigma).to(device) + + + generator = generation.Generator3D( + model, + device=device, + threshold=cfg['data']['zero_level'], + sample=cfg['generation']['use_sampling'], + input_type = cfg['data']['input_type'], + padding=cfg['data']['padding'], + dpsr=dpsr, + psr_tanh=cfg['model']['psr_tanh'] + ) + return generator + +# Datasets +def get_dataset(mode, cfg, return_idx=False): + ''' Returns the dataset. + + Args: + model (nn.Module): the model which is used + cfg (dict): config dictionary + return_idx (bool): whether to include an ID field + ''' + dataset_type = cfg['data']['dataset'] + dataset_folder = cfg['data']['path'] + categories = cfg['data']['class'] + + # Get split + splits = { + 'train': cfg['data']['train_split'], + 'val': cfg['data']['val_split'], + 'test': cfg['data']['test_split'], + 'vis': cfg['data']['val_split'], + } + + split = splits[mode] + + # Create dataset + if dataset_type == 'Shapes3D': + fields = get_data_fields(mode, cfg) + # Input fields + inputs_field = get_inputs_field(mode, cfg) + if inputs_field is not None: + fields['inputs'] = inputs_field + + if return_idx: + fields['idx'] = data.IndexField() + + dataset = data.Shapes3dDataset( + dataset_folder, fields, + split=split, + categories=categories, + cfg = cfg + ) + else: + raise ValueError('Invalid dataset "%s"' % cfg['data']['dataset']) + + return dataset + + +def get_inputs_field(mode, cfg): + ''' Returns the inputs fields. + + Args: + mode (str): the mode which is used + cfg (dict): config dictionary + ''' + input_type = cfg['data']['input_type'] + + if input_type is None: + inputs_field = None + elif input_type == 'pointcloud': + noise_level = cfg['data']['pointcloud_noise'] + if cfg['data']['pointcloud_outlier_ratio']>0: + transform = transforms.Compose([ + data.SubsamplePointcloud(cfg['data']['pointcloud_n']), + data.PointcloudNoise(noise_level), + data.PointcloudOutliers(cfg['data']['pointcloud_outlier_ratio']) + ]) + else: + transform = transforms.Compose([ + data.SubsamplePointcloud(cfg['data']['pointcloud_n']), + data.PointcloudNoise(noise_level) + ]) + + data_type = cfg['data']['data_type'] + inputs_field = data.PointCloudField( + cfg['data']['pointcloud_file'], data_type, transform, + multi_files= cfg['data']['multi_files'] + ) + else: + raise ValueError( + 'Invalid input type (%s)' % input_type) + return inputs_field + +def get_data_fields(mode, cfg): + ''' Returns the data fields. + + Args: + mode (str): the mode which is used + cfg (dict): imported yaml config + ''' + data_type = cfg['data']['data_type'] + fields = {} + + if (mode in ('val', 'test')): + transform = data.SubsamplePointcloud(100000) + else: + transform = data.SubsamplePointcloud(cfg['data']['num_gt_points']) + + data_name = cfg['data']['pointcloud_file'] + fields['gt_points'] = data.PointCloudField(data_name, + transform=transform, data_type=data_type, multi_files=cfg['data']['multi_files']) + if data_type == 'psr_full': + if mode != 'test': + fields['gt_psr'] = data.FullPSRField(multi_files=cfg['data']['multi_files']) + else: + raise ValueError('Invalid data type (%s)' % data_type) + + return fields \ No newline at end of file diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000..a7f46c1 --- /dev/null +++ b/src/data/__init__.py @@ -0,0 +1,25 @@ + +from src.data.core import ( + Shapes3dDataset, collate_remove_none, worker_init_fn, collate_stack_together +) +from src.data.fields import ( + IndexField, PointCloudField, FullPSRField +) +from src.data.transforms import ( + PointcloudNoise, SubsamplePointcloud, + PointcloudOutliers, +) +__all__ = [ + # Core + Shapes3dDataset, + collate_remove_none, + worker_init_fn, + # Fields + IndexField, + PointCloudField, + FullPSRField, + # Transforms + PointcloudNoise, + SubsamplePointcloud, + PointcloudOutliers, +] diff --git a/src/data/core.py b/src/data/core.py new file mode 100644 index 0000000..142747c --- /dev/null +++ b/src/data/core.py @@ -0,0 +1,237 @@ +import os +import logging +from torch.utils import data +from pdb import set_trace as st +import numpy as np +import yaml + + +logger = logging.getLogger(__name__) + + +# Fields +class Field(object): + ''' Data fields class. + ''' + + def load(self, data_path, idx, category): + ''' Loads a data point. + + Args: + data_path (str): path to data file + idx (int): index of data point + category (int): index of category + ''' + raise NotImplementedError + + def check_complete(self, files): + ''' Checks if set is complete. + + Args: + files: files + ''' + raise NotImplementedError + + +class Shapes3dDataset(data.Dataset): + ''' 3D Shapes dataset class. + ''' + + def __init__(self, dataset_folder, fields, split=None, + categories=None, no_except=True, transform=None, cfg=None): + ''' Initialization of the the 3D shape dataset. + + Args: + dataset_folder (str): dataset folder + fields (dict): dictionary of fields + split (str): which split is used + categories (list): list of categories to use + no_except (bool): no exception + transform (callable): transformation applied to data points + cfg (yaml): config file + ''' + # Attributes + self.dataset_folder = dataset_folder + self.fields = fields + self.no_except = no_except + self.transform = transform + self.cfg = cfg + + # If categories is None, use all subfolders + if categories is None: + categories = os.listdir(dataset_folder) + categories = [c for c in categories + if os.path.isdir(os.path.join(dataset_folder, c))] + + # Read metadata file + metadata_file = os.path.join(dataset_folder, 'metadata.yaml') + + if os.path.exists(metadata_file): + with open(metadata_file, 'r') as f: + self.metadata = yaml.load(f, Loader=yaml.Loader) + else: + self.metadata = { + c: {'id': c, 'name': 'n/a'} for c in categories + } + + # Set index + for c_idx, c in enumerate(categories): + self.metadata[c]['idx'] = c_idx + + # Get all models + self.models = [] + for c_idx, c in enumerate(categories): + subpath = os.path.join(dataset_folder, c) + if not os.path.isdir(subpath): + logger.warning('Category %s does not exist in dataset.' % c) + + if split is None: + self.models += [ + {'category': c, 'model': m} for m in [d for d in os.listdir(subpath) if (os.path.isdir(os.path.join(subpath, d)) and d != '') ] + ] + + else: + split_file = os.path.join(subpath, split + '.lst') + with open(split_file, 'r') as f: + models_c = f.read().split('\n') + + if '' in models_c: + models_c.remove('') + + self.models += [ + {'category': c, 'model': m} + for m in models_c + ] + + # precompute + self.split = split + + def __len__(self): + ''' Returns the length of the dataset. + ''' + return len(self.models) + + def __getitem__(self, idx): + ''' Returns an item of the dataset. + + Args: + idx (int): ID of data point + ''' + + category = self.models[idx]['category'] + model = self.models[idx]['model'] + c_idx = self.metadata[category]['idx'] + + model_path = os.path.join(self.dataset_folder, category, model) + data = {} + + info = c_idx + + if self.cfg['data']['multi_files'] is not None: + idx = np.random.randint(self.cfg['data']['multi_files']) + if self.split != 'train': + idx = 0 + + for field_name, field in self.fields.items(): + try: + field_data = field.load(model_path, idx, info) + except Exception: + if self.no_except: + logger.warn( + 'Error occured when loading field %s of model %s' + % (field_name, model) + ) + return None + else: + raise + + if isinstance(field_data, dict): + for k, v in field_data.items(): + if k is None: + data[field_name] = v + else: + data['%s.%s' % (field_name, k)] = v + else: + data[field_name] = field_data + + if self.transform is not None: + data = self.transform(data) + + return data + + + def get_model_dict(self, idx): + return self.models[idx] + + def test_model_complete(self, category, model): + ''' Tests if model is complete. + + Args: + model (str): modelname + ''' + model_path = os.path.join(self.dataset_folder, category, model) + files = os.listdir(model_path) + for field_name, field in self.fields.items(): + if not field.check_complete(files): + logger.warn('Field "%s" is incomplete: %s' + % (field_name, model_path)) + return False + + return True + + +def collate_remove_none(batch): + ''' Collater that puts each data field into a tensor with outer dimension + batch size. + + Args: + batch: batch + ''' + + batch = list(filter(lambda x: x is not None, batch)) + return data.dataloader.default_collate(batch) + +def collate_stack_together(batch): + ''' Collater that puts each data field into a tensor with outer dimension + batch size. + + Args: + batch: batch + ''' + + batch = list(filter(lambda x: x is not None, batch)) + keys = batch[0].keys() + concat = {} + if len(batch)>1: + for key in keys: + key_val = [item[key] for item in batch] + concat[key] = np.concatenate(key_val, axis=0) + if key == 'inputs': + n_pts = [item[key].shape[0] for item in batch] + + concat['batch_ind'] = np.concatenate( + [i * np.ones(n, dtype=int) for i, n in enumerate(n_pts)], axis=0) + + return data.dataloader.default_collate([concat]) + else: + n_pts = batch[0]['inputs'].shape[0] + batch[0]['batch_ind'] = np.zeros(n_pts, dtype=int) + return data.dataloader.default_collate(batch) + + +def worker_init_fn(worker_id): + ''' Worker init function to ensure true randomness. + ''' + def set_num_threads(nt): + try: + import mkl; mkl.set_num_threads(nt) + except: + pass + torch.set_num_threads(1) + os.environ['IPC_ENABLE']='1' + for o in ['OPENBLAS_NUM_THREADS','NUMEXPR_NUM_THREADS','OMP_NUM_THREADS','MKL_NUM_THREADS']: + os.environ[o] = str(nt) + + random_data = os.urandom(4) + base_seed = int.from_bytes(random_data, byteorder="big") + np.random.seed(base_seed + worker_id) diff --git a/src/data/fields.py b/src/data/fields.py new file mode 100644 index 0000000..7f78435 --- /dev/null +++ b/src/data/fields.py @@ -0,0 +1,118 @@ +import os +import glob +import time +import random +from PIL import Image +import numpy as np +import trimesh +from src.data.core import Field +from pdb import set_trace as st + + +class IndexField(Field): + ''' Basic index field.''' + def load(self, model_path, idx, category): + ''' Loads the index field. + + Args: + model_path (str): path to model + idx (int): ID of data point + category (int): index of category + ''' + return idx + + def check_complete(self, files): + ''' Check if field is complete. + + Args: + files: files + ''' + return True + +class FullPSRField(Field): + def __init__(self, transform=None, multi_files=None): + self.transform = transform + # self.unpackbits = unpackbits + self.multi_files = multi_files + + def load(self, model_path, idx, category): + + # try: + # t0 = time.time() + if self.multi_files is not None: + psr_path = os.path.join(model_path, 'psr', 'psr_{:02d}.npz'.format(idx)) + else: + psr_path = os.path.join(model_path, 'psr.npz') + psr_dict = np.load(psr_path) + # t1 = time.time() + psr = psr_dict['psr'] + psr = psr.astype(np.float32) + # t2 = time.time() + # print('load PSR: {:.4f}, change type: {:.4f}, total: {:.4f}'.format(t1 - t0, t2 - t1, t2-t0)) + data = {None: psr} + + if self.transform is not None: + data = self.transform(data) + + return data + +class PointCloudField(Field): + ''' Point cloud field. + + It provides the field used for point cloud data. These are the points + randomly sampled on the mesh. + + Args: + file_name (str): file name + transform (list): list of transformations applied to data points + multi_files (callable): number of files + ''' + def __init__(self, file_name, data_type=None, transform=None, multi_files=None, padding=0.1, scale=1.2): + self.file_name = file_name + self.data_type = data_type # to make sure the range of input is correct + self.transform = transform + self.multi_files = multi_files + self.padding = padding + self.scale = scale + + def load(self, model_path, idx, category): + ''' Loads the data point. + + Args: + model_path (str): path to model + idx (int): ID of data point + category (int): index of category + ''' + if self.multi_files is None: + file_path = os.path.join(model_path, self.file_name) + else: + # num = np.random.randint(self.multi_files) + # file_path = os.path.join(model_path, self.file_name, '%s_%02d.npz' % (self.file_name, num)) + file_path = os.path.join(model_path, self.file_name, 'pointcloud_%02d.npz' % (idx)) + + pointcloud_dict = np.load(file_path) + + points = pointcloud_dict['points'].astype(np.float32) + normals = pointcloud_dict['normals'].astype(np.float32) + + data = { + None: points, + 'normals': normals, + } + if self.transform is not None: + data = self.transform(data) + + if self.data_type == 'psr_full': + # scale the point cloud to the range of (0, 1) + data[None] = data[None] / self.scale + 0.5 + + return data + + def check_complete(self, files): + ''' Check if field is complete. + + Args: + files: files + ''' + complete = (self.file_name in files) + return complete diff --git a/src/data/transforms.py b/src/data/transforms.py new file mode 100644 index 0000000..8909594 --- /dev/null +++ b/src/data/transforms.py @@ -0,0 +1,86 @@ +import numpy as np + + +# Transforms +class PointcloudNoise(object): + ''' Point cloud noise transformation class. + + It adds noise to point cloud data. + + Args: + stddev (int): standard deviation + ''' + + def __init__(self, stddev): + self.stddev = stddev + + def __call__(self, data): + ''' Calls the transformation. + + Args: + data (dictionary): data dictionary + ''' + data_out = data.copy() + points = data[None] + noise = self.stddev * np.random.randn(*points.shape) + noise = noise.astype(np.float32) + data_out[None] = points + noise + return data_out + +class PointcloudOutliers(object): + ''' Point cloud outlier transformation class. + + It adds outliers to point cloud data. + + Args: + ratio (int): outlier percentage to the entire point cloud + ''' + + def __init__(self, ratio): + self.ratio = ratio + + def __call__(self, data): + ''' Calls the transformation. + + Args: + data (dictionary): data dictionary + ''' + data_out = data.copy() + points = data[None] + n_points = points.shape[0] + n_outlier_points = int(n_points*self.ratio) + ind = np.random.randint(0, n_points, n_outlier_points) + + outliers = np.random.uniform(-0.55, 0.55, (n_outlier_points, 3)) + outliers = outliers.astype(np.float32) + points[ind] = outliers + data_out[None] = points + return data_out + +class SubsamplePointcloud(object): + ''' Point cloud subsampling transformation class. + + It subsamples the point cloud data. + + Args: + N (int): number of points to be subsampled + ''' + def __init__(self, N): + self.N = N + + def __call__(self, data): + ''' Calls the transformation. + + Args: + data (dict): data dictionary + ''' + data_out = data.copy() + points = data[None] + + indices = np.random.randint(points.shape[0], size=self.N) + data_out[None] = points[indices, :] + if 'normals' in data.keys(): + normals = data['normals'] + data_out['normals'] = normals[indices, :] + + return data_out \ No newline at end of file diff --git a/src/data_loader.py b/src/data_loader.py new file mode 100644 index 0000000..75ea59c --- /dev/null +++ b/src/data_loader.py @@ -0,0 +1,228 @@ +import os +import cv2 +import torch +import numpy as np +from glob import glob +from torch.utils import data +from src.utils import load_rgb, load_mask, get_camera_params +from pytorch3d.renderer import PerspectiveCameras +from skimage import img_as_float32 + +################################################## +# Below are for the differentiable renderer +# Taken from https://github.com/lioryariv/idr/blob/main/code/utils/rend_util.py + +def load_rgb(path): + img = imageio.imread(path) + img = img_as_float32(img) + + # pixel values between [-1,1] + # img -= 0.5 + # img *= 2. + + # img = img.transpose(2, 0, 1) + return img + +def load_mask(path): + alpha = imageio.imread(path, as_gray=True) + alpha = img_as_float32(alpha) + object_mask = alpha > 127.5 + + return object_mask + + +def get_camera_params(uv, pose, intrinsics): + if pose.shape[1] == 7: #In case of quaternion vector representation + cam_loc = pose[:, 4:] + R = quat_to_rot(pose[:,:4]) + p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float() + p[:, :3, :3] = R + p[:, :3, 3] = cam_loc + else: # In case of pose matrix representation + cam_loc = pose[:, :3, 3] + p = pose + + batch_size, num_samples, _ = uv.shape + + depth = torch.ones((batch_size, num_samples)) + x_cam = uv[:, :, 0].view(batch_size, -1) + y_cam = uv[:, :, 1].view(batch_size, -1) + z_cam = depth.view(batch_size, -1) + + pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics) + + # permute for batch matrix product + pixel_points_cam = pixel_points_cam.permute(0, 2, 1) + + world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3] + ray_dirs = world_coords - cam_loc[:, None, :] + ray_dirs = F.normalize(ray_dirs, dim=2) + + return ray_dirs, cam_loc + +def quat_to_rot(q): + batch_size, _ = q.shape + q = F.normalize(q, dim=1) + R = torch.ones((batch_size, 3,3)).cuda() + qr=q[:,0] + qi = q[:, 1] + qj = q[:, 2] + qk = q[:, 3] + R[:, 0, 0]=1-2 * (qj**2 + qk**2) + R[:, 0, 1] = 2 * (qj *qi -qk*qr) + R[:, 0, 2] = 2 * (qi * qk + qr * qj) + R[:, 1, 0] = 2 * (qj * qi + qk * qr) + R[:, 1, 1] = 1-2 * (qi**2 + qk**2) + R[:, 1, 2] = 2*(qj*qk - qi*qr) + R[:, 2, 0] = 2 * (qk * qi-qj * qr) + R[:, 2, 1] = 2 * (qj*qk + qi*qr) + R[:, 2, 2] = 1-2 * (qi**2 + qj**2) + return R + +def lift(x, y, z, intrinsics): + # parse intrinsics + # intrinsics = intrinsics.cuda() + fx = intrinsics[:, 0, 0] + fy = intrinsics[:, 1, 1] + cx = intrinsics[:, 0, 2] + cy = intrinsics[:, 1, 2] + sk = intrinsics[:, 0, 1] + + x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z + y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z + + # homogeneous + return torch.stack((x_lift, y_lift, z, torch.ones_like(z)), dim=-1) + + +class PixelNeRFDTUDataset(data.Dataset): + """ + Processed DTU from pixelNeRF + """ + def __init__(self, + data_dir='data/DTU', + scan_id=65, + img_size=None, + device=None, + fixed_scale=0, + ): + data_dir = os.path.join(data_dir, "scan{}".format(scan_id)) + rgb_paths = [ + x for x in glob(os.path.join(data_dir, "image", "*")) + if (x.endswith(".jpg") or x.endswith(".png")) + ] + rgb_paths = sorted(rgb_paths) + mask_paths = sorted(glob(os.path.join(data_dir, "mask", "*.png"))) + if len(mask_paths) == 0: + mask_paths = [None] * len(rgb_paths) + sel_indices = np.arange(len(rgb_paths)) + + cam_path = os.path.join(data_dir, "cameras.npz") + all_cam = np.load(cam_path) + all_imgs = [] + all_poses = [] + all_masks = [] + all_rays = [] + all_light_pose = [] + all_K = [] + all_R = [] + all_T = [] + + for idx, (rgb_path, mask_path) in enumerate(zip(rgb_paths, mask_paths)): + + i = sel_indices[idx] + rgb = load_rgb(rgb_path) + mask = load_mask(mask_path) + rgb[~mask] = 0. + rgb = torch.from_numpy(rgb).float().to(device) + mask = torch.from_numpy(mask).float().to(device) + x_scale = y_scale = 1.0 + xy_delta = 0.0 + + P = all_cam["world_mat_" + str(i)] + P = P[:3] + + # scale the original shape to really [-0.9, 0.9] + if fixed_scale!=0.: + scale_mat_new = np.eye(4, 4) + scale_mat_new[:3, :3] *= fixed_scale # scale to [-0.9, 0.9] + P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)] @ scale_mat_new + else: + P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)] + + P = P[:3, :4] + K, R, t = cv2.decomposeProjectionMatrix(P)[:3] + K = K / K[2, 2] + + fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] + + ########!!!!! + RR = torch.from_numpy(R).permute(1, 0).unsqueeze(0) + tt = torch.from_numpy(-R@(t[:3] / t[3])).permute(1, 0) + focal = torch.tensor((fx, fy), dtype=torch.float32).unsqueeze(0) + pc = torch.tensor((cx, cy), dtype=torch.float32).unsqueeze(0) + im_size = (rgb.shape[1], rgb.shape[0]) + + # check https://pytorch3d.org/docs/cameras for how to transform from screen to NDC + s = min(im_size) + focal[:, 0] = focal[:, 0] * 2 / (s-1) + focal[:, 1] = focal[:, 1] * 2 /(s-1) + pc[:, 0] = -(pc[:, 0] - (im_size[0]-1)/2) * 2 / (s-1) + pc[:, 1] = -(pc[:, 1] - (im_size[1]-1)/2) * 2 / (s-1) + + camera = PerspectiveCameras(focal_length=-focal, principal_point=pc, + device=device, R=RR, T=tt) + + # calculate camera rays + uv = uv_creation(im_size)[None].float() + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() + pose[:3,3] = (t[:3] / t[3])[:,0] + pose = torch.from_numpy(pose)[None].float() + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + intrinsics[0, 1] = 0. #! remove skew for now + intrinsics = torch.from_numpy(intrinsics)[None].float() + + + rays, _ = get_camera_params(uv, pose, intrinsics) + rays = -rays.to(device) + + + + all_poses.append(camera) + all_imgs.append(rgb) + all_masks.append(mask) + all_rays.append(rays) + all_light_pose.append(pose) + # only for neural renderer + all_K.append(torch.tensor(K).to(device)) + all_R.append(torch.tensor(R).to(device)) + all_T.append(torch.tensor(t[:3]/t[3]).to(device)) + + all_imgs = torch.stack(all_imgs) + all_masks = torch.stack(all_masks) + all_rays = torch.stack(all_rays) + all_light_pose = torch.stack(all_light_pose).squeeze() + # only for neural renderer + all_K = torch.stack(all_K).float() + all_R = torch.stack(all_R).float() + all_T = torch.stack(all_T).permute(0, 2, 1).float() + + uv = uv_creation((all_imgs.size(2), all_imgs.size(1))) + self.data = {'rgbs': all_imgs, + 'masks': all_masks, + 'poses': all_poses, + 'rays': all_rays, + 'uv': uv, + 'light_pose': all_light_pose, # for rendering lights + 'K': all_K, + 'R': all_R, + 'T': all_T, + } + + def __len__(self): + return 1 + + def __getitem__(self, idx): + return self.data \ No newline at end of file diff --git a/src/dpsr.py b/src/dpsr.py new file mode 100644 index 0000000..cbb5c9b --- /dev/null +++ b/src/dpsr.py @@ -0,0 +1,75 @@ + +import torch +import torch.nn as nn +from src.utils import spec_gaussian_filter, fftfreqs, img, grid_interp, point_rasterize +import numpy as np +import torch.fft + +class DPSR(nn.Module): + def __init__(self, res, sig=10, scale=True, shift=True): + """ + :param res: tuple of output field resolution. eg., (128,128) + :param sig: degree of gaussian smoothing + """ + super(DPSR, self).__init__() + self.res = res + self.sig = sig + self.dim = len(res) + self.denom = np.prod(res) + G = spec_gaussian_filter(res=res, sig=sig).float() + G = G + # self.G.requires_grad = False # True, if we also make sig a learnable parameter + self.omega = fftfreqs(res, dtype=torch.float32) + self.scale = scale + self.shift = shift + self.register_buffer("G", G) + + def forward(self, V, N): + """ + :param V: (batch, nv, 2 or 3) tensor for point cloud coordinates + :param N: (batch, nv, 2 or 3) tensor for point normals + :return phi: (batch, res, res, ...) tensor of output indicator function field + """ + assert(V.shape == N.shape) # [b, nv, ndims] + ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2] + + #!!! OLD + # ras_s = torch.rfft(ras_p, signal_ndim=self.dim) # [b, n_dim, dim0, dim1, dim2/2+1, 2] + # ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+2))+[1, self.dim+2])) + # N_ = (ras_s * self.G) # [b, n_dim, dim0, dim1, dim2/2+1, 2] + + ras_s = torch.fft.rfftn(ras_p, dim=(2,3,4)) + ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+1))+[self.dim+1, 1])) + N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1] + + omega = fftfreqs(self.res, dtype=torch.float32).unsqueeze(-1) # [dim0, dim1, dim2/2+1, n_dim, 1] + omega *= 2 * np.pi # normalize frequencies + omega = omega.to(V.device) + + # DivN = torch.sum(-img(N_) * omega, dim=-2) #!!! OLD [b, dim0, dim1, dim2/2+1, 2] + DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2) + + Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1] + Phi = DivN / (Lap+1e-6) # [b, dim0, dim1, dim2/2+1, 2] + Phi = Phi.permute(*tuple([list(range(1,self.dim+2)) + [0]])) # [dim0, dim1, dim2/2+1, 2, b] + Phi[tuple([0] * self.dim)] = 0 + Phi = Phi.permute(*tuple([[self.dim+1] + list(range(self.dim+1))])) # [b, dim0, dim1, dim2/2+1, 2] + + # phi = torch.irfft(Phi, signal_ndim=self.dim, signal_sizes=self.res)#!!! OLD [b, dim0, dim1, dim2] + phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1,2,3)) + + if self.shift or self.scale: + # ensure values at points are zero + fv = grid_interp(phi.unsqueeze(-1), V, batched=True).squeeze(-1) # [b, nv] + if self.shift: # offset points to have mean of 0 + offset = torch.mean(fv, dim=-1) # [b,] + phi -= offset.view(*tuple([-1] + [1] * self.dim)) + + phi = phi.permute(*tuple([list(range(1,self.dim+1)) + [0]])) + fv0 = phi[tuple([0] * self.dim)] # [b,] + phi = phi.permute(*tuple([[self.dim] + list(range(self.dim))])) + + if self.scale: + # phi = phi / fv0.view(*tuple([-1] + [1] * self.dim)) * 0.5 + phi = -phi / torch.abs(fv0.view(*tuple([-1]+[1] * self.dim))) *0.5 + return phi \ No newline at end of file diff --git a/src/eval.py b/src/eval.py new file mode 100644 index 0000000..5ce65fe --- /dev/null +++ b/src/eval.py @@ -0,0 +1,168 @@ +import logging +import numpy as np +import trimesh +from pykdtree.kdtree import KDTree + +EMPTY_PCL_DICT = { + 'completeness': np.sqrt(3), + 'accuracy': np.sqrt(3), + 'completeness2': 3, + 'accuracy2': 3, + 'chamfer': 6, +} + +EMPTY_PCL_DICT_NORMALS = { + 'normals completeness': -1., + 'normals accuracy': -1., + 'normals': -1., +} + +logger = logging.getLogger(__name__) + + +class MeshEvaluator(object): + ''' Mesh evaluation class. + It handles the mesh evaluation process. + Args: + n_points (int): number of points to be used for evaluation + ''' + + def __init__(self, n_points=100000): + self.n_points = n_points + + def eval_mesh(self, mesh, pointcloud_tgt, normals_tgt, thresholds=np.linspace(1./1000, 1, 1000)): + ''' Evaluates a mesh. + Args: + mesh (trimesh): mesh which should be evaluated + pointcloud_tgt (numpy array): target point cloud + normals_tgt (numpy array): target normals + thresholds (numpy arry): for F-Score + ''' + if len(mesh.vertices) != 0 and len(mesh.faces) != 0: + pointcloud, idx = mesh.sample(self.n_points, return_index=True) + + pointcloud = pointcloud.astype(np.float32) + normals = mesh.face_normals[idx] + else: + pointcloud = np.empty((0, 3)) + normals = np.empty((0, 3)) + + out_dict = self.eval_pointcloud( + pointcloud, pointcloud_tgt, normals, normals_tgt, thresholds=thresholds) + + return out_dict + + def eval_pointcloud(self, pointcloud, pointcloud_tgt, + normals=None, normals_tgt=None, + thresholds=np.linspace(1./1000, 1, 1000)): + ''' Evaluates a point cloud. + Args: + pointcloud (numpy array): predicted point cloud + pointcloud_tgt (numpy array): target point cloud + normals (numpy array): predicted normals + normals_tgt (numpy array): target normals + thresholds (numpy array): threshold values for the F-score calculation + ''' + # Return maximum losses if pointcloud is empty + if pointcloud.shape[0] == 0: + logger.warn('Empty pointcloud / mesh detected!') + out_dict = EMPTY_PCL_DICT.copy() + if normals is not None and normals_tgt is not None: + out_dict.update(EMPTY_PCL_DICT_NORMALS) + return out_dict + + pointcloud = np.asarray(pointcloud) + pointcloud_tgt = np.asarray(pointcloud_tgt) + + + # Completeness: how far are the points of the target point cloud + # from thre predicted point cloud + completeness, completeness_normals = distance_p2p( + pointcloud_tgt, normals_tgt, pointcloud, normals + ) + recall = get_threshold_percentage(completeness, thresholds) + completeness2 = completeness**2 + + completeness = completeness.mean() + completeness2 = completeness2.mean() + completeness_normals = completeness_normals.mean() + + # Accuracy: how far are th points of the predicted pointcloud + # from the target pointcloud + accuracy, accuracy_normals = distance_p2p( + pointcloud, normals, pointcloud_tgt, normals_tgt + ) + precision = get_threshold_percentage(accuracy, thresholds) + accuracy2 = accuracy**2 + + accuracy = accuracy.mean() + accuracy2 = accuracy2.mean() + accuracy_normals = accuracy_normals.mean() + + # Chamfer distance + chamferL2 = 0.5 * (completeness2 + accuracy2) + normals_correctness = ( + 0.5 * completeness_normals + 0.5 * accuracy_normals + ) + chamferL1 = 0.5 * (completeness + accuracy) + + # F-Score + F = [ + 2 * precision[i] * recall[i] / (precision[i] + recall[i]) + for i in range(len(precision)) + ] + + out_dict = { + 'completeness': completeness, + 'accuracy': accuracy, + 'normals completeness': completeness_normals, + 'normals accuracy': accuracy_normals, + 'normals': normals_correctness, + 'completeness2': completeness2, + 'accuracy2': accuracy2, + 'chamfer-L2': chamferL2, + 'chamfer-L1': chamferL1, + 'f-score': F[9], # threshold = 1.0% + 'f-score-15': F[14], # threshold = 1.5% + 'f-score-20': F[19], # threshold = 2.0% + } + + return out_dict + + +def distance_p2p(points_src, normals_src, points_tgt, normals_tgt): + ''' Computes minimal distances of each point in points_src to points_tgt. + Args: + points_src (numpy array): source points + normals_src (numpy array): source normals + points_tgt (numpy array): target points + normals_tgt (numpy array): target normals + ''' + kdtree = KDTree(points_tgt) + dist, idx = kdtree.query(points_src) + + if normals_src is not None and normals_tgt is not None: + normals_src = \ + normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True) + normals_tgt = \ + normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True) + + normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1) + # Handle normals that point into wrong direction gracefully + # (mostly due to mehtod not caring about this in generation) + normals_dot_product = np.abs(normals_dot_product) + else: + normals_dot_product = np.array( + [np.nan] * points_src.shape[0], dtype=np.float32) + return dist, normals_dot_product + +def get_threshold_percentage(dist, thresholds): + ''' Evaluates a point cloud. + Args: + dist (numpy array): calculated distance + thresholds (numpy array): threshold values for the F-score calculation + ''' + in_threshold = [ + (dist <= t).mean() for t in thresholds + ] + return in_threshold \ No newline at end of file diff --git a/src/generation.py b/src/generation.py new file mode 100644 index 0000000..9abbe6d --- /dev/null +++ b/src/generation.py @@ -0,0 +1,63 @@ +import torch +import time +import trimesh +import numpy as np +from src.utils import mc_from_psr + +class Generator3D(object): + ''' Generator class for Occupancy Networks. + + It provides functions to generate the final mesh as well refining options. + + Args: + model (nn.Module): trained Occupancy Network model + points_batch_size (int): batch size for points evaluation + threshold (float): threshold value + device (device): pytorch device + padding (float): how much padding should be used for MISE + sample (bool): whether z should be sampled + input_type (str): type of input + ''' + + def __init__(self, model, points_batch_size=100000, + threshold=0.5, device=None, padding=0.1, + sample=False, input_type = None, dpsr=None, psr_tanh=True): + self.model = model.to(device) + self.points_batch_size = points_batch_size + self.threshold = threshold + self.device = device + self.input_type = input_type + self.padding = padding + self.sample = sample + self.dpsr = dpsr + self.psr_tanh = psr_tanh + + def generate_mesh(self, data, return_stats=True): + ''' Generates the output mesh. + + Args: + data (tensor): data tensor + return_stats (bool): whether stats should be returned + ''' + self.model.eval() + device = self.device + stats_dict = {} + + p = data.get('inputs', torch.empty(1, 0)).to(device) + + t0 = time.time() + points, normals = self.model(p) + t1 = time.time() + psr_grid = self.dpsr(points, normals) + t2 = time.time() + v, f, _ = mc_from_psr(psr_grid, + zero_level=self.threshold) + stats_dict['pcl'] = t1 - t0 + stats_dict['dpsr'] = t2 - t1 + stats_dict['mc'] = time.time() - t2 + stats_dict['total'] = time.time() - t0 + + if return_stats: + return v, f, points, normals, stats_dict + else: + return v, f, points, normals \ No newline at end of file diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000..a2e978f --- /dev/null +++ b/src/model.py @@ -0,0 +1,181 @@ +import torch +import numpy as np +import time +from src.utils import point_rasterize, grid_interp, mc_from_psr, \ +calc_inters_points +from src.dpsr import DPSR +import torch.nn as nn +from src.network import encoder_dict, decoder_dict +from src.network.utils import map2local + +class PSR2Mesh(torch.autograd.Function): + @staticmethod + def forward(ctx, psr_grid): + """ + In the forward pass we receive a Tensor containing the input and return + a Tensor containing the output. ctx is a context object that can be used + to stash information for backward computation. You can cache arbitrary + objects for use in the backward pass using the ctx.save_for_backward method. + """ + verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True) + verts = verts.unsqueeze(0) + faces = faces.unsqueeze(0) + normals = normals.unsqueeze(0) + + res = torch.tensor(psr_grid.detach().shape[2]) + ctx.save_for_backward(verts, normals, res) + + return verts, faces, normals + + @staticmethod + def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals): + """ + In the backward pass we receive a Tensor containing the gradient of the loss + with respect to the output, and we need to compute the gradient of the loss + with respect to the input. + """ + vert_pts, normals, res = ctx.saved_tensors + res = (res.item(), res.item(), res.item()) + # matrix multiplication between dL/dV and dV/dPSR + # dV/dPSR = - normals + grad_vert = torch.matmul(dL_dVertex.permute(1, 0, 2), -normals.permute(1, 2, 0)) + grad_grid = point_rasterize(vert_pts, grad_vert.permute(1, 0, 2), res) # b x 1 x res x res x res + + return grad_grid + +class PSR2SurfacePoints(torch.autograd.Function): + @staticmethod + def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample): + verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True) + verts = verts * 2. - 1. # within the range of [-1, 1] + + + p_all, n_all, mask_all = [], [], [] + + for i in range(len(poses)): + pose = poses[i] + if mask_sample is not None: + p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size, mask_gt=mask_sample[i]) + else: + p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size) + + n_inters = grid_interp(psr_grad[None], (p_inters[None].detach() + 1) / 2).squeeze() + p_all.append(p_inters) + n_all.append(n_inters) + mask_all.append(mask) + p_inters_all = torch.cat(p_all, dim=0) + n_inters_all = torch.cat(n_all, dim=0) + mask_visible = torch.stack(mask_all, dim=0) + + + res = torch.tensor(psr_grid.detach().shape[2]) + ctx.save_for_backward(p_inters_all, n_inters_all, res) + + return p_inters_all, mask_visible + + @staticmethod + def backward(ctx, dL_dp, dL_dmask): + pts, pts_n, res = ctx.saved_tensors + res = (res.item(), res.item(), res.item()) + + # grad from the p_inters via MLP renderer + grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None]) + grad_grid_pts = point_rasterize((pts[None]+1)/2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res + + return grad_grid_pts, None, None, None, None, None + +class Encode2Points(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + + encoder = cfg['model']['encoder'] + decoder = cfg['model']['decoder'] + dim = cfg['data']['dim'] # input dim + c_dim = cfg['model']['c_dim'] + encoder_kwargs = cfg['model']['encoder_kwargs'] + if encoder_kwargs == None: + encoder_kwargs = {} + decoder_kwargs = cfg['model']['decoder_kwargs'] + padding = cfg['data']['padding'] + self.predict_normal = cfg['model']['predict_normal'] + self.predict_offset = cfg['model']['predict_offset'] + + out_dim = 3 + out_dim_offset = 3 + num_offset = cfg['data']['num_offset'] + # each point predict more than one offset to add output points + if num_offset > 1: + out_dim_offset = out_dim * num_offset + self.num_offset = num_offset + + # local mapping + self.map2local = None + if cfg['model']['local_coord']: + if 'unet' in encoder_kwargs.keys(): + unit_size = 1 / encoder_kwargs['plane_resolution'] + else: + unit_size = 1 / encoder_kwargs['grid_resolution'] + + local_mapping = map2local(unit_size) + + self.encoder = encoder_dict[encoder]( + dim=dim, c_dim=c_dim, map2local=local_mapping, + **encoder_kwargs + ) + + if self.predict_normal: + # decoder for normal prediction + self.decoder_normal = decoder_dict[decoder]( + dim=dim, c_dim=c_dim, out_dim=out_dim, + **decoder_kwargs) + if self.predict_offset: + # decoder for offset prediction + self.decoder_offset = decoder_dict[decoder]( + dim=dim, c_dim=c_dim, out_dim=out_dim_offset, + map2local=local_mapping, + **decoder_kwargs) + + self.s_off = cfg['model']['s_offset'] + + + def forward(self, p): + ''' Performs a forward pass through the network. + + Args: + p (tensor): input unoriented points + ''' + + time_dict = {} + mask = None + + batch_size = p.size(0) + points = p.clone() + + # encode the input point cloud to a feature volume + t0 = time.perf_counter() + c = self.encoder(p) + t1 = time.perf_counter() + if self.predict_offset: + offset = self.decoder_offset(p, c) + # more than one offset is predicted per-point + if self.num_offset > 1: + points = points.repeat(1, 1, self.num_offset).reshape(batch_size, -1, 3) + points = points + self.s_off * offset + else: + points = p + + if self.predict_normal: + normals = self.decoder_normal(points, c) + t2 = time.perf_counter() + + time_dict['encode'] = t1 - t0 + time_dict['predict'] = t2 - t1 + + points = torch.clamp(points, 0.0, 0.99) + if self.cfg['model']['normal_normalize']: + normals = normals / (normals.norm(dim=-1, keepdim=True)+1e-8) + + + return points, normals + \ No newline at end of file diff --git a/src/model_rgb.py b/src/model_rgb.py new file mode 100644 index 0000000..79f20a1 --- /dev/null +++ b/src/model_rgb.py @@ -0,0 +1,139 @@ +import torch +from src.network.net_rgb import RenderingNetwork +from src.utils import approx_psr_grad +from pytorch3d.renderer import ( + RasterizationSettings, + PerspectiveCameras, + MeshRenderer, + MeshRasterizer, + SoftSilhouetteShader) +from pytorch3d.structures import Meshes + + +def approx_psr_grad(psr_grid, res, normalize=True): + delta_x = delta_y = delta_z = 1/res + psr_pad = torch.nn.ReplicationPad3d(1)(psr_grid).squeeze() + + grad_x = (psr_pad[2:, :, :] - psr_pad[:-2, :, :]) / 2 / delta_x + grad_y = (psr_pad[:, 2:, :] - psr_pad[:, :-2, :]) / 2 / delta_y + grad_z = (psr_pad[:, :, 2:] - psr_pad[:, :, :-2]) / 2 / delta_z + grad_x = grad_x[:, 1:-1, 1:-1] + grad_y = grad_y[1:-1, :, 1:-1] + grad_z = grad_z[1:-1, 1:-1, :] + + psr_grad = torch.stack([grad_x, grad_y, grad_z], dim=3) # [res_x, res_y, res_z, 3] + if normalize: + psr_grad = psr_grad / (psr_grad.norm(dim=3, keepdim=True) + 1e-12) + + return psr_grad + + +class SAP2Image(nn.Module): + def __init__(self, cfg, img_size): + super().__init__() + + self.psr2sur = PSR2SurfacePoints.apply + self.psr2mesh = PSR2Mesh.apply + # initialize DPSR + self.dpsr = DPSR(res=(cfg['model']['grid_res'], + cfg['model']['grid_res'], + cfg['model']['grid_res']), + sig=cfg['model']['psr_sigma']) + self.cfg = cfg + if cfg['train']['l_weight']['rgb'] != 0.: + self.rendering_network = RenderingNetwork(**cfg['model']['renderer']) + + if cfg['train']['l_weight']['mask'] != 0.: + # initialize rasterizer + sigma = 1e-4 + raster_settings_soft = RasterizationSettings( + image_size=img_size, + blur_radius=np.log(1. / 1e-4 - 1.)*sigma, + faces_per_pixel=150, + perspective_correct=False + ) + + # initialize silhouette renderer + self.mesh_rasterizer = MeshRenderer( + rasterizer=MeshRasterizer( + raster_settings=raster_settings_soft + ), + shader=SoftSilhouetteShader() + ) + + self.cfg = cfg + self.img_size = img_size + + def forward(self, inputs, data): + points, normals = inputs[...,:3], inputs[...,3:] + points = torch.sigmoid(points) + normals = normals / normals.norm(dim=-1, keepdim=True) + + # DPSR to get grid + psr_grid = self.dpsr(points, normals).unsqueeze(1) + psr_grid = torch.tanh(psr_grid) + + return self.render_img(psr_grid, data) + + def render_img(self, psr_grid, data): + + n_views = len(data['masks']) + n_views_per_iter = self.cfg['data']['n_views_per_iter'] + + rgb_render_mode = self.cfg['model']['renderer']['mode'] + uv = data['uv'] + + idx = np.random.randint(0, n_views, n_views_per_iter) + pose = [data['poses'][i] for i in idx] + rgb = data['rgbs'][idx] + mask_gt = data['masks'][idx] + ray = None + pred_rgb = None + pred_mask = None + + if self.cfg['train']['l_weight']['rgb'] != 0.: + psr_grad = approx_psr_grad(psr_grid, self.cfg['model']['grid_res']) + p_inters, visible_mask = self.psr2sur(psr_grid, pose, self.img_size, uv, psr_grad, None) + n_inters = grid_interp(psr_grad[None], (p_inters.detach()[None] + 1) / 2) + fea_interp = None + if 'rays' in data.keys(): + ray = data['rays'].squeeze()[idx][visible_mask] + pred_rgb = self.rendering_network(p_inters, normals=n_inters.squeeze(), view_dirs=ray, feature_vectors=fea_interp) + + # silhouette loss + if self.cfg['train']['l_weight']['mask'] != 0.: + # build mesh + v, f, _ = self.psr2mesh(psr_grid) + v = v * 2. - 1 # within the range of [-1, 1] + # ! Fast but more GPU usage + mesh = Meshes(verts=[v.squeeze()], faces=[f.squeeze()]) + if True: + #! PyTorch3D silhouette loss + # build pose + R = torch.cat([p.R for p in pose], dim=0) + T = torch.cat([p.T for p in pose], dim=0) + focal = torch.cat([p.focal_length for p in pose], dim=0) + pp = torch.cat([p.principal_point for p in pose], dim=0) + pose_cur = PerspectiveCameras( + focal_length=focal, + principal_point=pp, + R=R, T=T, + device=R.device) + pred_mask = self.mesh_rasterizer(mesh.extend(n_views_per_iter), cameras=pose_cur)[..., 3] + else: + pred_mask = [] + # ! Slow but less GPU usage + for i in range(n_views_per_iter): + #! PyTorch3D silhouette loss + pred_mask.append(self.mesh_rasterizer(mesh, cameras=pose[i])[..., 3]) + pred_mask = torch.cat(pred_mask, dim=0) + + output = { + 'rgb': pred_rgb, + 'rgb_gt': rgb, + 'mask': pred_mask, + 'mask_gt': mask_gt, + 'vis_mask': visible_mask, + } + + return output \ No newline at end of file diff --git a/src/network/__init__.py b/src/network/__init__.py new file mode 100644 index 0000000..f6a0b61 --- /dev/null +++ b/src/network/__init__.py @@ -0,0 +1,8 @@ +from src.network import encoder, decoder + +encoder_dict = { + 'local_pool_pointnet': encoder.LocalPoolPointnet, +} +decoder_dict = { + 'simple_local': decoder.LocalDecoder, +} \ No newline at end of file diff --git a/src/network/decoder.py b/src/network/decoder.py new file mode 100644 index 0000000..5fc7bd1 --- /dev/null +++ b/src/network/decoder.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ipdb import set_trace as st +from src.network.utils import normalize_3d_coordinate, ResnetBlockFC, \ + normalize_coordinate + + +class LocalDecoder(nn.Module): + ''' Decoder. + Instead of conditioning on global features, on plane/volume local features. + Args: + dim (int): input dimension + c_dim (int): dimension of latent conditioned code c + hidden_size (int): hidden size of Decoder network + n_blocks (int): number of blocks ResNetBlockFC layers + leaky (bool): whether to use leaky ReLUs + sample_mode (str): sampling feature strategy, bilinear|nearest + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + ''' + + def __init__(self, dim=3, c_dim=128, out_dim=3, + hidden_size=256, n_blocks=5, leaky=False, sample_mode='bilinear', padding=0.1, map2local=None): + super().__init__() + self.c_dim = c_dim + self.n_blocks = n_blocks + + if c_dim != 0: + self.fc_c = nn.ModuleList([ + nn.Linear(c_dim, hidden_size) for i in range(n_blocks) + ]) + + + self.fc_p = nn.Linear(dim, hidden_size) + + self.blocks = nn.ModuleList([ + ResnetBlockFC(hidden_size) for i in range(n_blocks) + ]) + + self.fc_out = nn.Linear(hidden_size, out_dim) + + if not leaky: + self.actvn = F.relu + else: + self.actvn = lambda x: F.leaky_relu(x, 0.2) + + self.sample_mode = sample_mode + self.padding = padding + self.map2local = map2local + self.out_dim = out_dim + + + def sample_plane_feature(self, p, c, plane='xz'): + xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1) + xy = xy[:, :, None].float() + vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1) + c = F.grid_sample(c, vgrid, padding_mode='border', + align_corners=True, + mode=self.sample_mode).squeeze(-1) + return c + + def sample_grid_feature(self, p, c): + p_nor = normalize_3d_coordinate(p.clone()) + p_nor = p_nor[:, :, None, None].float() + vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1) + # acutally trilinear interpolation if mode = 'bilinear' + c = F.grid_sample(c, vgrid, padding_mode='border', + align_corners=True, + mode=self.sample_mode).squeeze(-1).squeeze(-1) + return c + + def forward(self, p, c_plane, **kwargs): + batch_size = p.shape[0] + plane_type = list(c_plane.keys()) + c = 0 + if 'grid' in plane_type: + c += self.sample_grid_feature(p, c_plane['grid']) + if 'xz' in plane_type: + c += self.sample_plane_feature(p, c_plane['xz'], plane='xz') + if 'xy' in plane_type: + c += self.sample_plane_feature(p, c_plane['xy'], plane='xy') + if 'yz' in plane_type: + c += self.sample_plane_feature(p, c_plane['yz'], plane='yz') + c = c.transpose(1, 2) + + p = p.float() + + if self.map2local: + p = self.map2local(p) + + net = self.fc_p(p) + + for i in range(self.n_blocks): + if self.c_dim != 0: + net = net + self.fc_c[i](c) + + net = self.blocks[i](net) + + out = self.fc_out(self.actvn(net)) + + + if self.out_dim > 3: + out = out.reshape(batch_size, -1, 3) + + return out \ No newline at end of file diff --git a/src/network/encoder.py b/src/network/encoder.py new file mode 100644 index 0000000..4385a9f --- /dev/null +++ b/src/network/encoder.py @@ -0,0 +1,181 @@ +import torch +import torch.nn as nn +import numpy as np +from src.network.unet3d import UNet3D +from src.network.unet import UNet +from ipdb import set_trace as st +from torch_scatter import scatter_mean, scatter_max +from src.network.utils import get_embedder, normalize_3d_coordinate,\ +coordinate2index, ResnetBlockFC, normalize_coordinate + +class LocalPoolPointnet(nn.Module): + ''' PointNet-based encoder network with ResNet blocks for each point. + Number of input points are fixed. + + Args: + c_dim (int): dimension of latent code c + dim (int): input points dimension + hidden_dim (int): hidden dimension of the network + scatter_type (str): feature aggregation when doing local pooling + unet (bool): weather to use U-Net + unet_kwargs (str): U-Net parameters + unet3d (bool): weather to use 3D U-Net + unet3d_kwargs (str): 3D U-Net parameters + plane_resolution (int): defined resolution for plane feature + grid_resolution (int): defined resolution for grid feature + plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + n_blocks (int): number of blocks ResNetBlockFC layers + map2local (function): map global coordintes to local ones + pos_encoding (int): frequency for the positional encoding + ''' + + def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max', + unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None, + plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5, + map2local=None, pos_encoding=0): + super().__init__() + + self.c_dim = c_dim + + self.fc_pos = nn.Linear(dim, 2*hidden_dim) + self.blocks = nn.ModuleList([ + ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks) + ]) + self.fc_c = nn.Linear(hidden_dim, c_dim) + + self.actvn = nn.ReLU() + self.hidden_dim = hidden_dim + + self.unet = None + if unet: + self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs) + + self.unet3d = None + if unet3d: + self.unet3d = UNet3D(**unet3d_kwargs) + + self.reso_plane = plane_resolution + self.reso_grid = grid_resolution + self.plane_type = plane_type + self.padding = padding + + self.fc_pos = nn.Linear(dim, 2*hidden_dim) + self.pe = None + if pos_encoding > 0: + embed_fn, input_ch = get_embedder(pos_encoding, d_in=dim) + self.pe = embed_fn + self.fc_pos = nn.Linear(input_ch, 2*hidden_dim) + + self.map2local = map2local + + + if scatter_type == 'max': + self.scatter = scatter_max + elif scatter_type == 'mean': + self.scatter = scatter_mean + else: + raise ValueError('incorrect scatter type') + + def generate_plane_features(self, p, c, plane='xz'): + # acquire indices of features in plane + xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1) + index = coordinate2index(xy, self.reso_plane) + + # scatter plane features from points + fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2) + c = c.permute(0, 2, 1) # B x 512 x T + fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2 + fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso) + + # process the plane features with UNet + if self.unet is not None: + fea_plane = self.unet(fea_plane) + + return fea_plane + + def generate_grid_features(self, p, c): + p_nor = normalize_3d_coordinate(p.clone()) + index = coordinate2index(p_nor, self.reso_grid, coord_type='3d') + # scatter grid features from points + fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3) + c = c.permute(0, 2, 1) + fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3 + fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) # sparce matrix (B x 512 x reso x reso) + + if self.unet3d is not None: + fea_grid = self.unet3d(fea_grid) + + return fea_grid + + def pool_local(self, xy, index, c): + bs, fea_dim = c.size(0), c.size(2) + keys = xy.keys() + + c_out = 0 + for key in keys: + # scatter plane features from points + if key == 'grid': + fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid**3) + else: + fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2) + if self.scatter == scatter_max: + fea = fea[0] + # gather feature back to points + fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1)) + c_out += fea + return c_out.permute(0, 2, 1) + + + def forward(self, p, normalize=True): + batch_size, T, D = p.size() + + # acquire the index for each point + coord = {} + index = {} + if 'xz' in self.plane_type: + coord['xz'] = normalize_coordinate(p.clone(), plane='xz') + index['xz'] = coordinate2index(coord['xz'], self.reso_plane) + if 'xy' in self.plane_type: + coord['xy'] = normalize_coordinate(p.clone(), plane='xy') + index['xy'] = coordinate2index(coord['xy'], self.reso_plane) + if 'yz' in self.plane_type: + coord['yz'] = normalize_coordinate(p.clone(), plane='yz') + index['yz'] = coordinate2index(coord['yz'], self.reso_plane) + if 'grid' in self.plane_type: + if normalize: + coord['grid'] = normalize_3d_coordinate(p.clone()) + else: + coord['grid'] = p.clone()[...,:3] + index['grid'] = coordinate2index(coord['grid'], self.reso_grid, coord_type='3d') + + + if self.pe: + p = self.pe(p) + + if self.map2local: + pp = self.map2local(p) + net = self.fc_pos(pp) + else: + net = self.fc_pos(p) + # net = self.fc_pos(p) + + net = self.blocks[0](net) + for block in self.blocks[1:]: + pooled = self.pool_local(coord, index, net) + net = torch.cat([net, pooled], dim=2) + net = block(net) + + c = self.fc_c(net) + + fea = {} + if 'grid' in self.plane_type: + fea['grid'] = self.generate_grid_features(p, c) + if 'xz' in self.plane_type: + fea['xz'] = self.generate_plane_features(p, c, plane='xz') + if 'xy' in self.plane_type: + fea['xy'] = self.generate_plane_features(p, c, plane='xy') + if 'yz' in self.plane_type: + fea['yz'] = self.generate_plane_features(p, c, plane='yz') + + return fea \ No newline at end of file diff --git a/src/network/net_rgb.py b/src/network/net_rgb.py new file mode 100644 index 0000000..acdde01 --- /dev/null +++ b/src/network/net_rgb.py @@ -0,0 +1,234 @@ +# code from IDR (https://github.com/lioryariv/idr/blob/main/code/model/implicit_differentiable_renderer.py) +import torch +import torch.nn as nn +import numpy as np +from src.network.utils import get_embedder +from pdb import set_trace as st + +class RenderingNetwork(nn.Module): + def __init__( + self, + fea_size=0, + mode='naive', + d_out=3, + dims=[512, 512, 512, 512], + weight_norm=True, + pe_freq_view=0 # for positional encoding + ): + super().__init__() + + self.mode = mode + if mode == 'naive': + d_in = 3 + elif mode == 'no_feature': + d_in = 3 + 3 + 3 + fea_size = 0 + elif mode == 'full': + d_in = 3 + 3 + 3 + else: + d_in = 3 + 3 + dims = [d_in + fea_size] + dims + [d_out] + + self.embedview_fn = None + if pe_freq_view > 0: + embedview_fn, input_ch = get_embedder(pe_freq_view, d_in=3) + self.embedview_fn = embedview_fn + dims[0] += (input_ch - 3) + + self.num_layers = len(dims) + + for l in range(0, self.num_layers - 1): + out_dim = dims[l + 1] + lin = nn.Linear(dims[l], out_dim) + + if weight_norm: + lin = nn.utils.weight_norm(lin) + + setattr(self, "lin" + str(l), lin) + + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + + def forward(self, points, normals=None, view_dirs=None, feature_vectors=None): + if self.embedview_fn is not None: + view_dirs = self.embedview_fn(view_dirs) + # points = self.embedview_fn(points) + + if (self.mode == 'full') & (feature_vectors is not None): + rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1) + elif (self.mode == 'no_feature') | ((self.mode == 'full') & (feature_vectors is None)): + rendering_input = torch.cat([points, view_dirs, normals], dim=-1) + elif self.mode == 'no_view_dir': + rendering_input = torch.cat([points, normals], dim=-1) + elif self.mode == 'no_normal': + rendering_input = torch.cat([points, view_dirs], dim=-1) + else: + rendering_input = points + + x = rendering_input + + for l in range(0, self.num_layers - 1): + lin = getattr(self, "lin" + str(l)) + + x = lin(x) + + if l < self.num_layers - 2: + x = self.relu(x) + + x = self.tanh(x) + return x + + +class NeRFRenderingNetwork(nn.Module): + def __init__( + self, + feature_vector_size=0, + mode='naive', + d_in=3, + d_out=3, + dims=[512, 512, 512, 256], + weight_norm=True, + multires=0, # positional encoding of points + multires_view=0 # positional encoding of view + ): + super().__init__() + + self.mode = mode + dims = [d_in + feature_vector_size] + dims + + + self.embed_fn = None + if multires > 0: + embed_fn, input_ch = get_embedder(multires, d_in=d_in) + self.embed_fn = embed_fn + dims[0] += (input_ch - 3) + + self.num_layers = len(dims) + + self.pts_net = nn.ModuleList([nn.Linear(dims[i], dims[i + 1]) for i in range(self.num_layers - 1)]) + + self.embedview_fn = None + if multires_view > 0: + embedview_fn, view_ch = get_embedder(multires_view, d_in=3) + self.embedview_fn = embedview_fn + # dims[0] += (input_ch - 3) + + if mode == 'full': + self.view_net = nn.ModuleList([nn.Linear(dims[-1]+view_ch, 128)]) + self.rgb_net = nn.Linear(128, 3) + else: + self.rgb_net = nn.Linear(dims[-1], 3) + + + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + + def forward(self, points, normals=None, view_dirs=None, feature_vectors=None): + if self.embed_fn is not None: + points = self.embed_fn(points) + if self.embedview_fn is not None: + view_dirs = self.embedview_fn(view_dirs) + + x = points + for net in self.pts_net: + x = net(x) + x = self.relu(x) + + if self.mode=='full': + x = torch.cat([x, view_dirs], -1) + for net in self.view_net: + x = net(x) + x = self.relu(x) + + x = self.rgb_net(x) + x = self.tanh(x) + return x + +class ImplicitNetwork(nn.Module): + def __init__( + self, + d_in, + d_out, + dims, + geometric_init=True, + feature_vector_size=0, + bias=1.0, + skip_in=(), + weight_norm=True, + multires=0 + ): + super().__init__() + + dims = [d_in] + dims + [d_out + feature_vector_size] + + self.embed_fn = None + if multires > 0: + embed_fn, input_ch = get_embedder(multires) + self.embed_fn = embed_fn + dims[0] = input_ch + + self.num_layers = len(dims) + self.skip_in = skip_in + + for l in range(0, self.num_layers - 1): + if l + 1 in self.skip_in: + out_dim = dims[l + 1] - dims[0] + else: + out_dim = dims[l + 1] + + lin = nn.Linear(dims[l], out_dim) + + if geometric_init: + if l == self.num_layers - 2: + torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) + torch.nn.init.constant_(lin.bias, -bias) + elif multires > 0 and l == 0: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.constant_(lin.weight[:, 3:], 0.0) + torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) + elif multires > 0 and l in self.skip_in: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) + torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) + else: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) + + if weight_norm: + lin = nn.utils.weight_norm(lin) + + setattr(self, "lin" + str(l), lin) + + self.softplus = nn.Softplus(beta=100) + + def forward(self, input, compute_grad=False): + if self.embed_fn is not None: + input = self.embed_fn(input) + + x = input + + for l in range(0, self.num_layers - 1): + lin = getattr(self, "lin" + str(l)) + + if l in self.skip_in: + x = torch.cat([x, input], 1) / np.sqrt(2) + + x = lin(x) + + if l < self.num_layers - 2: + x = self.softplus(x) + + return x + + def gradient(self, x): + x.requires_grad_(True) + y = self.forward(x)[:,:1] + d_output = torch.ones_like(y, requires_grad=False, device=y.device) + gradients = torch.autograd.grad( + outputs=y, + inputs=x, + grad_outputs=d_output, + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + return gradients.unsqueeze(1) \ No newline at end of file diff --git a/src/network/unet.py b/src/network/unet.py new file mode 100644 index 0000000..a58cc31 --- /dev/null +++ b/src/network/unet.py @@ -0,0 +1,256 @@ +''' +Codes are from: +https://github.com/jaxony/unet-pytorch/blob/master/model.py +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +import numpy as np + +def conv3x3(in_channels, out_channels, stride=1, + padding=1, bias=True, groups=1): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=padding, + bias=bias, + groups=groups) + +def upconv2x2(in_channels, out_channels, mode='transpose'): + if mode == 'transpose': + return nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=2, + stride=2) + else: + # out_channels is always going to be the same + # as in_channels + return nn.Sequential( + nn.Upsample(mode='bilinear', scale_factor=2), + conv1x1(in_channels, out_channels)) + +def conv1x1(in_channels, out_channels, groups=1): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + groups=groups, + stride=1) + + +class DownConv(nn.Module): + """ + A helper Module that performs 2 convolutions and 1 MaxPool. + A ReLU activation follows each convolution. + """ + def __init__(self, in_channels, out_channels, pooling=True): + super(DownConv, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.pooling = pooling + + self.conv1 = conv3x3(self.in_channels, self.out_channels) + self.conv2 = conv3x3(self.out_channels, self.out_channels) + + if self.pooling: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + before_pool = x + if self.pooling: + x = self.pool(x) + return x, before_pool + + +class UpConv(nn.Module): + """ + A helper Module that performs 2 convolutions and 1 UpConvolution. + A ReLU activation follows each convolution. + """ + def __init__(self, in_channels, out_channels, + merge_mode='concat', up_mode='transpose'): + super(UpConv, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.merge_mode = merge_mode + self.up_mode = up_mode + + self.upconv = upconv2x2(self.in_channels, self.out_channels, + mode=self.up_mode) + + if self.merge_mode == 'concat': + self.conv1 = conv3x3( + 2*self.out_channels, self.out_channels) + else: + # num of input channels to conv2 is same + self.conv1 = conv3x3(self.out_channels, self.out_channels) + self.conv2 = conv3x3(self.out_channels, self.out_channels) + + + def forward(self, from_down, from_up): + """ Forward pass + Arguments: + from_down: tensor from the encoder pathway + from_up: upconv'd tensor from the decoder pathway + """ + from_up = self.upconv(from_up) + if self.merge_mode == 'concat': + x = torch.cat((from_up, from_down), 1) + else: + x = from_up + from_down + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + return x + + +class UNet(nn.Module): + """ `UNet` class is based on https://arxiv.org/abs/1505.04597 + + The U-Net is a convolutional encoder-decoder neural network. + Contextual spatial information (from the decoding, + expansive pathway) about an input tensor is merged with + information representing the localization of details + (from the encoding, compressive pathway). + + Modifications to the original paper: + (1) padding is used in 3x3 convolutions to prevent loss + of border pixels + (2) merging outputs does not require cropping due to (1) + (3) residual connections can be used by specifying + UNet(merge_mode='add') + (4) if non-parametric upsampling is used in the decoder + pathway (specified by upmode='upsample'), then an + additional 1x1 2d convolution occurs after upsampling + to reduce channel dimensionality by a factor of 2. + This channel halving happens with the convolution in + the tranpose convolution (specified by upmode='transpose') + """ + + def __init__(self, num_classes, in_channels=3, depth=5, + start_filts=64, up_mode='transpose', + merge_mode='concat', **kwargs): + """ + Arguments: + in_channels: int, number of channels in the input tensor. + Default is 3 for RGB images. + depth: int, number of MaxPools in the U-Net. + start_filts: int, number of convolutional filters for the + first conv. + up_mode: string, type of upconvolution. Choices: 'transpose' + for transpose convolution or 'upsample' for nearest neighbour + upsampling. + """ + super(UNet, self).__init__() + + if up_mode in ('transpose', 'upsample'): + self.up_mode = up_mode + else: + raise ValueError("\"{}\" is not a valid mode for " + "upsampling. Only \"transpose\" and " + "\"upsample\" are allowed.".format(up_mode)) + + if merge_mode in ('concat', 'add'): + self.merge_mode = merge_mode + else: + raise ValueError("\"{}\" is not a valid mode for" + "merging up and down paths. " + "Only \"concat\" and " + "\"add\" are allowed.".format(up_mode)) + + # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add' + if self.up_mode == 'upsample' and self.merge_mode == 'add': + raise ValueError("up_mode \"upsample\" is incompatible " + "with merge_mode \"add\" at the moment " + "because it doesn't make sense to use " + "nearest neighbour to reduce " + "depth channels (by half).") + + self.num_classes = num_classes + self.in_channels = in_channels + self.start_filts = start_filts + self.depth = depth + + self.down_convs = [] + self.up_convs = [] + + # create the encoder pathway and add to a list + for i in range(depth): + ins = self.in_channels if i == 0 else outs + outs = self.start_filts*(2**i) + pooling = True if i < depth-1 else False + + down_conv = DownConv(ins, outs, pooling=pooling) + self.down_convs.append(down_conv) + + # create the decoder pathway and add to a list + # - careful! decoding only requires depth-1 blocks + for i in range(depth-1): + ins = outs + outs = ins // 2 + up_conv = UpConv(ins, outs, up_mode=up_mode, + merge_mode=merge_mode) + self.up_convs.append(up_conv) + + # add the list of modules to current module + self.down_convs = nn.ModuleList(self.down_convs) + self.up_convs = nn.ModuleList(self.up_convs) + + self.conv_final = conv1x1(outs, self.num_classes) + + self.reset_params() + + @staticmethod + def weight_init(m): + if isinstance(m, nn.Conv2d): + init.xavier_normal_(m.weight) + init.constant_(m.bias, 0) + + + def reset_params(self): + for i, m in enumerate(self.modules()): + self.weight_init(m) + + + def forward(self, x): + encoder_outs = [] + # encoder pathway, save outputs for merging + for i, module in enumerate(self.down_convs): + x, before_pool = module(x) + encoder_outs.append(before_pool) + for i, module in enumerate(self.up_convs): + before_pool = encoder_outs[-(i+2)] + x = module(before_pool, x) + + # No softmax is used. This means you need to use + # nn.CrossEntropyLoss is your training script, + # as this module includes a softmax already. + x = self.conv_final(x) + return x + +if __name__ == "__main__": + """ + testing + """ + model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32) + print(model) + print(sum(p.numel() for p in model.parameters())) + + reso = 176 + x = np.zeros((1, 1, reso, reso)) + x[:,:,int(reso/2-1), int(reso/2-1)] = np.nan + x = torch.FloatTensor(x) + + out = model(x) + print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso))) + + # loss = torch.sum(out) + # loss.backward() diff --git a/src/network/unet3d.py b/src/network/unet3d.py new file mode 100644 index 0000000..3c2bae2 --- /dev/null +++ b/src/network/unet3d.py @@ -0,0 +1,559 @@ +''' +Code from the 3D UNet implementation: +https://github.com/wolny/pytorch-3dunet/ +''' +import importlib +import torch +import torch.nn as nn +from torch.nn import functional as F +from functools import partial +from src.network.utils import get_embedder + +def number_of_features_per_level(init_channel_number, num_levels): + return [init_channel_number * 2 ** k for k in range(num_levels)] + + +def conv3d(in_channels, out_channels, kernel_size, bias, padding=1): + return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) + + +def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1): + """ + Create a list of modules with together constitute a single conv layer with non-linearity + and optional batchnorm/groupnorm. + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + order (string): order of things, e.g. + 'cr' -> conv + ReLU + 'gcr' -> groupnorm + conv + ReLU + 'cl' -> conv + LeakyReLU + 'ce' -> conv + ELU + 'bcr' -> batchnorm + conv + ReLU + num_groups (int): number of groups for the GroupNorm + padding (int): add zero-padding to the input + + Return: + list of tuple (name, module) + """ + assert 'c' in order, "Conv layer MUST be present" + assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' + + modules = [] + for i, char in enumerate(order): + if char == 'r': + modules.append(('ReLU', nn.ReLU(inplace=True))) + elif char == 'l': + modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True))) + elif char == 'e': + modules.append(('ELU', nn.ELU(inplace=True))) + elif char == 'c': + # add learnable bias only in the absence of batchnorm/groupnorm + bias = not ('g' in order or 'b' in order) + modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding))) + elif char == 'g': + is_before_conv = i < order.index('c') + if is_before_conv: + num_channels = in_channels + else: + num_channels = out_channels + + # use only one group if the given number of groups is greater than the number of channels + if num_channels < num_groups: + num_groups = 1 + + assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}' + modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels))) + elif char == 'b': + is_before_conv = i < order.index('c') + if is_before_conv: + modules.append(('batchnorm', nn.BatchNorm3d(in_channels))) + else: + modules.append(('batchnorm', nn.BatchNorm3d(out_channels))) + else: + raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']") + + return modules + + +class SingleConv(nn.Sequential): + """ + Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order + of operations can be specified via the `order` parameter + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + kernel_size (int): size of the convolving kernel + order (string): determines the order of layers, e.g. + 'cr' -> conv + ReLU + 'crg' -> conv + ReLU + groupnorm + 'cl' -> conv + LeakyReLU + 'ce' -> conv + ELU + num_groups (int): number of groups for the GroupNorm + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8, padding=1): + super(SingleConv, self).__init__() + + for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding): + self.add_module(name, module) + + +class DoubleConv(nn.Sequential): + """ + A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d). + We use (Conv3d+ReLU+GroupNorm3d) by default. + This can be changed however by providing the 'order' argument, e.g. in order + to change to Conv3d+BatchNorm3d+ELU use order='cbe'. + Use padded convolutions to make sure that the output (H_out, W_out) is the same + as (H_in, W_in), so that you don't have to crop in the decoder path. + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + encoder (bool): if True we're in the encoder path, otherwise we're in the decoder + kernel_size (int): size of the convolving kernel + order (string): determines the order of layers, e.g. + 'cr' -> conv + ReLU + 'crg' -> conv + ReLU + groupnorm + 'cl' -> conv + LeakyReLU + 'ce' -> conv + ELU + num_groups (int): number of groups for the GroupNorm + """ + + def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='crg', num_groups=8): + super(DoubleConv, self).__init__() + if encoder: + # we're in the encoder path + conv1_in_channels = in_channels + conv1_out_channels = out_channels // 2 + if conv1_out_channels < in_channels: + conv1_out_channels = in_channels + conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels + else: + # we're in the decoder path, decrease the number of channels in the 1st convolution + conv1_in_channels, conv1_out_channels = in_channels, out_channels + conv2_in_channels, conv2_out_channels = out_channels, out_channels + + # conv1 + self.add_module('SingleConv1', + SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups)) + # conv2 + self.add_module('SingleConv2', + SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups)) + + +class ExtResNetBlock(nn.Module): + """ + Basic UNet block consisting of a SingleConv followed by the residual block. + The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number + of output channels is compatible with the residual block that follows. + This block can be used instead of standard DoubleConv in the Encoder module. + Motivated by: https://arxiv.org/pdf/1706.00120.pdf + + Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs): + super(ExtResNetBlock, self).__init__() + + # first convolution + self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) + # residual block + self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) + # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual + n_order = order + for c in 'rel': + n_order = n_order.replace(c, '') + self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order, + num_groups=num_groups) + + # create non-linearity separately + if 'l' in order: + self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif 'e' in order: + self.non_linearity = nn.ELU(inplace=True) + else: + self.non_linearity = nn.ReLU(inplace=True) + + def forward(self, x): + # apply first convolution and save the output as a residual + out = self.conv1(x) + residual = out + + # residual block + out = self.conv2(out) + out = self.conv3(out) + + out += residual + out = self.non_linearity(out) + + return out + + +class Encoder(nn.Module): + """ + A single module from the encoder path consisting of the optional max + pooling layer (one may specify the MaxPool kernel_size to be different + than the standard (2,2,2), e.g. if the volumetric data is anisotropic + (make sure to use complementary scale_factor in the decoder path) followed by + a DoubleConv module. + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + conv_kernel_size (int): size of the convolving kernel + apply_pooling (bool): if True use MaxPool3d before DoubleConv + pool_kernel_size (tuple): the size of the window to take a max over + pool_type (str): pooling layer: 'max' or 'avg' + basic_module(nn.Module): either ResNetBlock or DoubleConv + conv_layer_order (string): determines the order of layers + in `DoubleConv` module. See `DoubleConv` for more info. + num_groups (int): number of groups for the GroupNorm + """ + + def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, + pool_kernel_size=(2, 2, 2), pool_type='max', basic_module=DoubleConv, conv_layer_order='crg', + num_groups=8): + super(Encoder, self).__init__() + assert pool_type in ['max', 'avg'] + if apply_pooling: + if pool_type == 'max': + self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size) + else: + self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size) + else: + self.pooling = None + + self.basic_module = basic_module(in_channels, out_channels, + encoder=True, + kernel_size=conv_kernel_size, + order=conv_layer_order, + num_groups=num_groups) + + def forward(self, x): + if self.pooling is not None: + x = self.pooling(x) + x = self.basic_module(x) + return x + + +class Decoder(nn.Module): + """ + A single module for decoder path consisting of the upsampling layer + (either learned ConvTranspose3d or nearest neighbor interpolation) followed by a basic module (DoubleConv or ExtResNetBlock). + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + kernel_size (int): size of the convolving kernel + scale_factor (tuple): used as the multiplier for the image H/W/D in + case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation + from the corresponding encoder + basic_module(nn.Module): either ResNetBlock or DoubleConv + conv_layer_order (string): determines the order of layers + in `DoubleConv` module. See `DoubleConv` for more info. + num_groups (int): number of groups for the GroupNorm + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv, + conv_layer_order='crg', num_groups=8, mode='nearest'): + super(Decoder, self).__init__() + if basic_module == DoubleConv: + # if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining + self.upsampling = Upsampling(transposed_conv=False, in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, scale_factor=scale_factor, mode=mode) + # concat joining + self.joining = partial(self._joining, concat=True) + else: + # if basic_module=ExtResNetBlock use transposed convolution upsampling and summation joining + self.upsampling = Upsampling(transposed_conv=True, in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, scale_factor=scale_factor, mode=mode) + # sum joining + self.joining = partial(self._joining, concat=False) + # adapt the number of in_channels for the ExtResNetBlock + in_channels = out_channels + + self.basic_module = basic_module(in_channels, out_channels, + encoder=False, + kernel_size=kernel_size, + order=conv_layer_order, + num_groups=num_groups) + + def forward(self, encoder_features, x): + x = self.upsampling(encoder_features=encoder_features, x=x) + x = self.joining(encoder_features, x) + x = self.basic_module(x) + return x + + @staticmethod + def _joining(encoder_features, x, concat): + if concat: + return torch.cat((encoder_features, x), dim=1) + else: + return encoder_features + x + + +class Upsampling(nn.Module): + """ + Upsamples a given multi-channel 3D data using either interpolation or learned transposed convolution. + + Args: + transposed_conv (bool): if True uses ConvTranspose3d for upsampling, otherwise uses interpolation + concat_joining (bool): if True uses concatenation joining between encoder and decoder features, otherwise + uses summation joining (see Residual U-Net) + in_channels (int): number of input channels for transposed conv + out_channels (int): number of output channels for transpose conv + kernel_size (int or tuple): size of the convolving kernel + scale_factor (int or tuple): stride of the convolution + mode (str): algorithm used for upsampling: + 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest' + """ + + def __init__(self, transposed_conv, in_channels=None, out_channels=None, kernel_size=3, + scale_factor=(2, 2, 2), mode='nearest'): + super(Upsampling, self).__init__() + + if transposed_conv: + # make sure that the output size reverses the MaxPool3d from the corresponding encoder + # (D_out = (D_in − 1) ×  stride[0] − 2 ×  padding[0] +  kernel_size[0] +  output_padding[0]) + self.upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor, + padding=1) + else: + self.upsample = partial(self._interpolate, mode=mode) + + def forward(self, encoder_features, x): + output_size = encoder_features.size()[2:] + return self.upsample(x, output_size) + + @staticmethod + def _interpolate(x, size, mode): + return F.interpolate(x, size=size, mode=mode) + + +class FinalConv(nn.Sequential): + """ + A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution + which reduces the number of channels to 'out_channels'. + with the number of output channels 'out_channels // 2' and 'out_channels' respectively. + We use (Conv3d+ReLU+GroupNorm3d) by default. + This can be change however by providing the 'order' argument, e.g. in order + to change to Conv3d+BatchNorm3d+ReLU use order='cbr'. + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + kernel_size (int): size of the convolving kernel + order (string): determines the order of layers, e.g. + 'cr' -> conv + ReLU + 'crg' -> conv + ReLU + groupnorm + num_groups (int): number of groups for the GroupNorm + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8): + super(FinalConv, self).__init__() + + # conv1 + self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups)) + + # in the last layer a 1×1 convolution reduces the number of output channels to out_channels + final_conv = nn.Conv3d(in_channels, out_channels, 1) + self.add_module('final_conv', final_conv) + +class Abstract3DUNet(nn.Module): + """ + Base class for standard and residual UNet. + + Args: + in_channels (int): number of input channels + out_channels (int): number of output segmentation masks; + Note that that the of out_channels might correspond to either + different semantic classes or to different binary segmentation mask. + It's up to the user of the class to interpret the out_channels and + use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class) + or BCEWithLogitsLoss (two-class) respectively) + f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number + of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4 + final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the + final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used + to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model. + basic_module: basic model for the encoder/decoder (DoubleConv, ExtResNetBlock, ....) + layer_order (string): determines the order of layers + in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d. + See `SingleConv` for more info + f_maps (int, tuple): if int: number of feature maps in the first conv layer of the encoder (default: 64); + if tuple: number of feature maps at each level + num_groups (int): number of groups for the GroupNorm + num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int) + is_segmentation (bool): if True (semantic segmentation problem) Sigmoid/Softmax normalization is applied + after the final convolution; if False (regression problem) the normalization layer is skipped at the end + testing (bool): if True (testing mode) the `final_activation` (if present, i.e. `is_segmentation=true`) + will be applied as the last operation during the forward pass; if False the model is in training mode + and the `final_activation` (even if present) won't be applied; default: False + """ + + def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr', + num_groups=8, num_levels=4, is_segmentation=False, testing=False, pe_freq=0, **kwargs): + super(Abstract3DUNet, self).__init__() + + self.testing = testing + + if isinstance(f_maps, int): + f_maps = number_of_features_per_level(f_maps, num_levels=num_levels) + + # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)` + + self.embed_fn = None + if pe_freq > 0: + embed_fn, input_ch = get_embedder(pe_freq, d_in=in_channels) + self.embed_fn = embed_fn + in_channels = input_ch + + encoders = [] + for i, out_feature_num in enumerate(f_maps): + if i == 0: + encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=basic_module, + conv_layer_order=layer_order, num_groups=num_groups) + else: + # TODO: adapt for anisotropy in the data, i.e. use proper pooling kernel to make the data isotropic after 1-2 pooling operations + # currently pools with a constant kernel: (2, 2, 2) + encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=basic_module, + conv_layer_order=layer_order, num_groups=num_groups) + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + + # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1` + decoders = [] + reversed_f_maps = list(reversed(f_maps)) + for i in range(len(reversed_f_maps) - 1): + if basic_module == DoubleConv: + in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] + else: + in_feature_num = reversed_f_maps[i] + + out_feature_num = reversed_f_maps[i + 1] + # TODO: if non-standard pooling was used, make sure to use correct striding for transpose conv + # currently strides with a constant stride: (2, 2, 2) + decoder = Decoder(in_feature_num, out_feature_num, basic_module=basic_module, + conv_layer_order=layer_order, num_groups=num_groups) + decoders.append(decoder) + + self.decoders = nn.ModuleList(decoders) + + # in the last layer a 1×1 convolution reduces the number of output + # channels to the number of labels + self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) + + if is_segmentation: + # semantic segmentation problem + if final_sigmoid: + self.final_activation = nn.Sigmoid() + else: + self.final_activation = nn.Softmax(dim=1) + else: + # regression problem + self.final_activation = None + + def forward(self, x): + + if self.embed_fn is not None: + x = self.embed_fn(x.permute(0, 2, 3, 4, 1)) + x = x.permute(0, 4, 1, 2, 3) + + # encoder part + encoders_features = [] + for encoder in self.encoders: + x = encoder(x) + # reverse the encoder outputs to be aligned with the decoder + encoders_features.insert(0, x) + + # remove the last encoder's output from the list + # !!remember: it's the 1st in the list + encoders_features = encoders_features[1:] + + # decoder part + for decoder, encoder_features in zip(self.decoders, encoders_features): + # pass the output from the corresponding encoder and the output + # of the previous decoder + x = decoder(encoder_features, x) + + x = self.final_conv(x) + + # apply final_activation (i.e. Sigmoid or Softmax) only during prediction. During training the network outputs + # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric + if self.testing and self.final_activation is not None: + x = self.final_activation(x) + + return x + + +class UNet3D(Abstract3DUNet): + """ + 3DUnet model from + `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation" + `. + + Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder + """ + + def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', + num_groups=8, num_levels=4, is_segmentation=True, **kwargs): + super(UNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid, + basic_module=DoubleConv, f_maps=f_maps, layer_order=layer_order, + num_groups=num_groups, num_levels=num_levels, is_segmentation=is_segmentation, + **kwargs) + + +class ResidualUNet3D(Abstract3DUNet): + """ + Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf. + Uses ExtResNetBlock as a basic building block, summation joining instead + of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts). + Since the model effectively becomes a residual net, in theory it allows for deeper UNet. + """ + + def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', + num_groups=8, num_levels=5, is_segmentation=True, **kwargs): + super(ResidualUNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, + final_sigmoid=final_sigmoid, + basic_module=ExtResNetBlock, f_maps=f_maps, layer_order=layer_order, + num_groups=num_groups, num_levels=num_levels, + is_segmentation=is_segmentation, + **kwargs) + + +def get_model(config): + def _model_class(class_name): + m = importlib.import_module('pytorch3dunet.unet3d.model') + clazz = getattr(m, class_name) + return clazz + + assert 'model' in config, 'Could not find model configuration' + model_config = config['model'] + model_class = _model_class(model_config['name']) + return model_class(**model_config) + + +if __name__ == "__main__": + """ + testing + """ + in_channels = 1 + out_channels = 1 + f_maps = 32 + num_levels = 2 + model = UNet3D(in_channels, out_channels, f_maps=f_maps, num_levels=num_levels, layer_order='cr') + print(model) + print('number of parameters: ', sum(p.numel() for p in model.parameters())) + + reso = 18 + + import numpy as np + import torch + x = np.zeros((1, 1, reso, reso, reso)) + x[:,:, int(reso/2-1), int(reso/2-1), int(reso/2-1)] = np.nan + x = torch.FloatTensor(x) + + out = model(x) + print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso*reso))) + \ No newline at end of file diff --git a/src/network/utils.py b/src/network/utils.py new file mode 100644 index 0000000..68ea744 --- /dev/null +++ b/src/network/utils.py @@ -0,0 +1,167 @@ +""" Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """ + +import torch +import torch.nn as nn + +class Embedder: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_embedding_fn() + + def create_embedding_fn(self): + embed_fns = [] + d = self.kwargs['input_dims'] + out_dim = 0 + if self.kwargs['include_input']: + embed_fns.append(lambda x: x) + out_dim += d + + max_freq = self.kwargs['max_freq_log2'] + N_freqs = self.kwargs['num_freqs'] + + if self.kwargs['log_sampling']: + freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) + else: + freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) + + for freq in freq_bands: + for p_fn in self.kwargs['periodic_fns']: + embed_fns.append(lambda x, p_fn=p_fn, + freq=freq: p_fn(x * freq)) + out_dim += d + + self.embed_fns = embed_fns + self.out_dim = out_dim + + def embed(self, inputs): + return torch.cat([fn(inputs) for fn in self.embed_fns], -1) + +def get_embedder(multires, d_in=3): + embed_kwargs = { + 'include_input': True, + 'input_dims': d_in, + 'max_freq_log2': multires-1, + 'num_freqs': multires, + 'log_sampling': True, + 'periodic_fns': [torch.sin, torch.cos], + } + + embedder_obj = Embedder(**embed_kwargs) + def embed(x, eo=embedder_obj): return eo.embed(x) + return embed, embedder_obj.out_dim + +def normalize_coordinate(p, plane='xz'): + ''' Normalize coordinate to [0, 1] for unit cube experiments + + Args: + p (tensor): point + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + plane (str): plane feature type, ['xz', 'xy', 'yz'] + ''' + if plane == 'xz': + xy = p[:, :, [0, 2]] + elif plane =='xy': + xy = p[:, :, [0, 1]] + else: + xy = p[:, :, [1, 2]] + + xy_new = xy + # f there are outliers out of the range + if xy_new.max() >= 1: + xy_new[xy_new >= 1] = 1 - 10e-6 + if xy_new.min() < 0: + xy_new[xy_new < 0] = 0.0 + return xy_new + + +def normalize_3d_coordinate(p): + ''' Normalize coordinate to [0, 1] for unit cube experiments. + ''' + if p.max() >= 1: + p[p >= 1] = 1 - 10e-6 + if p.min() < 0: + p[p < 0] = 0.0 + return p + +def coordinate2index(x, reso, coord_type='2d'): + ''' Normalize coordinate to [0, 1] for unit cube experiments. + Corresponds to our 3D model + + Args: + x (tensor): coordinate + reso (int): defined resolution + coord_type (str): coordinate type + ''' + x = (x * reso).long() + if coord_type == '2d': # plane + index = x[:, :, 0] + reso * x[:, :, 1] + elif coord_type == '3d': # grid + index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2]) + index = index[:, None, :] + return index + + +class map2local(object): + ''' Add new keys to the given input + + Args: + s (float): the defined voxel size + pos_encoding (str): method for the positional encoding, linear|sin_cos + ''' + def __init__(self, s, pos_encoding='linear'): + super().__init__() + self.s = s + # self.pe = positional_encoding(basis_function=pos_encoding, local=True) + + def __call__(self, p): + # p = torch.remainder(p, self.s) / self.s # always possitive + p = (p % self.s) / self.s + p[p < 0] = 0.0 + # p = torch.fmod(p, self.s) / self.s # same sign as input p! + # p = self.pe(p) + return p + +# Resnet Blocks +class ResnetBlockFC(nn.Module): + ''' Fully connected ResNet Block class. + + Args: + size_in (int): input dimension + size_out (int): output dimension + size_h (int): hidden dimension + ''' + + def __init__(self, size_in, size_out=None, size_h=None, siren=False): + super().__init__() + # Attributes + if size_out is None: + size_out = size_in + + if size_h is None: + size_h = min(size_in, size_out) + + self.size_in = size_in + self.size_h = size_h + self.size_out = size_out + # Submodules + self.fc_0 = nn.Linear(size_in, size_h) + self.fc_1 = nn.Linear(size_h, size_out) + self.actvn = nn.ReLU() + + if size_in == size_out: + self.shortcut = None + else: + self.shortcut = nn.Linear(size_in, size_out, bias=False) + # Initialization + nn.init.zeros_(self.fc_1.weight) + + def forward(self, x): + net = self.fc_0(self.actvn(x)) + dx = self.fc_1(self.actvn(net)) + + if self.shortcut is not None: + x_s = self.shortcut(x) + else: + x_s = x + + return x_s + dx \ No newline at end of file diff --git a/src/optimization.py b/src/optimization.py new file mode 100644 index 0000000..12a0c1d --- /dev/null +++ b/src/optimization.py @@ -0,0 +1,349 @@ +import time, os +import numpy as np +import torch +from torch.nn import functional as F +import trimesh + +from src.dpsr import DPSR +from src.model import PSR2Mesh +from src.utils import grid_interp, verts_on_largest_mesh,\ + export_pointcloud, mc_from_psr, GaussianSmoothing +from src.visualize import visualize_points_mesh, visualize_psr_grid, \ + visualize_mesh_phong, render_rgb +from torchvision.utils import save_image +from torchvision.io import write_video +from pytorch3d.loss import chamfer_distance +import open3d as o3d + +class Trainer(object): + ''' + Args: + cfg : config file + optimizer : pytorch optimizer object + device : pytorch device + ''' + + def __init__(self, cfg, optimizer, device=None): + self.optimizer = optimizer + self.device = device + self.cfg = cfg + self.psr2mesh = PSR2Mesh.apply + self.data_type = cfg['data']['data_type'] + + # initialize DPSR + self.dpsr = DPSR(res=(cfg['model']['grid_res'], + cfg['model']['grid_res'], + cfg['model']['grid_res']), + sig=cfg['model']['psr_sigma']) + if torch.cuda.device_count() > 1: + self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR + self.dpsr = self.dpsr.to(device) + + def train_step(self, data, inputs, model, it): + ''' Performs a training step. + + Args: + data (dict) : data dictionary + inputs (torch.tensor) : input point clouds + model (nn.Module or None): a neural network or None + it (int) : the number of iterations + ''' + + self.optimizer.zero_grad() + loss, loss_each = self.compute_loss(inputs, data, model, it) + + loss.backward() + self.optimizer.step() + + return loss.item(), loss_each + + def compute_loss(self, inputs, data, model, it=0): + ''' Compute the loss. + Args: + data (dict) : data dictionary + inputs (torch.tensor) : input point clouds + model (nn.Module or None): a neural network or None + it (int) : the number of iterations + ''' + + device = self.device + res = self.cfg['model']['grid_res'] + + # source oriented point clouds to PSR grid + psr_grid, points, normals = self.pcl2psr(inputs) + + # build mesh + v, f, n = self.psr2mesh(psr_grid) + + # the output is in the range of [0, 1), we make it to the real range [0, 1]. + # This is a hack for our DPSR solver + v = v * res / (res-1) + + points = points * 2. - 1. + v = v * 2. - 1. # within the range of (-1, 1) + + loss = 0 + loss_each = {} + # compute loss + if self.data_type == 'point': + if self.cfg['train']['w_chamfer'] > 0: + loss_ = self.cfg['train']['w_chamfer'] * \ + self.compute_3d_loss(v, data) + loss_each['chamfer'] = loss_ + loss += loss_ + elif self.data_type == 'img': + loss, loss_each = self.compute_2d_loss(inputs, data, model) + + return loss, loss_each + + + def pcl2psr(self, inputs): + ''' Convert an oriented point cloud to PSR indicator grid + Args: + inputs (torch.tensor): input oriented point clouds + ''' + + points, normals = inputs[...,:3], inputs[...,3:] + if self.cfg['model']['apply_sigmoid']: + points = torch.sigmoid(points) + if self.cfg['model']['normal_normalize']: + normals = normals / normals.norm(dim=-1, keepdim=True) + + # DPSR to get grid + psr_grid = self.dpsr(points, normals).unsqueeze(1) + psr_grid = torch.tanh(psr_grid) + + return psr_grid, points, normals + + def compute_3d_loss(self, v, data): + ''' Compute the loss for point clouds. + Args: + v (torch.tensor) : mesh vertices + data (dict) : data dictionary + ''' + + pts_gt = data.get('target_points') + idx = np.random.randint(pts_gt.shape[1], size=self.cfg['train']['n_sup_point']) + if self.cfg['train']['subsample_vertex']: + #chamfer distance only on random sampled vertices + idx = np.random.randint(v.shape[1], size=self.cfg['train']['n_sup_point']) + loss, _ = chamfer_distance(v[:, idx], pts_gt) + else: + loss, _ = chamfer_distance(v, pts_gt) + + return loss + + def compute_2d_loss(self, inputs, data, model): + ''' Compute the 2D losses. + Args: + inputs (torch.tensor) : input source point clouds + data (dict) : data dictionary + model (nn.Module or None): neural network or None + ''' + + losses = {"color": + {"weight": self.cfg['train']['l_weight']['rgb'], + "values": [] + }, + "silhouette": + {"weight": self.cfg['train']['l_weight']['mask'], + "values": []}, + } + loss_all = {k: torch.tensor(0.0, device=self.device) for k in losses} + + # forward pass + out = model(inputs, data) + + if out['rgb'] is not None: + rgb_gt = out['rgb_gt'].reshape(self.cfg['data']['n_views_per_iter'], + -1, 3)[out['vis_mask']] + loss_all["color"] += torch.nn.L1Loss(reduction='sum')(rgb_gt, + out['rgb']) / out['rgb'].shape[0] + + if out['mask'] is not None: + loss_all["silhouette"] += ((out['mask'] - out['mask_gt']) ** 2).mean() + + # weighted sum of the losses + loss = torch.tensor(0.0, device=self.device) + for k, l in loss_all.items(): + loss += l * losses[k]["weight"] + losses[k]["values"].append(l) + + return loss, loss_all + + def point_resampling(self, inputs): + ''' Resample points + Args: + inputs (torch.tensor): oriented point clouds + ''' + + psr_grid, points, normals = self.pcl2psr(inputs) + + # shortcuts + n_grow = self.cfg['train']['n_grow_points'] + + # [hack] for points resampled from the mesh from marching cubes, + # we need to divide by s instead of (s-1), and the scale is correct. + verts, faces, _ = mc_from_psr(psr_grid, real_scale=False, zero_level=0) + + # find the largest component + pts_mesh, faces_mesh = verts_on_largest_mesh(verts, faces) + + # sample vertices only from the largest component, not from fragments + mesh = trimesh.Trimesh(vertices=pts_mesh, faces=faces_mesh) + pi, face_idx = mesh.sample(n_grow+points.shape[1], return_index=True) + normals_i = mesh.face_normals[face_idx].astype('float32') + pts_mesh = torch.tensor(pi.astype('float32')).to(self.device)[None] + n_mesh = torch.tensor(normals_i).to(self.device)[None] + + points, normals = pts_mesh, n_mesh + print('{} total points are resampled'.format(points.shape[1])) + + # update inputs + points = torch.log(points / (1 - points)) # inverse sigmoid + inputs = torch.cat([points, normals], dim=-1) + inputs.requires_grad = True + + return inputs + + def visualize(self, data, inputs, renderer, epoch, o3d_vis=None): + ''' Visualization. + Args: + data (dict) : data dictionary + inputs (torch.tensor) : source point clouds + renderer (nn.Module or None): a neural network or None + epoch (int) : the number of iterations + o3d_vis (o3d.Visualizer) : open3d visualizer + ''' + + data_type = self.cfg['data']['data_type'] + it = '{:04d}'.format(int(epoch/self.cfg['train']['visualize_every'])) + + + if (self.cfg['train']['exp_mesh']) \ + | (self.cfg['train']['exp_pcl']) \ + | (self.cfg['train']['o3d_show']): + psr_grid, points, normals = self.pcl2psr(inputs) + + with torch.no_grad(): + v, f, n = mc_from_psr(psr_grid, pytorchify=True, + zero_level=self.cfg['data']['zero_level'], real_scale=True) + v, f, n = v[None], f[None], n[None] + + v = v * 2. - 1. # change to the range of [-1, 1] + + color_v = None + if data_type == 'img': + if self.cfg['train']['vis_vert_color'] & \ + (self.cfg['train']['l_weight']['rgb'] != 0.): + color_v = renderer['color'](v, n).squeeze().detach().cpu().numpy() + color_v[color_v<0], color_v[color_v>1] = 0., 1. + + vv = v.detach().squeeze().cpu().numpy() + ff = f.detach().squeeze().cpu().numpy() + points = points * 2 - 1 + visualize_points_mesh(o3d_vis, points, normals, + vv, ff, self.cfg, it, epoch, color_v=color_v) + + else: + v, f, n = inputs + + + if (data_type == 'img') & (self.cfg['train']['vis_rendering']): + pred_imgs = [] + pred_masks = [] + n_views = len(data['poses']) + # idx_list = trange(n_views) + idx_list = [13, 24, 27, 48] + + #! + model = renderer.eval() + for idx in idx_list: + pose = data['poses'][idx] + rgb = data['rgbs'][idx] + mask_gt = data['masks'][idx] + img_size = rgb.shape[0] if rgb.shape[0]== rgb.shape[1] else (rgb.shape[0], rgb.shape[1]) + ray = None + if 'rays' in data.keys(): + ray = data['rays'][idx] + if self.cfg['train']['l_weight']['rgb'] != 0.: + fea_grid = None + if model.unet3d is not None: + with torch.no_grad(): + fea_grid = model.unet3d(psr_grid).permute(0, 2, 3, 4, 1) + if model.encoder is not None: + pp = torch.cat([(points+1)/2, normals], dim=-1) + fea_grid = model.encoder(pp, + normalize=False).permute(0, 2, 3, 4, 1) + + pred, visible_mask = render_rgb(v, f, n, pose, + model.rendering_network.eval(), + img_size, ray=ray, fea_grid=fea_grid) + img_pred = torch.zeros([rgb.shape[0]*rgb.shape[1], 3]) + img_pred[visible_mask] = pred.detach().cpu() + + img_pred = img_pred.reshape(rgb.shape[0], rgb.shape[1], 3) + img_pred[img_pred<0], img_pred[img_pred>1] = 0., 1. + filename=os.path.join(self.cfg['train']['dir_rendering'], + 'rendering_{}_{:d}.png'.format(it, idx)) + save_image(img_pred.permute(2, 0, 1), filename) + pred_imgs.append(img_pred) + + #! Mesh rendering using Phong shading model + filename=os.path.join(self.cfg['train']['dir_rendering'], + 'mesh_{}_{:d}.png'.format(it, idx)) + visualize_mesh_phong(v, f, n, pose, img_size, name=filename) + + if len(pred_imgs) >= 1: + pred_imgs = torch.stack(pred_imgs, dim=0) + save_image(pred_imgs.permute(0, 3, 1, 2), + os.path.join(self.cfg['train']['dir_rendering'], + '{}.png'.format(it)), nrow=4) + if self.cfg['train']['save_video']: + write_video(os.path.join(self.cfg['train']['dir_rendering'], + '{}.mp4'.format(it)), + (pred_imgs*255.).type(torch.uint8), fps=24) + + def save_mesh_pointclouds(self, inputs, epoch, center=None, scale=None): + ''' Save meshes and point clouds. + Args: + inputs (torch.tensor) : source point clouds + epoch (int) : the number of iterations + center (numpy.array) : center of the shape + scale (numpy.array) : scale of the shape + ''' + + exp_pcl = self.cfg['train']['exp_pcl'] + exp_mesh = self.cfg['train']['exp_mesh'] + + psr_grid, points, normals = self.pcl2psr(inputs) + + if exp_pcl: + dir_pcl = self.cfg['train']['dir_pcl'] + p = points.squeeze(0).detach().cpu().numpy() + p = p * 2 - 1 + n = normals.squeeze(0).detach().cpu().numpy() + if scale is not None: + p *= scale + if center is not None: + p += center + export_pointcloud(os.path.join(dir_pcl, '{:04d}.ply'.format(epoch)), p, n) + if exp_mesh: + dir_mesh = self.cfg['train']['dir_mesh'] + with torch.no_grad(): + v, f, _ = mc_from_psr(psr_grid, + zero_level=self.cfg['data']['zero_level'], real_scale=True) + v = v * 2 - 1 + if scale is not None: + v *= scale + if center is not None: + v += center + mesh = o3d.geometry.TriangleMesh() + mesh.vertices = o3d.utility.Vector3dVector(v) + mesh.triangles = o3d.utility.Vector3iVector(f) + outdir_mesh = os.path.join(dir_mesh, '{:04d}.ply'.format(epoch)) + o3d.io.write_triangle_mesh(outdir_mesh, mesh) + + if self.cfg['train']['vis_psr']: + dir_psr_vis = self.cfg['train']['out_dir']+'/psr_vis_all' + visualize_psr_grid(psr_grid, out_dir=dir_psr_vis) diff --git a/src/training.py b/src/training.py new file mode 100644 index 0000000..b38a1d4 --- /dev/null +++ b/src/training.py @@ -0,0 +1,207 @@ +import os +import numpy as np +import torch +from torch.nn import functional as F +from collections import defaultdict +import trimesh +from tqdm import tqdm + +from src.dpsr import DPSR +from src.utils import grid_interp, export_pointcloud, export_mesh, \ + mc_from_psr, scale2onet, GaussianSmoothing +from pytorch3d.ops.knn import knn_gather, knn_points +from pytorch3d.loss import chamfer_distance +from pdb import set_trace as st + +class Trainer(object): + ''' + Args: + model (nn.Module): our defined model + optimizer (optimizer): pytorch optimizer object + device (device): pytorch device + input_type (str): input type + vis_dir (str): visualization directory + ''' + def __init__(self, cfg, optimizer, device=None): + self.optimizer = optimizer + self.device = device + self.cfg = cfg + if self.cfg['train']['w_raw'] != 0: + from src.model import PSR2Mesh + self.psr2mesh = PSR2Mesh.apply + + # initialize DPSR + self.dpsr = DPSR(res=(cfg['model']['grid_res'], + cfg['model']['grid_res'], + cfg['model']['grid_res']), + sig=cfg['model']['psr_sigma']) + if torch.cuda.device_count() > 1: + self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR + self.dpsr = self.dpsr.to(device) + + if cfg['train']['gauss_weight']>0.: + self.gauss_smooth = GaussianSmoothing(1, 7, 2).to(device) + + def train_step(self, inputs, data, model): + ''' Performs a training step. + + Args: + data (dict): data dictionary + ''' + self.optimizer.zero_grad() + p = data.get('inputs').to(self.device) + + out = model(p) + + points, normals = out + + loss = 0 + loss_each = {} + if self.cfg['train']['w_psr'] != 0: + psr_gt = data.get('gt_psr').to(self.device) + if self.cfg['model']['psr_tanh']: + psr_gt = torch.tanh(psr_gt) + + psr_grid = self.dpsr(points, normals) + if self.cfg['model']['psr_tanh']: + psr_grid = torch.tanh(psr_grid) + + # apply a rescaling weight based on GT SDF values + if self.cfg['train']['gauss_weight']>0: + gauss_sigma = self.cfg['train']['gauss_weight'] + # set up the weighting for loss, higher weights + # for points near to the surface + psr_gt_pad = torch.nn.ReplicationPad3d(1)(psr_gt.unsqueeze(1)).squeeze(1) + delta_x = delta_y = delta_z = 1 + grad_x = (psr_gt_pad[:, 2:, :, :] - psr_gt_pad[:, :-2, :, :]) / 2 / delta_x + grad_y = (psr_gt_pad[:, :, 2:, :] - psr_gt_pad[:, :, :-2, :]) / 2 / delta_y + grad_z = (psr_gt_pad[:, :, :, 2:] - psr_gt_pad[:, :, :, :-2]) / 2 / delta_z + grad_x = grad_x[:, :, 1:-1, 1:-1] + grad_y = grad_y[:, 1:-1, :, 1:-1] + grad_z = grad_z[:, 1:-1, 1:-1, :] + psr_grad = torch.stack([grad_x, grad_y, grad_z], dim=-1) + psr_grad_norm = psr_grad.norm(dim=-1)[:, None] + w = torch.nn.ReplicationPad3d(3)(psr_grad_norm) + w = 2*self.gauss_smooth(w).squeeze(1) + loss_each['psr'] = self.cfg['train']['w_psr'] * F.mse_loss(w*psr_grid, w*psr_gt) + else: + loss_each['psr'] = self.cfg['train']['w_psr'] * F.mse_loss(psr_grid, psr_gt) + + loss += loss_each['psr'] + + # regularization on the input point positions via chamfer distance + if self.cfg['train']['w_reg_point'] != 0.: + points_gt = data.get('gt_points').to(self.device) + loss_reg, loss_norm = chamfer_distance(points, points_gt) + + loss_each['reg'] = self.cfg['train']['w_reg_point'] * loss_reg + loss += loss_each['reg'] + + if self.cfg['train']['w_normals'] != 0.: + points_gt = data.get('gt_points').to(self.device) + normals_gt = data.get('gt_points.normals').to(self.device) + x_nn = knn_points(points, points_gt, K=1) + x_normals_near = knn_gather(normals_gt, x_nn.idx)[..., 0, :] + + cham_norm_x = F.l1_loss(normals, x_normals_near) + loss_norm = cham_norm_x + + loss_each['normals'] = self.cfg['train']['w_normals'] * loss_norm + loss += loss_each['normals'] + + if self.cfg['train']['w_raw'] != 0: + res = self.cfg['model']['grid_res'] + # DPSR to get grid + psr_grid = self.dpsr(points, normals) + if self.cfg['model']['psr_tanh']: + psr_grid = torch.tanh(psr_grid) + + v, f, n = self.psr2mesh(psr_grid) + + pts_gt = data.get('gt_points').to(self.device) + + loss, _ = chamfer_distance(v, pts_gt) + + loss.backward() + self.optimizer.step() + + return loss.item(), loss_each + + def save(self, model, data, epoch, id): + + p = data.get('inputs').to(self.device) + + exp_pcl = self.cfg['train']['exp_pcl'] + exp_mesh = self.cfg['train']['exp_mesh'] + exp_gt = self.cfg['generation']['exp_gt'] + exp_input = self.cfg['generation']['exp_input'] + + model.eval() + with torch.no_grad(): + points, normals = model(p) + + if exp_gt: + points_gt = data.get('gt_points').to(self.device) + normals_gt = data.get('gt_points.normals').to(self.device) + + if exp_pcl: + dir_pcl = self.cfg['train']['dir_pcl'] + export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}.ply'.format(epoch, id)), scale2onet(points), normals) + if exp_gt: + export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}_oracle.ply'.format(epoch, id)), scale2onet(points_gt), normals_gt) + if exp_input: + export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}_input.ply'.format(epoch, id)), scale2onet(p)) + + if exp_mesh: + dir_mesh = self.cfg['train']['dir_mesh'] + psr_grid = self.dpsr(points, normals) + # psr_grid = torch.tanh(psr_grid) + with torch.no_grad(): + v, f, _ = mc_from_psr(psr_grid, + zero_level=self.cfg['data']['zero_level']) + outdir_mesh = os.path.join(dir_mesh, '{:04d}_{:01d}.ply'.format(epoch, id)) + export_mesh(outdir_mesh, scale2onet(v), f) + if exp_gt: + psr_gt = self.dpsr(points_gt, normals_gt) + with torch.no_grad(): + v, f, _ = mc_from_psr(psr_gt, + zero_level=self.cfg['data']['zero_level']) + export_mesh(os.path.join(dir_mesh, '{:04d}_{:01d}_oracle.ply'.format(epoch, id)), scale2onet(v), f) + + def evaluate(self, val_loader, model): + ''' Performs an evaluation. + Args: + val_loader (dataloader): pytorch dataloader + ''' + eval_list = defaultdict(list) + + for data in tqdm(val_loader): + eval_step_dict = self.eval_step(data, model) + + for k, v in eval_step_dict.items(): + eval_list[k].append(v) + + eval_dict = {k: np.mean(v) for k, v in eval_list.items()} + return eval_dict + + def eval_step(self, data, model): + ''' Performs an evaluation step. + Args: + data (dict): data dictionary + ''' + model.eval() + eval_dict = {} + + p = data.get('inputs').to(self.device) + psr_gt = data.get('gt_psr').to(self.device) + + with torch.no_grad(): + # forward pass + points, normals = model(p) + # DPSR to get predicted psr grid + psr_grid = self.dpsr(points, normals) + + eval_dict['psr_l1'] = F.l1_loss(psr_grid, psr_gt).item() + eval_dict['psr_l2'] = F.mse_loss(psr_grid, psr_gt).item() + + return eval_dict \ No newline at end of file diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..b9ce35b --- /dev/null +++ b/src/utils.py @@ -0,0 +1,645 @@ +import torch +import io, os, logging, urllib +import yaml +import trimesh +import imageio +import numbers +import math +import numpy as np +from collections import OrderedDict +from plyfile import PlyData +from torch import nn +from torch.nn import functional as F +from torch.utils import model_zoo +from skimage import measure, img_as_float32 +from pytorch3d.structures import Meshes +from pytorch3d.renderer import PerspectiveCameras, rasterize_meshes +from igl import adjacency_matrix, connected_components +import open3d as o3d + +################################################## +# Below are functions for DPSR + +def fftfreqs(res, dtype=torch.float32, exact=True): + """ + Helper function to return frequency tensors + :param res: n_dims int tuple of number of frequency modes + :return: + """ + + n_dims = len(res) + freqs = [] + for dim in range(n_dims - 1): + r_ = res[dim] + freq = np.fft.fftfreq(r_, d=1/r_) + freqs.append(torch.tensor(freq, dtype=dtype)) + r_ = res[-1] + if exact: + freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype)) + else: + freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype)) + omega = torch.meshgrid(freqs) + omega = list(omega) + omega = torch.stack(omega, dim=-1) + + return omega + +def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag) + """ + multiply tensor x by i ** deg + """ + deg %= 4 + if deg == 0: + res = x + elif deg == 1: + res = x[..., [1, 0]] + res[..., 0] = -res[..., 0] + elif deg == 2: + res = -x + elif deg == 3: + res = x[..., [1, 0]] + res[..., 1] = -res[..., 1] + return res + +def spec_gaussian_filter(res, sig): + omega = fftfreqs(res, dtype=torch.float64) # [dim0, dim1, dim2, d] + dis = torch.sqrt(torch.sum(omega ** 2, dim=-1)) + filter_ = torch.exp(-0.5*((sig*2*dis/res[0])**2)).unsqueeze(-1).unsqueeze(-1) + filter_.requires_grad = False + + return filter_ + +def grid_interp(grid, pts, batched=True): + """ + :param grid: tensor of shape (batch, *size, in_features) + :param pts: tensor of shape (batch, num_points, dim) within range (0, 1) + :return values at query points + """ + if not batched: + grid = grid.unsqueeze(0) + pts = pts.unsqueeze(0) + dim = pts.shape[-1] + bs = grid.shape[0] + size = torch.tensor(grid.shape[1:-1]).to(grid.device).type(pts.dtype) + cubesize = 1.0 / size + + ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim) + ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around + ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim) + tmp = torch.tensor([0,1],dtype=torch.long) + com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim) + dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim) + ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points) + ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim) + ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim) + # latent code on neighbor nodes + if dim == 2: + lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1]] # (batch, num_points, 2**dim, in_features) + else: + lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2]] # (batch, num_points, 2**dim, in_features) + + # weights of neighboring nodes + xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim) + xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim) + xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim) + pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim) + pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim) + pos_ = pos_.type(pts.dtype) + dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim) + weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim) + query_values = torch.sum(lat * weights.unsqueeze(-1), dim=-2) # (batch, num_points, in_features) + if not batched: + query_values = query_values.squeeze(0) + + return query_values + +def scatter_to_grid(inds, vals, size): + """ + Scatter update values into empty tensor of size size. + :param inds: (#values, dims) + :param vals: (#values) + :param size: tuple for size. len(size)=dims + """ + dims = inds.shape[1] + assert(inds.shape[0] == vals.shape[0]) + assert(len(size) == dims) + dev = vals.device + # result = torch.zeros(*size).view(-1).to(dev).type(vals.dtype) # flatten + # # flatten inds + result = torch.zeros(*size, device=dev).view(-1).type(vals.dtype) # flatten + # flatten inds + fac = [np.prod(size[i+1:]) for i in range(len(size)-1)] + [1] + fac = torch.tensor(fac, device=dev).type(inds.dtype) + inds_fold = torch.sum(inds*fac, dim=-1) # [#values,] + result.scatter_add_(0, inds_fold, vals) + result = result.view(*size) + return result + +def point_rasterize(pts, vals, size): + """ + :param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1) + :param vals: point values, tensor of shape (batch, num_points, features) + :param size: len(size)=dim tuple for grid size + :return rasterized values (batch, features, res0, res1, res2) + """ + dim = pts.shape[-1] + assert(pts.shape[:2] == vals.shape[:2]) + assert(pts.shape[2] == dim) + size_list = list(size) + size = torch.tensor(size).to(pts.device).float() + cubesize = 1.0 / size + bs = pts.shape[0] + nf = vals.shape[-1] + npts = pts.shape[1] + dev = pts.device + + ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim) + ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around + ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim) + tmp = torch.tensor([0,1],dtype=torch.long) + com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim) + dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim) + ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points) + ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim) + # ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim) + ind_b = torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim) + + # weights of neighboring nodes + xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim) + xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim) + xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim) + pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim) + pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim) + pos_ = pos_.type(pts.dtype) + dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim) + weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim) + + ind_b = ind_b.unsqueeze(-1).unsqueeze(-1) # (batch, num_points, 2**dim, 1, 1) + ind_n = ind_n.unsqueeze(-2) # (batch, num_points, 2**dim, 1, dim) + ind_f = torch.arange(nf, device=dev).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1) + # ind_f = torch.arange(nf).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1) + + ind_b = ind_b.expand(bs, npts, 2**dim, nf, 1) + ind_n = ind_n.expand(bs, npts, 2**dim, nf, dim).to(dev) + ind_f = ind_f.expand(bs, npts, 2**dim, nf, 1) + inds = torch.cat([ind_b, ind_f, ind_n], dim=-1) # (batch, num_points, 2**dim, nf, 1+1+dim) + + # weighted values + vals = weights.unsqueeze(-1) * vals.unsqueeze(-2) # (batch, num_points, 2**dim, nf) + + inds = inds.view(-1, dim+2).permute(1, 0).long() # (1+dim+1, bs*npts*2**dim*nf) + vals = vals.reshape(-1) # (bs*npts*2**dim*nf) + tensor_size = [bs, nf] + size_list + raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf] + size_list) + + return raster # [batch, nf, res, res, res] + + + +################################################## +# Below are the utilization functions in general + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.n = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.n = n + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + @property + def valcavg(self): + return self.val.sum().item() / (self.n != 0).sum().item() + + @property + def avgcavg(self): + return self.avg.sum().item() / (self.count != 0).sum().item() + +def load_model_manual(state_dict, model): + new_state_dict = OrderedDict() + is_model_parallel = isinstance(model, torch.nn.DataParallel) + for k, v in state_dict.items(): + if k.startswith('module.') != is_model_parallel: + if k.startswith('module.'): + # remove module + k = k[7:] + else: + # add module + k = 'module.' + k + + new_state_dict[k]=v + + model.load_state_dict(new_state_dict) + +def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0): + ''' + Run marching cubes from PSR grid + ''' + batch_size = psr_grid.shape[0] + s = psr_grid.shape[-1] # size of psr_grid + psr_grid_numpy = psr_grid.squeeze().detach().cpu().numpy() + + if batch_size>1: + verts, faces, normals = [], [], [] + for i in range(batch_size): + verts_cur, faces_cur, normals_cur, values = measure.marching_cubes(psr_grid_numpy[i], level=0) + verts.append(verts_cur) + faces.append(faces_cur) + normals.append(normals_cur) + verts = np.stack(verts, axis = 0) + faces = np.stack(faces, axis = 0) + normals = np.stack(normals, axis = 0) + else: + try: + verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy, level=zero_level) + except: + verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy) + if real_scale: + verts = verts / (s-1) # scale to range [0, 1] + else: + verts = verts / s # scale to range [0, 1) + + if pytorchify: + device = psr_grid.device + verts = torch.Tensor(np.ascontiguousarray(verts)).to(device) + faces = torch.Tensor(np.ascontiguousarray(faces)).to(device) + normals = torch.Tensor(np.ascontiguousarray(-normals)).to(device) + + return verts, faces, normals + +def calc_inters_points(verts, faces, pose, img_size, mask_gt=None): + verts = verts.squeeze() + faces = faces.squeeze() + pix_to_face, w, mask = mesh_rasterization(verts, faces, pose, img_size) + if mask_gt is not None: + #! only evaluate within the intersection + mask = mask & mask_gt + # find 3D points intesected on the mesh + if True: + w_masked = w[mask] + f_p = faces[pix_to_face[mask]].long() # cooresponding faces for each pixel + # corresponding vertices for p_closest + v_a, v_b, v_c = verts[f_p[..., 0]], verts[f_p[..., 1]], verts[f_p[..., 2]] + + # calculate the intersection point of each pixel and the mesh + p_inters = w_masked[..., 0, None] * v_a + \ + w_masked[..., 1, None] * v_b + \ + w_masked[..., 2, None] * v_c + else: + # backproject ndc to world coordinates using z-buffer + W, H = img_size[1], img_size[0] + xy = uv.to(mask.device)[mask] + x_ndc = 1 - (2*xy[:, 0]) / (W - 1) + y_ndc = 1 - (2*xy[:, 1]) / (H - 1) + z = zbuf.squeeze().reshape(H * W)[mask] + xy_depth = torch.stack((x_ndc, y_ndc, z), dim=1) + + p_inters = pose.unproject_points(xy_depth, world_coordinates=True) + + # if there are outlier points, we should remove it + if (p_inters.max()>1) | (p_inters.min()<-1): + mask_bound = (p_inters>=-1) & (p_inters<=1) + mask_bound = (mask_bound.sum(dim=-1)==3) + mask[mask==True] = mask_bound + p_inters = p_inters[mask_bound] + print('!!!!!find outlier!') + + return p_inters, mask, f_p, w_masked + +def mesh_rasterization(verts, faces, pose, img_size): + ''' + Use PyTorch3D to rasterize the mesh given a camera + ''' + transformed_v = pose.transform_points(verts.detach()) # world -> ndc coordinate system + if isinstance(pose, PerspectiveCameras): + transformed_v[..., 2] = 1/transformed_v[..., 2] + # find p_closest on mesh of each pixel via rasterization + transformed_mesh = Meshes(verts=[transformed_v], faces=[faces]) + pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( + transformed_mesh, + image_size=img_size, + blur_radius=0, + faces_per_pixel=1, + perspective_correct=False + ) + pix_to_face = pix_to_face.reshape(1, -1) # B x reso x reso -> B x (reso x reso) + mask = pix_to_face.clone() != -1 + mask = mask.squeeze() + pix_to_face = pix_to_face.squeeze() + w = bary_coords.reshape(-1, 3) + + return pix_to_face, w, mask + +def verts_on_largest_mesh(verts, faces): + ''' + verts: Numpy array or Torch.Tensor (N, 3) + faces: Numpy array (N, 3) + ''' + if torch.is_tensor(faces): + verts = verts.squeeze().detach().cpu().numpy() + faces = faces.squeeze().int().detach().cpu().numpy() + + A = adjacency_matrix(faces) + num, conn_idx, conn_size = connected_components(A) + if num == 0: + v_large, f_large = verts, faces + else: + max_idx = conn_size.argmax() # find the index of the largest component + v_large = verts[conn_idx==max_idx] # keep points on the largest component + + if True: + mesh_largest = trimesh.Trimesh(verts, faces) + connected_comp = mesh_largest.split(only_watertight=False) + mesh_largest = connected_comp[max_idx] + v_large, f_large = mesh_largest.vertices, mesh_largest.faces + v_large = v_large.astype(np.float32) + return v_large, f_large + +def load_pointcloud(in_file): + plydata = PlyData.read(in_file) + vertices = np.stack([ + plydata['vertex']['x'], + plydata['vertex']['y'], + plydata['vertex']['z'] + ], axis=1) + return vertices + +# General config +def load_config(path, default_path=None): + ''' Loads config file. + + Args: + path (str): path to config file + default_path (bool): whether to use default path + ''' + # Load configuration from file itself + with open(path, 'r') as f: + cfg_special = yaml.load(f, Loader=yaml.Loader) + + # Check if we should inherit from a config + inherit_from = cfg_special.get('inherit_from') + + # If yes, load this config first as default + # If no, use the default_path + if inherit_from is not None: + cfg = load_config(inherit_from, default_path) + elif default_path is not None: + with open(default_path, 'r') as f: + cfg = yaml.load(f, Loader=yaml.Loader) + else: + cfg = dict() + + # Include main configuration + update_recursive(cfg, cfg_special) + + return cfg + +def update_config(config, unknown): + # update config given args + for idx,arg in enumerate(unknown): + if arg.startswith("--"): + keys = arg.replace("--","").split(':') + assert(len(keys)==2) + k1, k2 = keys + argtype = type(config[k1][k2]) + if argtype == bool: + v = unknown[idx+1].lower() == 'true' + else: + if config[k1][k2] is not None: + v = type(config[k1][k2])(unknown[idx+1]) + else: + v = unknown[idx+1] + print(f'Changing {k1}:{k2} ---- {config[k1][k2]} to {v}') + config[k1][k2] = v + return config + +def initialize_logger(cfg): + out_dir = cfg['train']['out_dir'] + if not out_dir: + os.makedirs(out_dir) + + cfg['train']['dir_model'] = os.path.join(out_dir, 'model') + os.makedirs(cfg['train']['dir_model'], exist_ok=True) + + if cfg['train']['exp_mesh']: + cfg['train']['dir_mesh'] = os.path.join(out_dir, 'vis/mesh') + os.makedirs(cfg['train']['dir_mesh'], exist_ok=True) + if cfg['train']['exp_pcl']: + cfg['train']['dir_pcl'] = os.path.join(out_dir, 'vis/pointcloud') + os.makedirs(cfg['train']['dir_pcl'], exist_ok=True) + if cfg['train']['vis_rendering']: + cfg['train']['dir_rendering'] = os.path.join(out_dir, 'vis/rendering') + os.makedirs(cfg['train']['dir_rendering'], exist_ok=True) + if cfg['train']['o3d_show']: + cfg['train']['dir_o3d'] = os.path.join(out_dir, 'vis/o3d') + os.makedirs(cfg['train']['dir_o3d'], exist_ok=True) + + + logger = logging.getLogger("train") + logger.setLevel(logging.DEBUG) + logger.handlers = [] + # ch = logging.StreamHandler() + # logger.addHandler(ch) + fh = logging.FileHandler(os.path.join(cfg['train']['out_dir'], "log.txt")) + logger.addHandler(fh) + logger.info('Outout dir: %s', out_dir) + return logger + +def update_recursive(dict1, dict2): + ''' Update two config dictionaries recursively. + + Args: + dict1 (dict): first dictionary to be updated + dict2 (dict): second dictionary which entries should be used + + ''' + for k, v in dict2.items(): + if k not in dict1: + dict1[k] = dict() + if isinstance(v, dict): + update_recursive(dict1[k], v) + else: + dict1[k] = v + +def export_pointcloud(name, points, normals=None): + if len(points.shape) > 2: + points = points[0] + if normals is not None: + normals = normals[0] + if isinstance(points, torch.Tensor): + points = points.detach().cpu().numpy() + if normals is not None: + normals = normals.detach().cpu().numpy() + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + if normals is not None: + pcd.normals = o3d.utility.Vector3dVector(normals) + o3d.io.write_point_cloud(name, pcd) + +def export_mesh(name, v, f): + if len(v.shape) > 2: + v, f = v[0], f[0] + if isinstance(v, torch.Tensor): + v = v.detach().cpu().numpy() + f = f.detach().cpu().numpy() + mesh = o3d.geometry.TriangleMesh() + mesh.vertices = o3d.utility.Vector3dVector(v) + mesh.triangles = o3d.utility.Vector3iVector(f) + o3d.io.write_triangle_mesh(name, mesh) + +def scale2onet(p, scale=1.2): + ''' + Scale the point cloud from SAP to ONet range + ''' + return (p - 0.5) * scale + +def update_optimizer(inputs, cfg, epoch, model=None, schedule=None): + if model is not None: + if schedule is not None: + optimizer = torch.optim.Adam([ + {"params": model.parameters(), + "lr": schedule[0].get_learning_rate(epoch)}, + {"params": inputs, + "lr": schedule[1].get_learning_rate(epoch)}]) + elif 'lr' in cfg['train']: + optimizer = torch.optim.Adam([ + {"params": model.parameters(), + "lr": float(cfg['train']['lr'])}, + {"params": inputs, + "lr": float(cfg['train']['lr_pcl'])}]) + else: + raise Exception('no known learning rate') + else: + if schedule is not None: + optimizer = torch.optim.Adam([inputs], lr=schedule[0].get_learning_rate(epoch)) + else: + optimizer = torch.optim.Adam([inputs], lr=float(cfg['train']['lr_pcl'])) + + return optimizer + + +def is_url(url): + scheme = urllib.parse.urlparse(url).scheme + return scheme in ('http', 'https') + +def load_url(url): + '''Load a module dictionary from url. + + Args: + url (str): url to saved model + ''' + print(url) + print('=> Loading checkpoint from url...') + state_dict = model_zoo.load_url(url, progress=True) + + return state_dict + + +class GaussianSmoothing(nn.Module): + """ + Apply gaussian smoothing on a + 1d, 2d or 3d tensor. Filtering is performed seperately for each channel + in the input using a depthwise convolution. + Arguments: + channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. + sigma (float, sequence): Standard deviation of the gaussian kernel. + dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + def __init__(self, channels, kernel_size, sigma, dim=3): + super(GaussianSmoothing, self).__init__() + if isinstance(kernel_size, numbers.Number): + kernel_size = [kernel_size] * dim + if isinstance(sigma, numbers.Number): + sigma = [sigma] * dim + + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ + torch.exp(-((mgrid - mean) / std) ** 2 / 2) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer('weight', kernel) + self.groups = channels + + if dim == 1: + self.conv = F.conv1d + elif dim == 2: + self.conv = F.conv2d + elif dim == 3: + self.conv = F.conv3d + else: + raise RuntimeError( + 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) + ) + + def forward(self, input): + """ + Apply gaussian filter to input. + Arguments: + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return self.conv(input, weight=self.weight, groups=self.groups) + +# Originally from https://github.com/amosgropp/IGR/blob/0db06b1273/code/utils/general.py +def get_learning_rate_schedules(schedule_specs): + + schedules = [] + + for key in schedule_specs.keys(): + schedules.append(StepLearningRateSchedule( + schedule_specs[key]['initial'], + schedule_specs[key]["interval"], + schedule_specs[key]["factor"], + schedule_specs[key]["final"])) + return schedules + +class LearningRateSchedule: + def get_learning_rate(self, epoch): + pass +class StepLearningRateSchedule(LearningRateSchedule): + def __init__(self, initial, interval, factor, final=1e-6): + self.initial = float(initial) + self.interval = interval + self.factor = factor + self.final = float(final) + + def get_learning_rate(self, epoch): + lr = np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6) + if lr > self.final: + return lr + else: + return self.final + +def adjust_learning_rate(lr_schedules, optimizer, epoch): + for i, param_group in enumerate(optimizer.param_groups): + param_group["lr"] = lr_schedules[i].get_learning_rate(epoch) diff --git a/src/visualize.py b/src/visualize.py new file mode 100644 index 0000000..0eec2da --- /dev/null +++ b/src/visualize.py @@ -0,0 +1,175 @@ +import os +import torch +import numpy as np +import open3d as o3d +import matplotlib.pyplot as plt +from skimage import measure +from src.utils import calc_inters_points, grid_interp +from scipy import ndimage +from tqdm import trange +from torchvision.utils import save_image +from pdb import set_trace as st + + +def visualize_points_mesh(vis, points, normals, verts, faces, cfg, it, epoch, color_v=None): + ''' Visualization. + + Args: + data (dict): data dictionary + depth (int): PSR depth + out_path (str): output path for the mesh + ''' + mesh = o3d.geometry.TriangleMesh() + mesh.vertices = o3d.utility.Vector3dVector(verts) + mesh.triangles = o3d.utility.Vector3iVector(faces) + mesh.paint_uniform_color(np.array([0.7,0.7,0.7])) + if color_v is not None: + mesh.vertex_colors = o3d.utility.Vector3dVector(color_v) + + if vis is not None: + dir_o3d = cfg['train']['dir_o3d'] + wire = o3d.geometry.LineSet.create_from_triangle_mesh(mesh) + + p = points.squeeze(0).detach().cpu().numpy() + n = normals.squeeze(0).detach().cpu().numpy() + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(p) + pcd.normals = o3d.utility.Vector3dVector(n) + pcd.paint_uniform_color(np.array([0.7,0.7,1.0])) + # pcd = pcd.uniform_down_sample(5) + + vis.clear_geometries() + vis.add_geometry(mesh) + vis.update_geometry(mesh) + + #! Thingi wheel - an example for how to change cameras in Open3D viewers + vis.get_view_control().set_front([ 0.0461, -0.7467, 0.6635 ]) + vis.get_view_control().set_lookat([ 0.0092, 0.0078, 0.0638 ]) + vis.get_view_control().set_up([ 0.0520, 0.6651, 0.7449 ]) + vis.get_view_control().set_zoom(0.7) + vis.poll_events() + + out_path = os.path.join(dir_o3d, '{}.jpg'.format(it)) + vis.capture_screen_image(out_path) + + vis.clear_geometries() + vis.add_geometry(pcd, reset_bounding_box=False) + vis.update_geometry(pcd) + vis.get_render_option().point_show_normal=True # visualize point normals + + vis.get_view_control().set_front([ 0.0461, -0.7467, 0.6635 ]) + vis.get_view_control().set_lookat([ 0.0092, 0.0078, 0.0638 ]) + vis.get_view_control().set_up([ 0.0520, 0.6651, 0.7449 ]) + vis.get_view_control().set_zoom(0.7) + vis.poll_events() + + out_path = os.path.join(dir_o3d, '{}_pcd.jpg'.format(it)) + vis.capture_screen_image(out_path) + +def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name='video.mp4'): + if pose is not None: + device = psr_grid.device + # get world coordinate of grid points [-1, 1] + res = psr_grid.shape[-1] + x = torch.linspace(-1, 1, steps=res) + co_x, co_y, co_z = torch.meshgrid(x, x, x) + co_grid = torch.stack( + [co_x.reshape(-1), co_y.reshape(-1), co_z.reshape(-1)], + dim=1).to(device).unsqueeze(0) + + # visualize the projected occ_soft value + res = 128 + psr_grid = psr_grid.reshape(-1) + out_mask = psr_grid>0 + in_mask = psr_grid<0 + pix = pose.transform_points_screen(co_grid, ((res, res),))[..., :2].round().long().squeeze() + vis_mask = (pix[..., 0]>=0) & (pix[..., 0]<=res-1) & \ + (pix[..., 1]>=0) & (pix[..., 1]<=res-1) + pix_out = pix[vis_mask & out_mask] + pix_in = pix[vis_mask & in_mask] + + img = torch.ones([res,res]).to(device) + psr_grid = torch.sigmoid(- psr_grid * 5) + img[pix_out[:, 1], pix_out[:, 0]] = psr_grid[vis_mask & out_mask] + img[pix_in[:, 1], pix_in[:, 0]] = psr_grid[vis_mask & in_mask] + # save_image(img, 'tmp.png', normalize=True) + return img + elif out_dir is not None: + dir_psr_vis = out_dir + os.makedirs(dir_psr_vis, exist_ok=True) + psr_grid = psr_grid.squeeze().detach().cpu().numpy() + axis = ['x', 'y', 'z'] + s = psr_grid.shape[0] + for idx in trange(s): + my_dpi = 100 + plt.figure(figsize=(1000/my_dpi, 300/my_dpi), dpi=my_dpi) + plt.subplot(1, 3, 1) + plt.imshow(ndimage.rotate(psr_grid[idx], 180, mode='nearest'), cmap='nipy_spectral') + plt.clim(-1, 1) + plt.colorbar() + plt.title('x') + plt.grid("off") + plt.axis("off") + + plt.subplot(1, 3, 2) + plt.imshow(ndimage.rotate(psr_grid[:, idx], 180, mode='nearest'), cmap='nipy_spectral') + plt.clim(-1, 1) + plt.colorbar() + plt.title('y') + plt.grid("off") + plt.axis("off") + + plt.subplot(1, 3, 3) + plt.imshow(ndimage.rotate(psr_grid[:,:,idx], 90, mode='nearest'), cmap='nipy_spectral') + plt.clim(-1, 1) + plt.colorbar() + plt.title('z') + plt.grid("off") + plt.axis("off") + + + plt.savefig(os.path.join(dir_psr_vis, '{}'.format(idx)), pad_inches = 0, dpi=100) + plt.close() + os.system("rm {}/{}".format(dir_psr_vis, out_video_name)) + os.system("ffmpeg -framerate 25 -start_number 0 -i {}/%d.png -pix_fmt yuv420p -crf 17 {}/{}".format(dir_psr_vis, dir_psr_vis, out_video_name)) + +def visualize_mesh_phong(v, f, n, pose, img_size, name, device='cpu'): + #! Mesh rendering using Phong shading model + _, mask, f_p, w = calc_inters_points(v, f, pose, img_size) + n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]] + n_inters = w[..., 0, None] * n_a.squeeze() + \ + w[..., 1, None] * n_b.squeeze() + \ + w[..., 2, None] * n_c.squeeze() + n_inters = n_inters.detach().to(device) + light_source = -pose.R@pose.T.squeeze() + light = (light_source / light_source.norm(2)).permute(1, 0).to(device).float() + diffuse_per = torch.Tensor([0.7,0.7,0.7]).float() + ambiant = torch.Tensor([0.3,0.3,0.3]).float() + + diffuse = torch.mm(n_inters, light).clamp_min(0).repeat(1, 3) * diffuse_per.unsqueeze(0).to(device) + + phong = torch.ones([img_size[0]*img_size[1], 3]).to(device) + phong[mask] = (ambiant.unsqueeze(0).to(device) + diffuse).clamp_max(1.0) + pp = phong.reshape(img_size[0], img_size[1], -1) + save_image(pp.permute(2, 0, 1), name) + +def render_rgb(v, f, n, pose, renderer, img_size, mask_gt=None, ray=None, fea_grid=None): + p_inters, mask, f_p, w = calc_inters_points(v.detach(), f, pose, img_size, mask_gt=mask_gt) + # normals for p_inters + n_inters = None + if n is not None: + n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]] + n_inters = w[..., 0, None] * n_a.squeeze() + \ + w[..., 1, None] * n_b.squeeze() + \ + w[..., 2, None] * n_c.squeeze() + if ray is not None: + ray = ray.squeeze()[mask] + + fea = None + if fea_grid is not None: + fea = grid_interp(fea_grid, (p_inters.detach()[None] + 1) / 2).squeeze() + + # use MLP to regress color + color_pred = renderer(p_inters, normals=n_inters, view_dirs=ray, feature_vectors=fea).squeeze() + + return color_pred, mask \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..2bc08f8 --- /dev/null +++ b/train.py @@ -0,0 +1,185 @@ +import os + +abspath = os.path.abspath(__file__) +dname = os.path.dirname(abspath) +os.chdir(dname) + +import torch +import torch.optim as optim +import open3d as o3d + +import numpy as np; np.set_printoptions(precision=4) +import shutil, argparse, time +from torch.utils.tensorboard import SummaryWriter + +from src import config +from src.data import collate_remove_none, collate_stack_together, worker_init_fn +from src.training import Trainer +from src.model import Encode2Points +from src.utils import load_config, initialize_logger, \ +AverageMeter, load_model_manual + + +def main(): + parser = argparse.ArgumentParser(description='MNIST toy experiment') + parser.add_argument('config', type=str, help='Path to config file.') + parser.add_argument('--no_cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') + + args = parser.parse_args() + cfg = load_config(args.config, 'configs/default.yaml') + use_cuda = not args.no_cuda and torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + input_type = cfg['data']['input_type'] + batch_size = cfg['train']['batch_size'] + model_selection_metric = cfg['train']['model_selection_metric'] + + # PYTORCH VERSION > 1.0.0 + assert(float(torch.__version__.split('.')[-3]) > 0) + + # boiler-plate + if cfg['train']['timestamp']: + cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S") + logger = initialize_logger(cfg) + torch.manual_seed(args.seed) + np.random.seed(args.seed) + shutil.copyfile(args.config, os.path.join(cfg['train']['out_dir'], 'config.yaml')) + + logger.info("using GPU: " + torch.cuda.get_device_name(0)) + + # TensorboardX writer + tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log") + if not os.path.exists(tblogdir): + os.makedirs(tblogdir, exist_ok=True) + writer = SummaryWriter(log_dir=tblogdir) + + + inputs = None + train_dataset = config.get_dataset('train', cfg) + val_dataset = config.get_dataset('val', cfg) + vis_dataset = config.get_dataset('vis', cfg) + + + collate_fn = collate_remove_none + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, num_workers=cfg['train']['n_workers'], shuffle=True, + collate_fn=collate_fn, + worker_init_fn=worker_init_fn) + + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False, + collate_fn=collate_remove_none, + worker_init_fn=worker_init_fn) + + vis_loader = torch.utils.data.DataLoader( + vis_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False, + collate_fn=collate_fn, + worker_init_fn=worker_init_fn) + + if torch.cuda.device_count() > 1: + model = torch.nn.DataParallel(Encode2Points(cfg)).to(device) + else: + model = Encode2Points(cfg).to(device) + + n_parameter = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info('Number of parameters: %d'% n_parameter) + # load model + try: + # load model + state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt')) + load_model_manual(state_dict['state_dict'], model) + + out = "Load model from iteration %d" % state_dict.get('it', 0) + logger.info(out) + # load point cloud + except: + state_dict = dict() + + metric_val_best = state_dict.get( + 'loss_val_best', np.inf) + + logger.info('Current best validation metric (%s): %.8f' + % (model_selection_metric, metric_val_best)) + + LR = float(cfg['train']['lr']) + optimizer = optim.Adam(model.parameters(), lr=LR) + + start_epoch = state_dict.get('epoch', -1) + it = state_dict.get('it', -1) + + trainer = Trainer(cfg, optimizer, device=device) + runtime = {} + runtime['all'] = AverageMeter() + + # training loop + for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1): + + for batch in train_loader: + it += 1 + + start = time.time() + loss, loss_each = trainer.train_step(inputs, batch, model) + + # measure elapsed time + end = time.time() + runtime['all'].update(end - start) + + if it % cfg['train']['print_every'] == 0: + log_text = ('[Epoch %02d] it=%d, loss=%.4f') %(epoch, it, loss) + writer.add_scalar('train/loss', loss, it) + if loss_each is not None: + for k, l in loss_each.items(): + if l.item() != 0.: + log_text += (' loss_%s=%.4f') % (k, l.item()) + writer.add_scalar('train/%s' % k, l, it) + + log_text += (' time=%.3f / %.2f') % (runtime['all'].val, runtime['all'].sum) + logger.info(log_text) + + if (it>0)& (it % cfg['train']['visualize_every'] == 0): + for i, batch_vis in enumerate(vis_loader): + trainer.save(model, batch_vis, it, i) + if i >= 4: + break + logger.info('Saved mesh and pointcloud') + + # run validation + if it > 0 and (it % cfg['train']['validate_every']) == 0: + eval_dict = trainer.evaluate(val_loader, model) + metric_val = eval_dict[model_selection_metric] + logger.info('Validation metric (%s): %.4f' + % (model_selection_metric, metric_val)) + + for k, v in eval_dict.items(): + writer.add_scalar('val/%s' % k, v, it) + + if -(metric_val - metric_val_best) >= 0: + metric_val_best = metric_val + logger.info('New best model (loss %.4f)' % metric_val_best) + state = {'epoch': epoch, + 'it': it, + 'loss_val_best': metric_val_best} + state['state_dict'] = model.state_dict() + torch.save(state, os.path.join(cfg['train']['out_dir'], 'model_best.pt')) + + # save checkpoint + if (epoch > 0) & (it % cfg['train']['checkpoint_every'] == 0): + state = {'epoch': epoch, + 'it': it, + 'loss_val_best': metric_val_best} + pcl = None + state['state_dict'] = model.state_dict() + + torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt')) + + if (it % cfg['train']['backup_every'] == 0): + torch.save(state, os.path.join(cfg['train']['dir_model'], '%04d' % it + '.pt')) + logger.info("Backup model at iteration %d" % it) + logger.info("Save new model at iteration %d" % it) + + done=time.time() + +if __name__ == '__main__': + main() \ No newline at end of file