init commit

This commit is contained in:
pengsongyou 2021-11-08 11:09:50 +01:00
parent c4f63f1510
commit 12757682f1
58 changed files with 6478 additions and 4 deletions

150
.gitignore vendored Normal file
View file

@ -0,0 +1,150 @@
/out
/data
.vscode
.cache
*.pyc
*.pyd
*.pt
*.so
*.o
*.prof
*.swp
*.lib
*.obj
*.exp
.nfs*
*.jpg
*.png
*.ply
*.off
*.npz
*.txt
# *.sh
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/

160
README.md
View file

@ -1,5 +1,6 @@
# Shape As Points (SAP) # Shape As Points (SAP)
[**Paper**](https://arxiv.org/abs/2106.03452) | [**Project Page**](https://pengsongyou.github.io/sap) | [**Short Video (6 min)**](https://youtu.be/FL8LMk_qWb4) | [**Long Video (13 min)**](https://youtu.be/TgR0NvYty0A) <br>
### [**Paper**](https://arxiv.org/abs/2106.03452) | [**Project Page**](https://pengsongyou.github.io/sap) | [**Short Video (6 min)**](https://youtu.be/FL8LMk_qWb4) | [**Long Video (12 min)**](https://youtu.be/TgR0NvYty0A) <br>
![](./media/teaser_wheel.gif) ![](./media/teaser_wheel.gif)
@ -10,13 +11,164 @@ Shape As Points: A Differentiable Poisson Solver
**NeurIPS 2021 (Oral)** **NeurIPS 2021 (Oral)**
## Code is coming soon!
If you find our code or paper useful, please consider citing If you find our code or paper useful, please consider citing
```bibtex ```bibtex
@inproceedings{Peng2021SAP, @inproceedings{Peng2021SAP,
author = {Peng, Songyou and Jiang, Chiyu "Max" and Liao, Yiyi and Niemeyer, Michael and Pollefeys, Marc and Andreas, Geiger}, author = {Peng, Songyou and Jiang, Chiyu "Max" and Liao, Yiyi and Niemeyer, Michael and Pollefeys, Marc and Geiger, Andreas},
title = {Shape As Points: A Differentiable Poisson Solver}, title = {Shape As Points: A Differentiable Poisson Solver},
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
year = {2021}} year = {2021}}
``` ```
## Installation
First you have to make sure that you have all dependencies in place.
The simplest way to do so, is to use [anaconda](https://www.anaconda.com/).
You can create an anaconda environment called `sap` using
```
conda env create -f environment.yaml
conda activate sap
```
Now, you can install [PyTorch3D](https://pytorch3d.org/) 0.6.0 from the [official instruction](https://github.com/facebookresearch/pytorch3d/blob/master/INSTALL.md#3-install-wheels-for-linux) as follows
```sh
pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu102_pyt190/download.html
```
And install [PyTorch Scatter](https://github.com/rusty1s/pytorch_scatter):
```sh
conda install pytorch-scatter -c pyg
```
## Demo - Quick Start
First, run the script to get the demo data:
```bash
bash scripts/download_demo_data.sh
```
### Optimization-based 3D Surface Reconstruction
You can now quickly test our code on the data shown in the teaser. To this end, simply run:
```python
python optim_hierarchy.py configs/optim_based/teaser.yaml
```
This script should create a folder `out/demo_optim` where the output meshes and the optimized oriented point clouds under different grid resolution are stored.
To visualize the optimization process on the fly, you can set `o3d_show: Frue` in [`configs/optim_based/teaser.yaml`](https://github.com/autonomousvision/shape_as_points/tree/main/configs/optim_based/teaser.yaml).
### Learning-based 3D Surface Reconstruction
You can also test SAP on another application where we can reconstruct from unoriented point clouds with either **large noises** or **outliers** with a learned network.
![](./media/results_large_noise.gif)
For the point clouds with large noise as shown above, you can run:
```python
python generate.py configs/learning_based/demo_large_noise.yaml
```
The results can been found at `out/demo_shapenet_large_noise/generation/vis`.
![](./media/results_outliers.gif)
As for the point clouds with outliers, you can run:
```python
python generate.py configs/learning_based/demo_outlier.yaml
```
You can find the reconstrution on `out/demo_shapenet_outlier/generation/vis`.
## Dataset
We have different dataset for our optimization-based and learning-based settings.
### Dataset for Optimization-based Reconstruction
Here we consider the following dataset:
- [Thingi10K](https://arxiv.org/abs/1605.04797) (synthetic)
- [Surface Reconstruction Benchmark (SRB)](https://github.com/fwilliams/deep-geometric-prior) (real scans)
- [MPI Dynamic FAUST](https://dfaust.is.tue.mpg.de/) (real scans)
Please cite the corresponding papers if you use the data.
You can download the processed dataset (~200 MB) by running:
```bash
bash scripts/download_optim_data.sh
```
### Dataset for Learning-based Reconstruction
We train and evaluate on [ShapeNet](https://shapenet.org/).
You can download the processed dataset (~220 GB) by running:
```bash
bash scripts/download_shapenet.sh
```
After, you should have the dataset in `data/shapenet_psr` folder.
Alternatively, you can also preprocess the dataset yourself. To this end, you can:
* first download the preprocessed dataset (73.4 GB) by running [the script](https://github.com/autonomousvision/occupancy_networks#preprocessed-data) from Occupancy Networks.
* check [`scripts/process_shapenet.py`](https://github.com/autonomousvision/shape_as_points/tree/main/scripts/process_shapenet.py), modify the base path and run the code
## Usage for Optimization-based 3D Reconstruction
For our optimization-based setting, you can consider running with a coarse-to-fine strategy:
```python
python optim_hierarchy.py configs/optim_based/CONFIG.yaml
```
We start from a grid resolution of 32^3, and increase to 64^3, 128^3 and finally 256^3.
Alternatively, you can also run on a single resolution with:
```python
python optim.py configs/optim_based/CONFIG.yaml
```
You might need to modify the `CONFIG.yaml` accordingly.
## Usage for Learning-based 3D Reconstruction
### Mesh Generation
To generate meshes using a trained model, use
```python
python generate.py configs/learning_based/CONFIG.yaml
```
where you replace `CONFIG.yaml` with the correct config file.
#### Use a pre-trained model
The easiest way is to use a pre-trained model. You can do this by using one of the config files with postfix `_pretrained`.
For example, for 3D reconstruction from point clouds with outliers using our model with 7x offsets, you can simply run:
```python
python generate.py configs/learning_based/outlier/ours_7x_pretrained.yaml
```
The script will automatically download the pretrained model and run the generation. You can find the outputs in the `out/.../generation_pretrained` folders.
**Note** config files are only for generation, not for training new models: when these configs are used for training, the model will be trained from scratch, but during inference our code will still use the pretrained model.
We provide the following pretrained models:
```
noise_small/ours.pt
noise_large/ours.pt
outlier/ours_1x.pt
outlier/ours_3x.pt
outlier/ours_5x.pt
outlier/ours_7x.pt
outlier/ours_3plane.pt
```
### Evaluation
To evaluate a trained model, we provide the script [`eval_meshes.py`](https://github.com/autonomousvision/shape_as_points/blob/main/eval_meshes.py). You can run it using:
```python
python eval_meshes.py configs/learning_based/CONFIG.yaml
```
The script takes the meshes generated in the previous step and evaluates them using a standardized protocol. The output will be written to `.pkl` and `.csv` files in the corresponding generation folder that can be processed using [pandas](https://pandas.pydata.org/).
### Training
Finally, to train a new network from scratch, simply run:
```python
python train.py configs/learning_based/CONFIG.yaml
```
For available training options, please take a look at `configs/default.yaml`.

104
configs/default.yaml Normal file
View file

@ -0,0 +1,104 @@
data:
dataset: Shapes3D
path: data/ShapeNet
class: null
data_type: img
input_type: pointcloud
dim: 3
num_points: 1000
num_gt_points: 1000
num_offset: 1
img_size: null
n_views_input: 20
n_views_per_iter: 2
pointcloud_noise: 0
pointcloud_file: pointcloud.npz
pointcloud_outlier_ratio: 0
fixed_scale: 0
train_split: train
val_split: val
test_split: test
points_file: null
points_iou_file: points.npz
points_unpackbits: true
padding: 0.1
multi_files: null
gt_mesh: null
zero_level: 0
only_single: False
sample_only_floor: False
model:
apply_sigmoid: True
grid_res: 128 # poisson grid resolution
psr_sigma: 0
psr_tanh: False
normal_normalize: False
raster: {}
renderer: {}
encoder: null
predict_normal: True
predict_offset: True
s_offset: 0.001
local_coord: True
encoder_kwargs: {}
unet3d: False
unet3d_kwargs: {}
multi_gpu: false
rotate_matrix: false
c_dim: 512
sphere_radius: 0.2
train:
lr: 1e-3
lr_pcl: 2e-2
input_mesh: ''
out_dir: out/default
subsample_vertex: False
batch_size: 1
n_grow_points: 0
resample_every: 0
l_weight: {}
w_reg_point: 0
w_psr: 0
w_raw: 0 # train with raw point cloud
gauss_weight: 0
n_sup_point: 0
w_normals: 0
total_epochs: 1000
print_every: 20
visualize_every: 1000
save_every: 1000
vis_vert_color: True
o3d_show: False
o3d_vis_pcl: True
o3d_window_size: 540
vis_rendering: False
vis_psr: False
save_video: False
exp_mesh: True
exp_pcl: True
checkpoint_every: 1000
validate_every: 2000000
backup_every: 100000
timestamp: False # add timestamp to out_dir name
model_selection_metric: loss
model_selection_mode: minimize
n_workers: 0
n_workers_val: 0
test:
threshold: 0.5
eval_mesh: true
eval_pointcloud: false
model_file: model_best.pt
generation:
batch_size: 100000
exp_gt: False
exp_oracle: false
exp_input: False
vis_n_outputs: 10
generate_mesh: true
generate_pointcloud: true
generation_dir: generation
copy_input: true
use_sampling: false
psr_resolution: 0
psr_sigma: 0

View file

@ -0,0 +1,9 @@
inherit_from: configs/learning_based/noise_large/ours.yaml
data:
class: ['']
path: data/demo/shapenet_chair
train:
out_dir: out/demo_shapenet_large_noise
test:
model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_noise_025.pt

View file

@ -0,0 +1,9 @@
inherit_from: configs/learning_based/outlier/ours_7x.yaml
data:
class: ['']
path: data/demo/shapenet_lamp
train:
out_dir: out/demo_shapenet_outlier
test:
model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_7x.pt

View file

@ -0,0 +1,54 @@
data:
class: null
data_type: psr_full
input_type: pointcloud
path: data/shapenet_psr
num_gt_points: 10000
num_offset: 7
pointcloud_n: 3000
pointcloud_noise: 0.025
model:
grid_res: 128 # poisson grid resolution
psr_sigma: 2
psr_tanh: True
normal_normalize: False
predict_normal: True
predict_offset: True
c_dim: 32
s_offset: 0.001
encoder: local_pool_pointnet
encoder_kwargs:
hidden_dim: 32
plane_type: 'grid'
grid_resolution: 32
unet3d: True
unet3d_kwargs:
num_levels: 3
f_maps: 32
in_channels: 32
out_channels: 32
decoder: simple_local
decoder_kwargs:
sample_mode: bilinear # bilinear / nearest
hidden_size: 32
train:
batch_size: 32
lr: 5e-4
out_dir: out/shapenet/noise_025_ours
w_psr: 1
model_selection_metric: psr_l2
print_every: 100
checkpoint_every: 200
validate_every: 5000
backup_every: 10000
total_epochs: 400000
visualize_every: 5000
exp_pcl: True
exp_mesh: True
n_workers: 8
n_workers_val: 0
generation:
exp_gt: False
exp_input: True
psr_resolution: 128
psr_sigma: 2

View file

@ -0,0 +1,5 @@
inherit_from: configs/learning_based/noise_large/ours.yaml
generation:
generation_dir: generation_pretrained
test:
model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_noise_025.pt

View file

@ -0,0 +1,54 @@
data:
class: null
data_type: psr_full
input_type: pointcloud
path: data/shapenet_psr
num_gt_points: 10000
num_offset: 7
pointcloud_n: 3000
pointcloud_noise: 0.005
model:
grid_res: 128 # poisson grid resolution
psr_sigma: 2
psr_tanh: True
normal_normalize: False
predict_normal: True
predict_offset: True
c_dim: 32
s_offset: 0.001
encoder: local_pool_pointnet
encoder_kwargs:
hidden_dim: 32
plane_type: 'grid'
grid_resolution: 32
unet3d: True
unet3d_kwargs:
num_levels: 3
f_maps: 32
in_channels: 32
out_channels: 32
decoder: simple_local
decoder_kwargs:
sample_mode: bilinear # bilinear / nearest
hidden_size: 32
train:
batch_size: 32
lr: 5e-4
out_dir: out/shapenet/noise_005_ours
w_psr: 1
model_selection_metric: psr_l2
print_every: 100
checkpoint_every: 200
validate_every: 5000
backup_every: 10000
total_epochs: 400000
visualize_every: 5000
exp_pcl: True
exp_mesh: True
n_workers: 8
n_workers_val: 0
generation:
exp_gt: False
exp_input: True
psr_resolution: 128
psr_sigma: 2

View file

@ -0,0 +1,5 @@
inherit_from: configs/learning_based/noise_small/ours.yaml
generation:
generation_dir: generation_pretrained
test:
model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_noise_005.pt

View file

@ -0,0 +1,55 @@
data:
class: null
data_type: psr_full
input_type: pointcloud
path: data/shapenet_psr
num_gt_points: 10000
num_offset: 1
pointcloud_n: 3000
pointcloud_noise: 0.005
pointcloud_outlier_ratio: 0.5
model:
grid_res: 128 # poisson grid resolution
psr_sigma: 2
psr_tanh: True
normal_normalize: False
predict_normal: True
predict_offset: True
c_dim: 32
s_offset: 0.001
encoder: local_pool_pointnet
encoder_kwargs:
hidden_dim: 32
plane_type: 'grid'
grid_resolution: 32
unet3d: True
unet3d_kwargs:
num_levels: 3
f_maps: 32
in_channels: 32
out_channels: 32
decoder: simple_local
decoder_kwargs:
sample_mode: bilinear # bilinear / nearest
hidden_size: 32
train:
batch_size: 32
lr: 5e-4
out_dir: out/shapenet/outlier_ours_1x
w_psr: 1
model_selection_metric: psr_l2
print_every: 100
checkpoint_every: 200
validate_every: 5000
backup_every: 10000
total_epochs: 400000
visualize_every: 5000
exp_pcl: True
exp_mesh: True
n_workers: 8
n_workers_val: 0
generation:
exp_gt: False
exp_input: True
psr_resolution: 128
psr_sigma: 2

View file

@ -0,0 +1,5 @@
inherit_from: configs/learning_based/outlier/ours_3x/ours.yaml
generation:
generation_dir: generation_pretrained
test:
model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_3x.pt

View file

@ -0,0 +1,54 @@
data:
class: null
data_type: psr_full
input_type: pointcloud
path: data/shapenet_psr
num_gt_points: 10000
num_offset: 5
pointcloud_n: 3000
pointcloud_noise: 0.005
pointcloud_outlier_ratio: 0.5
model:
grid_res: 128 # poisson grid resolution
psr_sigma: 2
psr_tanh: True
normal_normalize: False
predict_normal: True
predict_offset: True
c_dim: 32
s_offset: 0.001
encoder: local_pool_pointnet
encoder_kwargs:
hidden_dim: 32
plane_type: ['xz', 'xy', 'yz']
plane_resolution: 64
unet: True
unet_kwargs:
depth: 4
merge_mode: concat
start_filts: 32
decoder: simple_local
decoder_kwargs:
sample_mode: bilinear # bilinear / nearest
hidden_size: 32
train:
batch_size: 32
lr: 5e-4
out_dir: out/shapenet/outlier_ours_3plane
w_psr: 1
model_selection_metric: psr_l2
print_every: 100
checkpoint_every: 200
validate_every: 5000
backup_every: 10000
total_epochs: 400000
visualize_every: 5000
exp_pcl: True
exp_mesh: True
n_workers: 8
n_workers_val: 0
generation:
exp_gt: False
exp_input: True
psr_resolution: 128
psr_sigma: 2

View file

@ -0,0 +1,55 @@
data:
class: null
data_type: psr_full
input_type: pointcloud
path: data/shapenet_psr
num_gt_points: 10000
num_offset: 3
pointcloud_n: 3000
pointcloud_noise: 0.005
pointcloud_outlier_ratio: 0.5
model:
grid_res: 128 # poisson grid resolution
psr_sigma: 2
psr_tanh: True
normal_normalize: False
predict_normal: True
predict_offset: True
c_dim: 32
s_offset: 0.001
encoder: local_pool_pointnet
encoder_kwargs:
hidden_dim: 32
plane_type: 'grid'
grid_resolution: 32
unet3d: True
unet3d_kwargs:
num_levels: 3
f_maps: 32
in_channels: 32
out_channels: 32
decoder: simple_local
decoder_kwargs:
sample_mode: bilinear # bilinear / nearest
hidden_size: 32
train:
batch_size: 32
lr: 5e-4
out_dir: out/shapenet/outlier_ours_3x
w_psr: 1
model_selection_metric: psr_l2
print_every: 100
checkpoint_every: 200
validate_every: 5000
backup_every: 10000
total_epochs: 400000
visualize_every: 5000
exp_pcl: True
exp_mesh: True
n_workers: 8
n_workers_val: 0
generation:
exp_gt: False
exp_input: True
psr_resolution: 128
psr_sigma: 2

View file

@ -0,0 +1,5 @@
inherit_from: configs/learning_based/outlier/ours_1x/ours.yaml
generation:
generation_dir: generation_pretrained
test:
model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_1x.pt

View file

@ -0,0 +1,55 @@
data:
class: null
data_type: psr_full
input_type: pointcloud
path: data/shapenet_psr
num_gt_points: 10000
num_offset: 5
pointcloud_n: 3000
pointcloud_noise: 0.005
pointcloud_outlier_ratio: 0.5
model:
grid_res: 128 # poisson grid resolution
psr_sigma: 2
psr_tanh: True
normal_normalize: False
predict_normal: True
predict_offset: True
c_dim: 32
s_offset: 0.001
encoder: local_pool_pointnet
encoder_kwargs:
hidden_dim: 32
plane_type: 'grid'
grid_resolution: 32
unet3d: True
unet3d_kwargs:
num_levels: 3
f_maps: 32
in_channels: 32
out_channels: 32
decoder: simple_local
decoder_kwargs:
sample_mode: bilinear # bilinear / nearest
hidden_size: 32
train:
batch_size: 32
lr: 5e-4
out_dir: out/shapenet/outlier_ours_5x
w_psr: 1
model_selection_metric: psr_l2
print_every: 100
checkpoint_every: 200
validate_every: 5000
backup_every: 10000
total_epochs: 400000
visualize_every: 5000
exp_pcl: True
exp_mesh: True
n_workers: 8
n_workers_val: 0
generation:
exp_gt: False
exp_input: True
psr_resolution: 128
psr_sigma: 2

View file

@ -0,0 +1,5 @@
inherit_from: configs/learning_based/outlier/ours_5x/ours.yaml
generation:
generation_dir: generation_pretrained
test:
model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_5x.pt

View file

@ -0,0 +1,55 @@
data:
class: null
data_type: psr_full
input_type: pointcloud
path: data/shapenet_psr
num_gt_points: 10000
num_offset: 7
pointcloud_n: 3000
pointcloud_noise: 0.005
pointcloud_outlier_ratio: 0.5
model:
grid_res: 128 # poisson grid resolution
psr_sigma: 2
psr_tanh: True
normal_normalize: False
predict_normal: True
predict_offset: True
c_dim: 32
s_offset: 0.001
encoder: local_pool_pointnet
encoder_kwargs:
hidden_dim: 32
plane_type: 'grid'
grid_resolution: 32
unet3d: True
unet3d_kwargs:
num_levels: 3
f_maps: 32
in_channels: 32
out_channels: 32
decoder: simple_local
decoder_kwargs:
sample_mode: bilinear # bilinear / nearest
hidden_size: 32
train:
batch_size: 32
lr: 5e-4
out_dir: out/shapenet/outlier_ours_7x
w_psr: 1
model_selection_metric: psr_l2
print_every: 100
checkpoint_every: 200
validate_every: 5000
backup_every: 10000
total_epochs: 400000
visualize_every: 5000
exp_pcl: True
exp_mesh: True
n_workers: 8
n_workers_val: 0
generation:
exp_gt: False
exp_input: True
psr_resolution: 128
psr_sigma: 2

View file

@ -0,0 +1,5 @@
inherit_from: configs/learning_based/outlier/ours_7x/ours.yaml
generation:
generation_dir: generation_pretrained
test:
model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_7x.pt

View file

@ -0,0 +1,37 @@
data:
class: 'only_pcl'
data_type: point
data_path: 'data/dfaust/*.ply'
object_id: 0
num_points: 20000
model:
sphere_radius: 0.2
grid_res: 256 # poisson grid resolution
psr_sigma: 2
train:
schedule:
pcl:
initial: 1e-2
interval: 700
factor: 0.5
final: 1e-3
out_dir: out/dfaust
w_chamfer: 1
n_sup_point: 20000
batch_size: 1
n_grow_points: 2000
resample_every: 200
subsample_vertex: False
total_epochs: 1600
print_every: 10
visualize_every: 2
checkpoint_every: 500
save_every: 100
exp_pcl: True
exp_mesh: True
o3d_show: False
o3d_vis_pcl: True
o3d_window_size: 540
vis_rendering: False
vis_vert_color: False
n_workers: 0

View file

@ -0,0 +1,38 @@
data:
class: 'only_pcl'
data_type: point
data_path: 'data/deep_geometric_prior_data/*.ply'
object_id: 0
num_points: 20000
model:
sphere_radius: 0.2
grid_res: 256 # poisson grid resolution
psr_sigma: 2
train:
schedule:
pcl:
initial: 1e-2
interval: 700
factor: 0.5
final: 1e-3
out_dir: out/dgp
w_reg_point: 0
w_chamfer: 1
n_sup_point: 20000
batch_size: 1
n_grow_points: 2000
resample_every: 200
subsample_vertex: False
total_epochs: 1600
print_every: 10
visualize_every: 2
checkpoint_every: 500
save_every: 100
exp_pcl: True
exp_mesh: True
o3d_show: False
o3d_vis_pcl: True
o3d_window_size: 540
vis_rendering: False
vis_vert_color: False
n_workers: 0

View file

@ -0,0 +1,37 @@
data:
class: 'only_pcl'
data_type: point
data_path: 'data/demo/wheel.obj'
object_id: 0
num_points: 20000
model:
sphere_radius: 0.2
grid_res: 128 # poisson grid resolution
psr_sigma: 2
train:
schedule:
pcl:
initial: 1e-2
interval: 700
factor: 0.5
final: 1e-3
out_dir: out/demo_optim
w_chamfer: 1
n_sup_point: 20000
batch_size: 1
n_grow_points: 2000
resample_every: 200
subsample_vertex: False
total_epochs: 1600
print_every: 10
visualize_every: 2
checkpoint_every: 500
save_every: 100
exp_pcl: True
exp_mesh: True
o3d_show: False
o3d_vis_pcl: True
o3d_window_size: 540
vis_rendering: False
vis_vert_color: False
n_workers: 0

View file

@ -0,0 +1,39 @@
data:
class: 'only_pcl'
data_type: point
data_path: 'data/thingi/*.ply'
object_id: 0
num_points: 20000
model:
sphere_radius: 0.2
grid_res: 128 # poisson grid resolution
psr_sigma: 2
train:
# lr_pcl: 2e-2
schedule:
pcl:
initial: 1e-2
interval: 700
factor: 0.5
final: 1e-3
out_dir: out/thingi
w_reg_point: 0
w_chamfer: 1
n_sup_point: 20000
batch_size: 1
n_grow_points: 2000
resample_every: 200
subsample_vertex: False
total_epochs: 1600
print_every: 10
visualize_every: 2
checkpoint_every: 500
save_every: 100
exp_pcl: True
exp_mesh: True
o3d_show: False
o3d_vis_pcl: True
o3d_window_size: 540
vis_rendering: False
vis_vert_color: False
n_workers: 0

View file

@ -0,0 +1,39 @@
data:
class: 'only_pcl'
data_type: point
data_path: 'data/thingi_noisy/*.ply'
object_id: 0
num_points: 20000
model:
sphere_radius: 0.2
grid_res: 128 # poisson grid resolution
psr_sigma: 2
train:
# lr_pcl: 2e-2
schedule:
pcl:
initial: 1e-2
interval: 700
factor: 0.5
final: 1e-3
out_dir: out/thingi_noisy
w_reg_point: 0
w_chamfer: 1
n_sup_point: 20000
batch_size: 1
n_grow_points: 2000
resample_every: 200
subsample_vertex: False
total_epochs: 1600
print_every: 10
visualize_every: 2
checkpoint_every: 500
save_every: 100
exp_pcl: True
exp_mesh: True
o3d_show: False
o3d_vis_pcl: True
o3d_window_size: 540
vis_rendering: False
vis_vert_color: False
n_workers: 0

29
environment.yaml Normal file
View file

@ -0,0 +1,29 @@
name: sap
channels:
- anaconda
- conda-forge
- pytorch
- defaults
dependencies:
- python=3.8
- pytorch=1.9.0
- torchvision=0.10.0
- cudatoolkit=10.2
- numpy=1.18.1
- matplotlib=3.4.3
- pyyaml=5.3.1
- scipy=1.4.1
- tqdm=4.54.0
- trimesh=3.8.14
- igl=2.2.1
- tensorboard=2.6.0
- pip
- pip:
- plyfile==0.7
- open3d>=0.11.1
- scikit-image>=0.18.0
- python-mnist==0.7
- opencv-python>=4.4
- av==8.0.3
- pykdtree==1.3.4
- ipdb==0.13.7

155
eval_meshes.py Normal file
View file

@ -0,0 +1,155 @@
import torch
import trimesh
from torch.utils.data import Dataset, DataLoader
import numpy as np; np.set_printoptions(precision=4)
import shutil, argparse, time, os
import pandas as pd
from src.data import collate_remove_none, collate_stack_together, worker_init_fn
from src.training import Trainer
from src.model import Encode2Points
from src.data import PointCloudField, IndexField, Shapes3dDataset
from src.utils import load_config, load_pointcloud
from src.eval import MeshEvaluator
from tqdm import tqdm
from pdb import set_trace as st
def main():
parser = argparse.ArgumentParser(description='MNIST toy experiment')
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--no_cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.')
args = parser.parse_args()
cfg = load_config(args.config, 'configs/default.yaml')
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
data_type = cfg['data']['data_type']
# Shorthands
out_dir = cfg['train']['out_dir']
generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir'])
if cfg['generation'].get('iter', 0)!=0:
generation_dir += '_%04d'%cfg['generation']['iter']
elif args.iter is not None:
generation_dir += '_%04d'%args.iter
print('Evaluate meshes under %s'%generation_dir)
out_file = os.path.join(generation_dir, 'eval_meshes_full.pkl')
out_file_class = os.path.join(generation_dir, 'eval_meshes.csv')
# PYTORCH VERSION > 1.0.0
assert(float(torch.__version__.split('.')[-3]) > 0)
pointcloud_field = PointCloudField(cfg['data']['pointcloud_file'])
fields = {
'pointcloud': pointcloud_field,
'idx': IndexField(),
}
print('Test split: ', cfg['data']['test_split'])
dataset_folder = cfg['data']['path']
dataset = Shapes3dDataset(
dataset_folder, fields,
cfg['data']['test_split'],
categories=cfg['data']['class'], cfg=cfg)
# Loader
test_loader = torch.utils.data.DataLoader(
dataset, batch_size=1, num_workers=0, shuffle=False)
# Evaluator
evaluator = MeshEvaluator(n_points=100000)
eval_dicts = []
print('Evaluating meshes...')
for it, data in enumerate(tqdm(test_loader)):
if data is None:
print('Invalid data.')
continue
mesh_dir = os.path.join(generation_dir, 'meshes')
pointcloud_dir = os.path.join(generation_dir, 'pointcloud')
# Get index etc.
idx = data['idx'].item()
try:
model_dict = dataset.get_model_dict(idx)
except AttributeError:
model_dict = {'model': str(idx), 'category': 'n/a'}
modelname = model_dict['model']
category_id = model_dict['category']
try:
category_name = dataset.metadata[category_id].get('name', 'n/a')
except AttributeError:
category_name = 'n/a'
if category_id != 'n/a':
mesh_dir = os.path.join(mesh_dir, category_id)
pointcloud_dir = os.path.join(pointcloud_dir, category_id)
# Evaluate
pointcloud_tgt = data['pointcloud'].squeeze(0).numpy()
normals_tgt = data['pointcloud.normals'].squeeze(0).numpy()
eval_dict = {
'idx': idx,
'class id': category_id,
'class name': category_name,
'modelname':modelname,
}
eval_dicts.append(eval_dict)
# Evaluate mesh
if cfg['test']['eval_mesh']:
mesh_file = os.path.join(mesh_dir, '%s.off' % modelname)
if os.path.exists(mesh_file):
mesh = trimesh.load(mesh_file, process=False)
eval_dict_mesh = evaluator.eval_mesh(
mesh, pointcloud_tgt, normals_tgt)
for k, v in eval_dict_mesh.items():
eval_dict[k + ' (mesh)'] = v
else:
print('Warning: mesh does not exist: %s' % mesh_file)
# Evaluate point cloud
if cfg['test']['eval_pointcloud']:
pointcloud_file = os.path.join(
pointcloud_dir, '%s.ply' % modelname)
if os.path.exists(pointcloud_file):
pointcloud = load_pointcloud(pointcloud_file).astype(np.float32)
eval_dict_pcl = evaluator.eval_pointcloud(
pointcloud, pointcloud_tgt)
for k, v in eval_dict_pcl.items():
eval_dict[k + ' (pcl)'] = v
else:
print('Warning: pointcloud does not exist: %s'
% pointcloud_file)
# Create pandas dataframe and save
eval_df = pd.DataFrame(eval_dicts)
eval_df.set_index(['idx'], inplace=True)
eval_df.to_pickle(out_file)
# Create CSV file with main statistics
eval_df_class = eval_df.groupby(by=['class name']).mean()
eval_df_class.loc['mean'] = eval_df_class.mean()
eval_df_class.to_csv(out_file_class)
# Print results
print(eval_df_class)
if __name__ == '__main__':
main()

217
generate.py Normal file
View file

@ -0,0 +1,217 @@
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np; np.set_printoptions(precision=4)
import shutil, argparse, time, os
import pandas as pd
from collections import defaultdict
from src import config
from src.utils import mc_from_psr, export_mesh, export_pointcloud
from src.dpsr import DPSR
from src.training import Trainer
from src.model import Encode2Points
from src.utils import load_config, load_model_manual, scale2onet, is_url, load_url
from tqdm import tqdm
from pdb import set_trace as st
def main():
parser = argparse.ArgumentParser(description='MNIST toy experiment')
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--no_cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.')
args = parser.parse_args()
cfg = load_config(args.config, 'configs/default.yaml')
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
data_type = cfg['data']['data_type']
input_type = cfg['data']['input_type']
vis_n_outputs = cfg['generation']['vis_n_outputs']
if vis_n_outputs is None:
vis_n_outputs = -1
# Shorthands
out_dir = cfg['train']['out_dir']
if not out_dir:
os.makedirs(out_dir)
generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir'])
out_time_file = os.path.join(generation_dir, 'time_generation_full.pkl')
out_time_file_class = os.path.join(generation_dir, 'time_generation.pkl')
# PYTORCH VERSION > 1.0.0
assert(float(torch.__version__.split('.')[-3]) > 0)
dataset = config.get_dataset('test', cfg, return_idx=True)
test_loader = torch.utils.data.DataLoader(
dataset, batch_size=1, num_workers=0, shuffle=False)
model = Encode2Points(cfg).to(device)
# load model
try:
if is_url(cfg['test']['model_file']):
state_dict = load_url(cfg['test']['model_file'])
elif cfg['generation'].get('iter', 0)!=0:
state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% cfg['generation']['iter']))
generation_dir += '_%04d'%cfg['generation']['iter']
elif args.iter is not None:
state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% args.iter))
else:
state_dict = torch.load(os.path.join(out_dir, 'model_best.pt'))
load_model_manual(state_dict['state_dict'], model)
except:
print('Model loading error. Exiting.')
exit()
# Generator
generator = config.get_generator(model, cfg, device=device)
# Determine what to generate
generate_mesh = cfg['generation']['generate_mesh']
generate_pointcloud = cfg['generation']['generate_pointcloud']
# Statistics
time_dicts = []
# Generate
model.eval()
dpsr = DPSR(res=(cfg['generation']['psr_resolution'],
cfg['generation']['psr_resolution'],
cfg['generation']['psr_resolution']),
sig= cfg['generation']['psr_sigma']).to(device)
# Count how many models already created
model_counter = defaultdict(int)
print('Generating...')
for it, data in enumerate(tqdm(test_loader)):
# Output folders
mesh_dir = os.path.join(generation_dir, 'meshes')
in_dir = os.path.join(generation_dir, 'input')
pointcloud_dir = os.path.join(generation_dir, 'pointcloud')
generation_vis_dir = os.path.join(generation_dir, 'vis', )
# Get index etc.
idx = data['idx'].item()
try:
model_dict = dataset.get_model_dict(idx)
except AttributeError:
model_dict = {'model': str(idx), 'category': 'n/a'}
modelname = model_dict['model']
category_id = model_dict['category']
try:
category_name = dataset.metadata[category_id].get('name', 'n/a')
except AttributeError:
category_name = 'n/a'
if category_id != 'n/a':
mesh_dir = os.path.join(mesh_dir, str(category_id))
pointcloud_dir = os.path.join(pointcloud_dir, str(category_id))
in_dir = os.path.join(in_dir, str(category_id))
folder_name = str(category_id)
if category_name != 'n/a':
folder_name = str(folder_name) + '_' + category_name.split(',')[0]
generation_vis_dir = os.path.join(generation_vis_dir, folder_name)
# Create directories if necessary
if vis_n_outputs >= 0 and not os.path.exists(generation_vis_dir):
os.makedirs(generation_vis_dir)
if generate_mesh and not os.path.exists(mesh_dir):
os.makedirs(mesh_dir)
if generate_pointcloud and not os.path.exists(pointcloud_dir):
os.makedirs(pointcloud_dir)
if not os.path.exists(in_dir):
os.makedirs(in_dir)
# Timing dict
time_dict = {
'idx': idx,
'class id': category_id,
'class name': category_name,
'modelname':modelname,
}
time_dicts.append(time_dict)
# Generate outputs
out_file_dict = {}
if generate_mesh:
#! deploy the generator to a separate class
out = generator.generate_mesh(data)
v, f, points, normals, stats_dict = out
time_dict.update(stats_dict)
# Write output
mesh_out_file = os.path.join(mesh_dir, '%s.off' % modelname)
export_mesh(mesh_out_file, scale2onet(v), f)
out_file_dict['mesh'] = mesh_out_file
if generate_pointcloud:
pointcloud_out_file = os.path.join(
pointcloud_dir, '%s.ply' % modelname)
export_pointcloud(pointcloud_out_file, scale2onet(points), normals)
out_file_dict['pointcloud'] = pointcloud_out_file
if cfg['generation']['copy_input']:
inputs_path = os.path.join(in_dir, '%s.ply' % modelname)
p = data.get('inputs').to(device)
export_pointcloud(inputs_path, scale2onet(p))
out_file_dict['in'] = inputs_path
# Copy to visualization directory for first vis_n_output samples
c_it = model_counter[category_id]
if c_it < vis_n_outputs:
# Save output files
img_name = '%02d.off' % c_it
for k, filepath in out_file_dict.items():
ext = os.path.splitext(filepath)[1]
out_file = os.path.join(generation_vis_dir, '%02d_%s%s'
% (c_it, k, ext))
shutil.copyfile(filepath, out_file)
# Also generate oracle meshes
if cfg['generation']['exp_oracle']:
points_gt = data.get('gt_points').to(device)
normals_gt = data.get('gt_points.normals').to(device)
psr_gt = dpsr(points_gt, normals_gt)
v, f, _ = mc_from_psr(psr_gt,
zero_level=cfg['data']['zero_level'])
out_file = os.path.join(generation_vis_dir, '%02d_%s%s'
% (c_it, 'mesh_oracle', '.off'))
export_mesh(out_file, scale2onet(v), f)
model_counter[category_id] += 1
# Create pandas dataframe and save
time_df = pd.DataFrame(time_dicts)
time_df.set_index(['idx'], inplace=True)
time_df.to_pickle(out_time_file)
# Create pickle files with main statistics
time_df_class = time_df.groupby(by=['class name']).mean()
time_df_class.loc['mean'] = time_df_class.mean()
time_df_class.to_pickle(out_time_file_class)
# Print results
print('Timings [s]:')
print(time_df_class)
if __name__ == '__main__':
main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 MiB

BIN
media/results_outliers.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 MiB

315
optim.py Normal file
View file

@ -0,0 +1,315 @@
import torch
import trimesh
import shutil, argparse, time, os, glob
import numpy as np; np.set_printoptions(precision=4)
import open3d as o3d
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image
from torchvision.io import write_video
from src.optimization import Trainer
from src.utils import load_config, update_config, initialize_logger, \
get_learning_rate_schedules, adjust_learning_rate, AverageMeter,\
update_optimizer, export_pointcloud
from skimage import measure
from plyfile import PlyData
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.structures import Meshes
def main():
parser = argparse.ArgumentParser(description='MNIST toy experiment')
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--no_cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1457, metavar='S',
help='random seed')
args, unknown = parser.parse_known_args()
cfg = load_config(args.config, 'configs/default.yaml')
cfg = update_config(cfg, unknown)
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
data_type = cfg['data']['data_type']
data_class = cfg['data']['class']
print(cfg['train']['out_dir'])
# PYTORCH VERSION > 1.0.0
assert(float(torch.__version__.split('.')[-3]) > 0)
# boiler-plate
if cfg['train']['timestamp']:
cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S")
logger = initialize_logger(cfg)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
shutil.copyfile(args.config,
os.path.join(cfg['train']['out_dir'], 'config.yaml'))
# tensorboardX writer
tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log")
if not os.path.exists(tblogdir):
os.makedirs(tblogdir)
writer = SummaryWriter(log_dir=tblogdir)
# initialize o3d visualizer
vis = None
if cfg['train']['o3d_show']:
vis = o3d.visualization.Visualizer()
vis.create_window(width=cfg['train']['o3d_window_size'],
height=cfg['train']['o3d_window_size'])
# initialize dataset
if data_type == 'point':
if cfg['data']['object_id'] != -1:
data_paths = sorted(glob.glob(cfg['data']['data_path']))
data_path = data_paths[cfg['data']['object_id']]
print('Loaded %d/%d object' % (cfg['data']['object_id']+1, len(data_paths)))
else:
data_path = cfg['data']['data_path']
print('Data loaded')
ext = data_path.split('.')[-1]
if ext == 'obj': # have GT mesh
mesh = load_objs_as_meshes([data_path], device=device)
# scale the mesh into unit cube
verts = mesh.verts_packed()
N = verts.shape[0]
center = verts.mean(0)
mesh.offset_verts_(-center.expand(N, 3))
scale = max((verts - center).abs().max(0)[0])
mesh.scale_verts_((1.0 / float(scale)))
# important for our DPSR to have the range in [0, 1), not reaching 1
mesh.scale_verts_(0.9)
target_pts, target_normals = sample_points_from_meshes(mesh,
num_samples=200000, return_normals=True)
elif ext == 'ply': # only have the point cloud
plydata = PlyData.read(data_path)
vertices = np.stack([plydata['vertex']['x'],
plydata['vertex']['y'],
plydata['vertex']['z']], axis=1)
normals = np.stack([plydata['vertex']['nx'],
plydata['vertex']['ny'],
plydata['vertex']['nz']], axis=1)
N = vertices.shape[0]
center = vertices.mean(0)
scale = np.max(np.max(np.abs(vertices - center), axis=0))
vertices -= center
vertices /= scale
vertices *= 0.9
target_pts = torch.tensor(vertices, device=device)[None].float()
target_normals = torch.tensor(normals, device=device)[None].float()
mesh = None # no GT mesh
if not torch.is_tensor(center):
center = torch.from_numpy(center)
if not torch.is_tensor(scale):
scale = torch.from_numpy(np.array([scale]))
data = {'target_points': target_pts,
'target_normals': target_normals, # normals are never used
'gt_mesh': mesh}
else:
raise NotImplementedError
# save the input point cloud
if 'target_points' in data.keys():
outdir_pcl = os.path.join(cfg['train']['out_dir'], 'target_pcl.ply')
if 'target_normals' in data.keys():
export_pointcloud(outdir_pcl, data['target_points'], data['target_normals'])
else:
export_pointcloud(outdir_pcl, data['target_points'])
# save oracle PSR mesh (mesh from our PSR using GT point+normals)
if data.get('gt_mesh') is not None:
gt_verts, gt_faces = data['gt_mesh'].get_mesh_verts_faces(0)
pts_gt, norms_gt = sample_points_from_meshes(data['gt_mesh'],
num_samples=500000, return_normals=True)
pts_gt = (pts_gt + 1) / 2
from src.dpsr import DPSR
dpsr_tmp = DPSR(res=(cfg['model']['grid_res'],
cfg['model']['grid_res'],
cfg['model']['grid_res']),
sig=cfg['model']['psr_sigma']).to(device)
target = dpsr_tmp(pts_gt, norms_gt).unsqueeze(1).to(device)
target = torch.tanh(target)
s = target.shape[-1] # size of psr_grid
psr_grid_numpy = target.squeeze().detach().cpu().numpy()
verts, faces, _, _ = measure.marching_cubes(psr_grid_numpy)
verts = verts / s * 2. - 1 # [-1, 1]
mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(verts)
mesh.triangles = o3d.utility.Vector3iVector(faces)
outdir_mesh = os.path.join(cfg['train']['out_dir'], 'oracle_mesh.ply')
o3d.io.write_triangle_mesh(outdir_mesh, mesh)
# initialize the source point cloud given an input mesh
if 'input_mesh' in cfg['train'].keys() and \
os.path.isfile(cfg['train']['input_mesh']):
if cfg['train']['input_mesh'].split('/')[-2] == 'mesh':
mesh_tmp = trimesh.load_mesh(cfg['train']['input_mesh'])
verts = torch.from_numpy(mesh_tmp.vertices[None]).float().to(device)
faces = torch.from_numpy(mesh_tmp.faces[None]).to(device)
mesh = Meshes(verts=verts, faces=faces)
points, normals = sample_points_from_meshes(mesh,
num_samples=cfg['data']['num_points'], return_normals=True)
# mesh is saved in the original scale of the gt
points -= center.float().to(device)
points /= scale.float().to(device)
points *= 0.9
# make sure the points are within the range of [0, 1)
points = points / 2. + 0.5
else:
# directly initialize from a point cloud
pcd = o3d.io.read_point_cloud(cfg['train']['input_mesh'])
points = torch.from_numpy(np.array(pcd.points)[None]).float().to(device)
normals = torch.from_numpy(np.array(pcd.normals)[None]).float().to(device)
points -= center.float().to(device)
points /= scale.float().to(device)
points *= 0.9
points = points / 2. + 0.5
else: #! initialize our source point cloud from a sphere
sphere_radius = cfg['model']['sphere_radius']
sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius,
count=[256,256])
points, idx = sphere_mesh.sample(cfg['data']['num_points'],
return_index=True)
points += 0.5 # make sure the points are within the range of [0, 1)
normals = sphere_mesh.face_normals[idx]
points = torch.from_numpy(points).unsqueeze(0).to(device)
normals = torch.from_numpy(normals).unsqueeze(0).to(device)
points = torch.log(points/(1-points)) # inverse sigmoid
inputs = torch.cat([points, normals], axis=-1).float()
inputs.requires_grad = True
model = None # no network
# initialize optimizer
cfg['train']['schedule']['pcl']['initial'] = cfg['train']['lr_pcl']
print('Initial learning rate:', cfg['train']['schedule']['pcl']['initial'])
if 'schedule' in cfg['train']:
lr_schedules = get_learning_rate_schedules(cfg['train']['schedule'])
else:
lr_schedules = None
optimizer = update_optimizer(inputs, cfg,
epoch=0, model=model, schedule=lr_schedules)
try:
# load model
state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt'))
if ('pcl' in state_dict.keys()) & (state_dict['pcl'] is not None):
inputs = state_dict['pcl'].to(device)
inputs.requires_grad = True
optimizer = update_optimizer(inputs, cfg,
epoch=state_dict.get('epoch'), schedule=lr_schedules)
out = "Load model from epoch %d" % state_dict.get('epoch', 0)
print(out)
logger.info(out)
except:
state_dict = dict()
start_epoch = state_dict.get('epoch', -1)
trainer = Trainer(cfg, optimizer, device=device)
runtime = {}
runtime['all'] = AverageMeter()
# training loop
for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1):
# schedule the learning rate
if (epoch>0) & (lr_schedules is not None):
if (epoch % lr_schedules[0].interval == 0):
adjust_learning_rate(lr_schedules, optimizer, epoch)
if len(lr_schedules) >1:
print('[epoch {}] net_lr: {}, pcl_lr: {}'.format(epoch,
lr_schedules[0].get_learning_rate(epoch),
lr_schedules[1].get_learning_rate(epoch)))
else:
print('[epoch {}] adjust pcl_lr to: {}'.format(epoch,
lr_schedules[0].get_learning_rate(epoch)))
start = time.time()
loss, loss_each = trainer.train_step(data, inputs, model, epoch)
runtime['all'].update(time.time() - start)
if epoch % cfg['train']['print_every'] == 0:
log_text = ('[Epoch %02d] loss=%.5f') %(epoch, loss)
if loss_each is not None:
for k, l in loss_each.items():
if l.item() != 0.:
log_text += (' loss_%s=%.5f') % (k, l.item())
log_text += (' time=%.3f / %.3f') % (runtime['all'].val,
runtime['all'].sum)
logger.info(log_text)
print(log_text)
# visualize point clouds and meshes
if (epoch % cfg['train']['visualize_every'] == 0) & (vis is not None):
trainer.visualize(data, inputs, model, epoch, o3d_vis=vis)
# save outputs
if epoch % cfg['train']['save_every'] == 0:
trainer.save_mesh_pointclouds(inputs, epoch,
center.cpu().numpy(),
scale.cpu().numpy()*(1/0.9))
# save checkpoints
if (epoch > 0) & (epoch % cfg['train']['checkpoint_every'] == 0):
state = {'epoch': epoch}
pcl = None
if isinstance(inputs, torch.Tensor):
state['pcl'] = inputs.detach().cpu()
torch.save(state, os.path.join(cfg['train']['dir_model'],
'%04d' % epoch + '.pt'))
print("Save new model at epoch %d" % epoch)
logger.info("Save new model at epoch %d" % epoch)
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt'))
# resample and gradually add new points to the source pcl
if (epoch > 0) & \
(cfg['train']['resample_every']!=0) & \
(epoch % cfg['train']['resample_every'] == 0) & \
(epoch < cfg['train']['total_epochs']):
inputs = trainer.point_resampling(inputs)
optimizer = update_optimizer(inputs, cfg,
epoch=epoch, model=model, schedule=lr_schedules)
trainer = Trainer(cfg, optimizer, device=device)
# visualize the Open3D outputs
if cfg['train']['o3d_show']:
out_video_dir = os.path.join(cfg['train']['out_dir'],
'vis/o3d/video.mp4')
if os.path.isfile(out_video_dir):
os.system('rm {}'.format(out_video_dir))
os.system('ffmpeg -framerate 30 \
-start_number 0 \
-i {}/vis/o3d/%04d.jpg \
-pix_fmt yuv420p \
-crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir))
out_video_dir = os.path.join(cfg['train']['out_dir'],
'vis/o3d/video_pcd.mp4')
if os.path.isfile(out_video_dir):
os.system('rm {}'.format(out_video_dir))
os.system('ffmpeg -framerate 30 \
-start_number 0 \
-i {}/vis/o3d/%04d_pcd.jpg \
-pix_fmt yuv420p \
-crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir))
print('Video saved.')
if __name__ == '__main__':
main()

69
optim_hierarchy.py Normal file
View file

@ -0,0 +1,69 @@
import sys, os
import argparse
from src.utils import load_config
import subprocess
os.environ['MKL_THREADING_LAYER'] = 'GNU'
def main():
parser = argparse.ArgumentParser(description='MNIST toy experiment')
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--start_res', type=int, default=-1, help='Resolution to start with.')
parser.add_argument('--object_id', type=int, default=-1, help='Object index.')
args, unknown = parser.parse_known_args()
cfg = load_config(args.config, 'configs/default.yaml')
resolutions=[32, 64, 128, 256]
iterations=[1000, 1000, 1000, 200]
lrs=[2e-3, 2e-3*0.7, 2e-3*(0.7**2), 2e-3*(0.7**3)] # reduce lr
for idx,(res, iteration, lr) in enumerate(zip(resolutions, iterations, lrs)):
if res<args.start_res:
continue
if res>cfg['model']['grid_res']:
continue
psr_sigma= 2 if res<=128 else 3
if res > 128:
psr_sigma = 5 if 'thingi_noisy' in args.config else 3
if args.object_id != -1:
out_dir = os.path.join(cfg['train']['out_dir'], 'object_%02d'%args.object_id, 'res_%d'%res)
else:
out_dir = os.path.join(cfg['train']['out_dir'], 'res_%d'%res)
# sample from mesh when resampling is enabled, otherwise reuse the pointcloud
init_shape='mesh' if cfg['train']['resample_every']>0 else 'pointcloud'
if args.object_id != -1:
input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'],
'object_%02d'%args.object_id, 'res_%d' % (resolutions[idx-1]),
'vis', init_shape, '%04d.ply' % (iterations[idx-1]))
else:
input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'],
'res_%d' % (resolutions[idx-1]),
'vis', init_shape, '%04d.ply' % (iterations[idx-1]))
cmd = 'export MKL_SERVICE_FORCE_INTEL=1 && '
cmd += "python optim.py %s --model:grid_res %d --model:psr_sigma %d \
--train:input_mesh %s --train:total_epochs %d \
--train:out_dir %s --train:lr_pcl %f \
--data:object_id %d" % (
args.config,
res,
psr_sigma,
input_mesh,
iteration,
out_dir,
lr,
args.object_id)
print(cmd)
os.system(cmd)
if __name__=="__main__":
main()

View file

@ -0,0 +1,8 @@
#!/bin/bash
mkdir -p data
cd data
echo "Start downloading ..."
wget https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/data/demo.zip
unzip demo.zip
rm demo.zip
echo "Done!"

View file

@ -0,0 +1,8 @@
#!/bin/bash
mkdir -p data
cd data
echo "Start downloading data for optimization-based setting (~200 MB)"
wget https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/data/optim_data.zip
unzip optim_data.zip
rm optim_data.zip
echo "Done!"

View file

@ -0,0 +1,8 @@
#!/bin/bash
mkdir -p data
cd data
echo "Start downloading preprocessed ShapeNet data (~220G)"
wget https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/data/shapenet_psr.zip
unzip shapenet_psr.zip
rm shapenet_psr.zip
echo "Done!"

101
scripts/process_shapenet.py Normal file
View file

@ -0,0 +1,101 @@
import os
import torch
import time
import multiprocessing
import numpy as np
from tqdm import tqdm
from src.dpsr import DPSR
data_path = 'data/ShapeNet' # path for ShapeNet from ONet
base = 'data' # output base directory
dataset_name = 'shapenet_psr'
multiprocess = True
njobs = 8
save_pointcloud = True
save_psr_field = True
resolution = 128
zero_level = 0.0
num_points = 100000
padding = 1.2
dpsr = DPSR(res=(resolution, resolution, resolution), sig=0)
def process_one(obj):
obj_name = obj.split('/')[-1]
c = obj.split('/')[-2]
# create new for the current object
out_path_cur = os.path.join(base, dataset_name, c)
out_path_cur_obj = os.path.join(out_path_cur, obj_name)
os.makedirs(out_path_cur_obj, exist_ok=True)
gt_path = os.path.join(data_path, c, obj_name, 'pointcloud.npz')
data = np.load(gt_path)
points = data['points']
normals = data['normals']
# normalize the point to [0, 1)
points = points / padding + 0.5
# to scale back during inference, we should:
#! p = (p - 0.5) * padding
if save_pointcloud:
outdir = os.path.join(out_path_cur_obj, 'pointcloud.npz')
# np.savez(outdir, points=points, normals=normals)
np.savez(outdir, points=data['points'], normals=data['normals'])
# return
if save_psr_field:
psr_gt = dpsr(torch.from_numpy(points.astype(np.float32))[None],
torch.from_numpy(normals.astype(np.float32))[None]).squeeze().cpu().numpy().astype(np.float16)
outdir = os.path.join(out_path_cur_obj, 'psr.npz')
np.savez(outdir, psr=psr_gt)
def main(c):
print('---------------------------------------')
print('Processing {} {}'.format(c, split))
print('---------------------------------------')
for split in ['train', 'val', 'test']:
fname = os.path.join(data_path, c, split+'.lst')
with open(fname, 'r') as f:
obj_list = f.read().splitlines()
obj_list = [c+'/'+s for s in obj_list]
if multiprocess:
# multiprocessing.set_start_method('spawn', force=True)
pool = multiprocessing.Pool(njobs)
try:
for _ in tqdm(pool.imap_unordered(process_one, obj_list), total=len(obj_list)):
pass
# pool.map_async(process_one, obj_list).get()
except KeyboardInterrupt:
# Allow ^C to interrupt from any thread.
exit()
pool.close()
else:
for obj in tqdm(obj_list):
process_one(obj)
print('Done Processing {} {}!'.format(c, split))
if __name__ == "__main__":
classes = ['02691156', '02828884', '02933112',
'02958343', '03211117', '03001627',
'03636649', '03691459', '04090263',
'04256520', '04379243', '04401088', '04530566']
t_start = time.time()
for c in classes:
main(c)
t_end = time.time()
print('Total processing time: ', t_end - t_start)

0
src/__init__.py Normal file
View file

146
src/config.py Normal file
View file

@ -0,0 +1,146 @@
import yaml
from torchvision import transforms
from src import data, generation
from src.dpsr import DPSR
from ipdb import set_trace as st
# Generator for final mesh extraction
def get_generator(model, cfg, device, **kwargs):
''' Returns the generator object.
Args:
model (nn.Module): Occupancy Network model
cfg (dict): imported yaml config
device (device): pytorch device
'''
if cfg['generation']['psr_resolution'] == 0:
psr_res = cfg['model']['grid_res']
psr_sigma = cfg['model']['psr_sigma']
else:
psr_res = cfg['generation']['psr_resolution']
psr_sigma = cfg['generation']['psr_sigma']
dpsr = DPSR(res=(psr_res, psr_res, psr_res),
sig= psr_sigma).to(device)
generator = generation.Generator3D(
model,
device=device,
threshold=cfg['data']['zero_level'],
sample=cfg['generation']['use_sampling'],
input_type = cfg['data']['input_type'],
padding=cfg['data']['padding'],
dpsr=dpsr,
psr_tanh=cfg['model']['psr_tanh']
)
return generator
# Datasets
def get_dataset(mode, cfg, return_idx=False):
''' Returns the dataset.
Args:
model (nn.Module): the model which is used
cfg (dict): config dictionary
return_idx (bool): whether to include an ID field
'''
dataset_type = cfg['data']['dataset']
dataset_folder = cfg['data']['path']
categories = cfg['data']['class']
# Get split
splits = {
'train': cfg['data']['train_split'],
'val': cfg['data']['val_split'],
'test': cfg['data']['test_split'],
'vis': cfg['data']['val_split'],
}
split = splits[mode]
# Create dataset
if dataset_type == 'Shapes3D':
fields = get_data_fields(mode, cfg)
# Input fields
inputs_field = get_inputs_field(mode, cfg)
if inputs_field is not None:
fields['inputs'] = inputs_field
if return_idx:
fields['idx'] = data.IndexField()
dataset = data.Shapes3dDataset(
dataset_folder, fields,
split=split,
categories=categories,
cfg = cfg
)
else:
raise ValueError('Invalid dataset "%s"' % cfg['data']['dataset'])
return dataset
def get_inputs_field(mode, cfg):
''' Returns the inputs fields.
Args:
mode (str): the mode which is used
cfg (dict): config dictionary
'''
input_type = cfg['data']['input_type']
if input_type is None:
inputs_field = None
elif input_type == 'pointcloud':
noise_level = cfg['data']['pointcloud_noise']
if cfg['data']['pointcloud_outlier_ratio']>0:
transform = transforms.Compose([
data.SubsamplePointcloud(cfg['data']['pointcloud_n']),
data.PointcloudNoise(noise_level),
data.PointcloudOutliers(cfg['data']['pointcloud_outlier_ratio'])
])
else:
transform = transforms.Compose([
data.SubsamplePointcloud(cfg['data']['pointcloud_n']),
data.PointcloudNoise(noise_level)
])
data_type = cfg['data']['data_type']
inputs_field = data.PointCloudField(
cfg['data']['pointcloud_file'], data_type, transform,
multi_files= cfg['data']['multi_files']
)
else:
raise ValueError(
'Invalid input type (%s)' % input_type)
return inputs_field
def get_data_fields(mode, cfg):
''' Returns the data fields.
Args:
mode (str): the mode which is used
cfg (dict): imported yaml config
'''
data_type = cfg['data']['data_type']
fields = {}
if (mode in ('val', 'test')):
transform = data.SubsamplePointcloud(100000)
else:
transform = data.SubsamplePointcloud(cfg['data']['num_gt_points'])
data_name = cfg['data']['pointcloud_file']
fields['gt_points'] = data.PointCloudField(data_name,
transform=transform, data_type=data_type, multi_files=cfg['data']['multi_files'])
if data_type == 'psr_full':
if mode != 'test':
fields['gt_psr'] = data.FullPSRField(multi_files=cfg['data']['multi_files'])
else:
raise ValueError('Invalid data type (%s)' % data_type)
return fields

25
src/data/__init__.py Normal file
View file

@ -0,0 +1,25 @@
from src.data.core import (
Shapes3dDataset, collate_remove_none, worker_init_fn, collate_stack_together
)
from src.data.fields import (
IndexField, PointCloudField, FullPSRField
)
from src.data.transforms import (
PointcloudNoise, SubsamplePointcloud,
PointcloudOutliers,
)
__all__ = [
# Core
Shapes3dDataset,
collate_remove_none,
worker_init_fn,
# Fields
IndexField,
PointCloudField,
FullPSRField,
# Transforms
PointcloudNoise,
SubsamplePointcloud,
PointcloudOutliers,
]

237
src/data/core.py Normal file
View file

@ -0,0 +1,237 @@
import os
import logging
from torch.utils import data
from pdb import set_trace as st
import numpy as np
import yaml
logger = logging.getLogger(__name__)
# Fields
class Field(object):
''' Data fields class.
'''
def load(self, data_path, idx, category):
''' Loads a data point.
Args:
data_path (str): path to data file
idx (int): index of data point
category (int): index of category
'''
raise NotImplementedError
def check_complete(self, files):
''' Checks if set is complete.
Args:
files: files
'''
raise NotImplementedError
class Shapes3dDataset(data.Dataset):
''' 3D Shapes dataset class.
'''
def __init__(self, dataset_folder, fields, split=None,
categories=None, no_except=True, transform=None, cfg=None):
''' Initialization of the the 3D shape dataset.
Args:
dataset_folder (str): dataset folder
fields (dict): dictionary of fields
split (str): which split is used
categories (list): list of categories to use
no_except (bool): no exception
transform (callable): transformation applied to data points
cfg (yaml): config file
'''
# Attributes
self.dataset_folder = dataset_folder
self.fields = fields
self.no_except = no_except
self.transform = transform
self.cfg = cfg
# If categories is None, use all subfolders
if categories is None:
categories = os.listdir(dataset_folder)
categories = [c for c in categories
if os.path.isdir(os.path.join(dataset_folder, c))]
# Read metadata file
metadata_file = os.path.join(dataset_folder, 'metadata.yaml')
if os.path.exists(metadata_file):
with open(metadata_file, 'r') as f:
self.metadata = yaml.load(f, Loader=yaml.Loader)
else:
self.metadata = {
c: {'id': c, 'name': 'n/a'} for c in categories
}
# Set index
for c_idx, c in enumerate(categories):
self.metadata[c]['idx'] = c_idx
# Get all models
self.models = []
for c_idx, c in enumerate(categories):
subpath = os.path.join(dataset_folder, c)
if not os.path.isdir(subpath):
logger.warning('Category %s does not exist in dataset.' % c)
if split is None:
self.models += [
{'category': c, 'model': m} for m in [d for d in os.listdir(subpath) if (os.path.isdir(os.path.join(subpath, d)) and d != '') ]
]
else:
split_file = os.path.join(subpath, split + '.lst')
with open(split_file, 'r') as f:
models_c = f.read().split('\n')
if '' in models_c:
models_c.remove('')
self.models += [
{'category': c, 'model': m}
for m in models_c
]
# precompute
self.split = split
def __len__(self):
''' Returns the length of the dataset.
'''
return len(self.models)
def __getitem__(self, idx):
''' Returns an item of the dataset.
Args:
idx (int): ID of data point
'''
category = self.models[idx]['category']
model = self.models[idx]['model']
c_idx = self.metadata[category]['idx']
model_path = os.path.join(self.dataset_folder, category, model)
data = {}
info = c_idx
if self.cfg['data']['multi_files'] is not None:
idx = np.random.randint(self.cfg['data']['multi_files'])
if self.split != 'train':
idx = 0
for field_name, field in self.fields.items():
try:
field_data = field.load(model_path, idx, info)
except Exception:
if self.no_except:
logger.warn(
'Error occured when loading field %s of model %s'
% (field_name, model)
)
return None
else:
raise
if isinstance(field_data, dict):
for k, v in field_data.items():
if k is None:
data[field_name] = v
else:
data['%s.%s' % (field_name, k)] = v
else:
data[field_name] = field_data
if self.transform is not None:
data = self.transform(data)
return data
def get_model_dict(self, idx):
return self.models[idx]
def test_model_complete(self, category, model):
''' Tests if model is complete.
Args:
model (str): modelname
'''
model_path = os.path.join(self.dataset_folder, category, model)
files = os.listdir(model_path)
for field_name, field in self.fields.items():
if not field.check_complete(files):
logger.warn('Field "%s" is incomplete: %s'
% (field_name, model_path))
return False
return True
def collate_remove_none(batch):
''' Collater that puts each data field into a tensor with outer dimension
batch size.
Args:
batch: batch
'''
batch = list(filter(lambda x: x is not None, batch))
return data.dataloader.default_collate(batch)
def collate_stack_together(batch):
''' Collater that puts each data field into a tensor with outer dimension
batch size.
Args:
batch: batch
'''
batch = list(filter(lambda x: x is not None, batch))
keys = batch[0].keys()
concat = {}
if len(batch)>1:
for key in keys:
key_val = [item[key] for item in batch]
concat[key] = np.concatenate(key_val, axis=0)
if key == 'inputs':
n_pts = [item[key].shape[0] for item in batch]
concat['batch_ind'] = np.concatenate(
[i * np.ones(n, dtype=int) for i, n in enumerate(n_pts)], axis=0)
return data.dataloader.default_collate([concat])
else:
n_pts = batch[0]['inputs'].shape[0]
batch[0]['batch_ind'] = np.zeros(n_pts, dtype=int)
return data.dataloader.default_collate(batch)
def worker_init_fn(worker_id):
''' Worker init function to ensure true randomness.
'''
def set_num_threads(nt):
try:
import mkl; mkl.set_num_threads(nt)
except:
pass
torch.set_num_threads(1)
os.environ['IPC_ENABLE']='1'
for o in ['OPENBLAS_NUM_THREADS','NUMEXPR_NUM_THREADS','OMP_NUM_THREADS','MKL_NUM_THREADS']:
os.environ[o] = str(nt)
random_data = os.urandom(4)
base_seed = int.from_bytes(random_data, byteorder="big")
np.random.seed(base_seed + worker_id)

118
src/data/fields.py Normal file
View file

@ -0,0 +1,118 @@
import os
import glob
import time
import random
from PIL import Image
import numpy as np
import trimesh
from src.data.core import Field
from pdb import set_trace as st
class IndexField(Field):
''' Basic index field.'''
def load(self, model_path, idx, category):
''' Loads the index field.
Args:
model_path (str): path to model
idx (int): ID of data point
category (int): index of category
'''
return idx
def check_complete(self, files):
''' Check if field is complete.
Args:
files: files
'''
return True
class FullPSRField(Field):
def __init__(self, transform=None, multi_files=None):
self.transform = transform
# self.unpackbits = unpackbits
self.multi_files = multi_files
def load(self, model_path, idx, category):
# try:
# t0 = time.time()
if self.multi_files is not None:
psr_path = os.path.join(model_path, 'psr', 'psr_{:02d}.npz'.format(idx))
else:
psr_path = os.path.join(model_path, 'psr.npz')
psr_dict = np.load(psr_path)
# t1 = time.time()
psr = psr_dict['psr']
psr = psr.astype(np.float32)
# t2 = time.time()
# print('load PSR: {:.4f}, change type: {:.4f}, total: {:.4f}'.format(t1 - t0, t2 - t1, t2-t0))
data = {None: psr}
if self.transform is not None:
data = self.transform(data)
return data
class PointCloudField(Field):
''' Point cloud field.
It provides the field used for point cloud data. These are the points
randomly sampled on the mesh.
Args:
file_name (str): file name
transform (list): list of transformations applied to data points
multi_files (callable): number of files
'''
def __init__(self, file_name, data_type=None, transform=None, multi_files=None, padding=0.1, scale=1.2):
self.file_name = file_name
self.data_type = data_type # to make sure the range of input is correct
self.transform = transform
self.multi_files = multi_files
self.padding = padding
self.scale = scale
def load(self, model_path, idx, category):
''' Loads the data point.
Args:
model_path (str): path to model
idx (int): ID of data point
category (int): index of category
'''
if self.multi_files is None:
file_path = os.path.join(model_path, self.file_name)
else:
# num = np.random.randint(self.multi_files)
# file_path = os.path.join(model_path, self.file_name, '%s_%02d.npz' % (self.file_name, num))
file_path = os.path.join(model_path, self.file_name, 'pointcloud_%02d.npz' % (idx))
pointcloud_dict = np.load(file_path)
points = pointcloud_dict['points'].astype(np.float32)
normals = pointcloud_dict['normals'].astype(np.float32)
data = {
None: points,
'normals': normals,
}
if self.transform is not None:
data = self.transform(data)
if self.data_type == 'psr_full':
# scale the point cloud to the range of (0, 1)
data[None] = data[None] / self.scale + 0.5
return data
def check_complete(self, files):
''' Check if field is complete.
Args:
files: files
'''
complete = (self.file_name in files)
return complete

86
src/data/transforms.py Normal file
View file

@ -0,0 +1,86 @@
import numpy as np
# Transforms
class PointcloudNoise(object):
''' Point cloud noise transformation class.
It adds noise to point cloud data.
Args:
stddev (int): standard deviation
'''
def __init__(self, stddev):
self.stddev = stddev
def __call__(self, data):
''' Calls the transformation.
Args:
data (dictionary): data dictionary
'''
data_out = data.copy()
points = data[None]
noise = self.stddev * np.random.randn(*points.shape)
noise = noise.astype(np.float32)
data_out[None] = points + noise
return data_out
class PointcloudOutliers(object):
''' Point cloud outlier transformation class.
It adds outliers to point cloud data.
Args:
ratio (int): outlier percentage to the entire point cloud
'''
def __init__(self, ratio):
self.ratio = ratio
def __call__(self, data):
''' Calls the transformation.
Args:
data (dictionary): data dictionary
'''
data_out = data.copy()
points = data[None]
n_points = points.shape[0]
n_outlier_points = int(n_points*self.ratio)
ind = np.random.randint(0, n_points, n_outlier_points)
outliers = np.random.uniform(-0.55, 0.55, (n_outlier_points, 3))
outliers = outliers.astype(np.float32)
points[ind] = outliers
data_out[None] = points
return data_out
class SubsamplePointcloud(object):
''' Point cloud subsampling transformation class.
It subsamples the point cloud data.
Args:
N (int): number of points to be subsampled
'''
def __init__(self, N):
self.N = N
def __call__(self, data):
''' Calls the transformation.
Args:
data (dict): data dictionary
'''
data_out = data.copy()
points = data[None]
indices = np.random.randint(points.shape[0], size=self.N)
data_out[None] = points[indices, :]
if 'normals' in data.keys():
normals = data['normals']
data_out['normals'] = normals[indices, :]
return data_out

228
src/data_loader.py Normal file
View file

@ -0,0 +1,228 @@
import os
import cv2
import torch
import numpy as np
from glob import glob
from torch.utils import data
from src.utils import load_rgb, load_mask, get_camera_params
from pytorch3d.renderer import PerspectiveCameras
from skimage import img_as_float32
##################################################
# Below are for the differentiable renderer
# Taken from https://github.com/lioryariv/idr/blob/main/code/utils/rend_util.py
def load_rgb(path):
img = imageio.imread(path)
img = img_as_float32(img)
# pixel values between [-1,1]
# img -= 0.5
# img *= 2.
# img = img.transpose(2, 0, 1)
return img
def load_mask(path):
alpha = imageio.imread(path, as_gray=True)
alpha = img_as_float32(alpha)
object_mask = alpha > 127.5
return object_mask
def get_camera_params(uv, pose, intrinsics):
if pose.shape[1] == 7: #In case of quaternion vector representation
cam_loc = pose[:, 4:]
R = quat_to_rot(pose[:,:4])
p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float()
p[:, :3, :3] = R
p[:, :3, 3] = cam_loc
else: # In case of pose matrix representation
cam_loc = pose[:, :3, 3]
p = pose
batch_size, num_samples, _ = uv.shape
depth = torch.ones((batch_size, num_samples))
x_cam = uv[:, :, 0].view(batch_size, -1)
y_cam = uv[:, :, 1].view(batch_size, -1)
z_cam = depth.view(batch_size, -1)
pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics)
# permute for batch matrix product
pixel_points_cam = pixel_points_cam.permute(0, 2, 1)
world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3]
ray_dirs = world_coords - cam_loc[:, None, :]
ray_dirs = F.normalize(ray_dirs, dim=2)
return ray_dirs, cam_loc
def quat_to_rot(q):
batch_size, _ = q.shape
q = F.normalize(q, dim=1)
R = torch.ones((batch_size, 3,3)).cuda()
qr=q[:,0]
qi = q[:, 1]
qj = q[:, 2]
qk = q[:, 3]
R[:, 0, 0]=1-2 * (qj**2 + qk**2)
R[:, 0, 1] = 2 * (qj *qi -qk*qr)
R[:, 0, 2] = 2 * (qi * qk + qr * qj)
R[:, 1, 0] = 2 * (qj * qi + qk * qr)
R[:, 1, 1] = 1-2 * (qi**2 + qk**2)
R[:, 1, 2] = 2*(qj*qk - qi*qr)
R[:, 2, 0] = 2 * (qk * qi-qj * qr)
R[:, 2, 1] = 2 * (qj*qk + qi*qr)
R[:, 2, 2] = 1-2 * (qi**2 + qj**2)
return R
def lift(x, y, z, intrinsics):
# parse intrinsics
# intrinsics = intrinsics.cuda()
fx = intrinsics[:, 0, 0]
fy = intrinsics[:, 1, 1]
cx = intrinsics[:, 0, 2]
cy = intrinsics[:, 1, 2]
sk = intrinsics[:, 0, 1]
x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z
y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z
# homogeneous
return torch.stack((x_lift, y_lift, z, torch.ones_like(z)), dim=-1)
class PixelNeRFDTUDataset(data.Dataset):
"""
Processed DTU from pixelNeRF
"""
def __init__(self,
data_dir='data/DTU',
scan_id=65,
img_size=None,
device=None,
fixed_scale=0,
):
data_dir = os.path.join(data_dir, "scan{}".format(scan_id))
rgb_paths = [
x for x in glob(os.path.join(data_dir, "image", "*"))
if (x.endswith(".jpg") or x.endswith(".png"))
]
rgb_paths = sorted(rgb_paths)
mask_paths = sorted(glob(os.path.join(data_dir, "mask", "*.png")))
if len(mask_paths) == 0:
mask_paths = [None] * len(rgb_paths)
sel_indices = np.arange(len(rgb_paths))
cam_path = os.path.join(data_dir, "cameras.npz")
all_cam = np.load(cam_path)
all_imgs = []
all_poses = []
all_masks = []
all_rays = []
all_light_pose = []
all_K = []
all_R = []
all_T = []
for idx, (rgb_path, mask_path) in enumerate(zip(rgb_paths, mask_paths)):
i = sel_indices[idx]
rgb = load_rgb(rgb_path)
mask = load_mask(mask_path)
rgb[~mask] = 0.
rgb = torch.from_numpy(rgb).float().to(device)
mask = torch.from_numpy(mask).float().to(device)
x_scale = y_scale = 1.0
xy_delta = 0.0
P = all_cam["world_mat_" + str(i)]
P = P[:3]
# scale the original shape to really [-0.9, 0.9]
if fixed_scale!=0.:
scale_mat_new = np.eye(4, 4)
scale_mat_new[:3, :3] *= fixed_scale # scale to [-0.9, 0.9]
P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)] @ scale_mat_new
else:
P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)]
P = P[:3, :4]
K, R, t = cv2.decomposeProjectionMatrix(P)[:3]
K = K / K[2, 2]
fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
########!!!!!
RR = torch.from_numpy(R).permute(1, 0).unsqueeze(0)
tt = torch.from_numpy(-R@(t[:3] / t[3])).permute(1, 0)
focal = torch.tensor((fx, fy), dtype=torch.float32).unsqueeze(0)
pc = torch.tensor((cx, cy), dtype=torch.float32).unsqueeze(0)
im_size = (rgb.shape[1], rgb.shape[0])
# check https://pytorch3d.org/docs/cameras for how to transform from screen to NDC
s = min(im_size)
focal[:, 0] = focal[:, 0] * 2 / (s-1)
focal[:, 1] = focal[:, 1] * 2 /(s-1)
pc[:, 0] = -(pc[:, 0] - (im_size[0]-1)/2) * 2 / (s-1)
pc[:, 1] = -(pc[:, 1] - (im_size[1]-1)/2) * 2 / (s-1)
camera = PerspectiveCameras(focal_length=-focal, principal_point=pc,
device=device, R=RR, T=tt)
# calculate camera rays
uv = uv_creation(im_size)[None].float()
pose = np.eye(4, dtype=np.float32)
pose[:3, :3] = R.transpose()
pose[:3,3] = (t[:3] / t[3])[:,0]
pose = torch.from_numpy(pose)[None].float()
intrinsics = np.eye(4)
intrinsics[:3, :3] = K
intrinsics[0, 1] = 0. #! remove skew for now
intrinsics = torch.from_numpy(intrinsics)[None].float()
rays, _ = get_camera_params(uv, pose, intrinsics)
rays = -rays.to(device)
all_poses.append(camera)
all_imgs.append(rgb)
all_masks.append(mask)
all_rays.append(rays)
all_light_pose.append(pose)
# only for neural renderer
all_K.append(torch.tensor(K).to(device))
all_R.append(torch.tensor(R).to(device))
all_T.append(torch.tensor(t[:3]/t[3]).to(device))
all_imgs = torch.stack(all_imgs)
all_masks = torch.stack(all_masks)
all_rays = torch.stack(all_rays)
all_light_pose = torch.stack(all_light_pose).squeeze()
# only for neural renderer
all_K = torch.stack(all_K).float()
all_R = torch.stack(all_R).float()
all_T = torch.stack(all_T).permute(0, 2, 1).float()
uv = uv_creation((all_imgs.size(2), all_imgs.size(1)))
self.data = {'rgbs': all_imgs,
'masks': all_masks,
'poses': all_poses,
'rays': all_rays,
'uv': uv,
'light_pose': all_light_pose, # for rendering lights
'K': all_K,
'R': all_R,
'T': all_T,
}
def __len__(self):
return 1
def __getitem__(self, idx):
return self.data

75
src/dpsr.py Normal file
View file

@ -0,0 +1,75 @@
import torch
import torch.nn as nn
from src.utils import spec_gaussian_filter, fftfreqs, img, grid_interp, point_rasterize
import numpy as np
import torch.fft
class DPSR(nn.Module):
def __init__(self, res, sig=10, scale=True, shift=True):
"""
:param res: tuple of output field resolution. eg., (128,128)
:param sig: degree of gaussian smoothing
"""
super(DPSR, self).__init__()
self.res = res
self.sig = sig
self.dim = len(res)
self.denom = np.prod(res)
G = spec_gaussian_filter(res=res, sig=sig).float()
G = G
# self.G.requires_grad = False # True, if we also make sig a learnable parameter
self.omega = fftfreqs(res, dtype=torch.float32)
self.scale = scale
self.shift = shift
self.register_buffer("G", G)
def forward(self, V, N):
"""
:param V: (batch, nv, 2 or 3) tensor for point cloud coordinates
:param N: (batch, nv, 2 or 3) tensor for point normals
:return phi: (batch, res, res, ...) tensor of output indicator function field
"""
assert(V.shape == N.shape) # [b, nv, ndims]
ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2]
#!!! OLD
# ras_s = torch.rfft(ras_p, signal_ndim=self.dim) # [b, n_dim, dim0, dim1, dim2/2+1, 2]
# ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+2))+[1, self.dim+2]))
# N_ = (ras_s * self.G) # [b, n_dim, dim0, dim1, dim2/2+1, 2]
ras_s = torch.fft.rfftn(ras_p, dim=(2,3,4))
ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+1))+[self.dim+1, 1]))
N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1]
omega = fftfreqs(self.res, dtype=torch.float32).unsqueeze(-1) # [dim0, dim1, dim2/2+1, n_dim, 1]
omega *= 2 * np.pi # normalize frequencies
omega = omega.to(V.device)
# DivN = torch.sum(-img(N_) * omega, dim=-2) #!!! OLD [b, dim0, dim1, dim2/2+1, 2]
DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2)
Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1]
Phi = DivN / (Lap+1e-6) # [b, dim0, dim1, dim2/2+1, 2]
Phi = Phi.permute(*tuple([list(range(1,self.dim+2)) + [0]])) # [dim0, dim1, dim2/2+1, 2, b]
Phi[tuple([0] * self.dim)] = 0
Phi = Phi.permute(*tuple([[self.dim+1] + list(range(self.dim+1))])) # [b, dim0, dim1, dim2/2+1, 2]
# phi = torch.irfft(Phi, signal_ndim=self.dim, signal_sizes=self.res)#!!! OLD [b, dim0, dim1, dim2]
phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1,2,3))
if self.shift or self.scale:
# ensure values at points are zero
fv = grid_interp(phi.unsqueeze(-1), V, batched=True).squeeze(-1) # [b, nv]
if self.shift: # offset points to have mean of 0
offset = torch.mean(fv, dim=-1) # [b,]
phi -= offset.view(*tuple([-1] + [1] * self.dim))
phi = phi.permute(*tuple([list(range(1,self.dim+1)) + [0]]))
fv0 = phi[tuple([0] * self.dim)] # [b,]
phi = phi.permute(*tuple([[self.dim] + list(range(self.dim))]))
if self.scale:
# phi = phi / fv0.view(*tuple([-1] + [1] * self.dim)) * 0.5
phi = -phi / torch.abs(fv0.view(*tuple([-1]+[1] * self.dim))) *0.5
return phi

168
src/eval.py Normal file
View file

@ -0,0 +1,168 @@
import logging
import numpy as np
import trimesh
from pykdtree.kdtree import KDTree
EMPTY_PCL_DICT = {
'completeness': np.sqrt(3),
'accuracy': np.sqrt(3),
'completeness2': 3,
'accuracy2': 3,
'chamfer': 6,
}
EMPTY_PCL_DICT_NORMALS = {
'normals completeness': -1.,
'normals accuracy': -1.,
'normals': -1.,
}
logger = logging.getLogger(__name__)
class MeshEvaluator(object):
''' Mesh evaluation class.
It handles the mesh evaluation process.
Args:
n_points (int): number of points to be used for evaluation
'''
def __init__(self, n_points=100000):
self.n_points = n_points
def eval_mesh(self, mesh, pointcloud_tgt, normals_tgt, thresholds=np.linspace(1./1000, 1, 1000)):
''' Evaluates a mesh.
Args:
mesh (trimesh): mesh which should be evaluated
pointcloud_tgt (numpy array): target point cloud
normals_tgt (numpy array): target normals
thresholds (numpy arry): for F-Score
'''
if len(mesh.vertices) != 0 and len(mesh.faces) != 0:
pointcloud, idx = mesh.sample(self.n_points, return_index=True)
pointcloud = pointcloud.astype(np.float32)
normals = mesh.face_normals[idx]
else:
pointcloud = np.empty((0, 3))
normals = np.empty((0, 3))
out_dict = self.eval_pointcloud(
pointcloud, pointcloud_tgt, normals, normals_tgt, thresholds=thresholds)
return out_dict
def eval_pointcloud(self, pointcloud, pointcloud_tgt,
normals=None, normals_tgt=None,
thresholds=np.linspace(1./1000, 1, 1000)):
''' Evaluates a point cloud.
Args:
pointcloud (numpy array): predicted point cloud
pointcloud_tgt (numpy array): target point cloud
normals (numpy array): predicted normals
normals_tgt (numpy array): target normals
thresholds (numpy array): threshold values for the F-score calculation
'''
# Return maximum losses if pointcloud is empty
if pointcloud.shape[0] == 0:
logger.warn('Empty pointcloud / mesh detected!')
out_dict = EMPTY_PCL_DICT.copy()
if normals is not None and normals_tgt is not None:
out_dict.update(EMPTY_PCL_DICT_NORMALS)
return out_dict
pointcloud = np.asarray(pointcloud)
pointcloud_tgt = np.asarray(pointcloud_tgt)
# Completeness: how far are the points of the target point cloud
# from thre predicted point cloud
completeness, completeness_normals = distance_p2p(
pointcloud_tgt, normals_tgt, pointcloud, normals
)
recall = get_threshold_percentage(completeness, thresholds)
completeness2 = completeness**2
completeness = completeness.mean()
completeness2 = completeness2.mean()
completeness_normals = completeness_normals.mean()
# Accuracy: how far are th points of the predicted pointcloud
# from the target pointcloud
accuracy, accuracy_normals = distance_p2p(
pointcloud, normals, pointcloud_tgt, normals_tgt
)
precision = get_threshold_percentage(accuracy, thresholds)
accuracy2 = accuracy**2
accuracy = accuracy.mean()
accuracy2 = accuracy2.mean()
accuracy_normals = accuracy_normals.mean()
# Chamfer distance
chamferL2 = 0.5 * (completeness2 + accuracy2)
normals_correctness = (
0.5 * completeness_normals + 0.5 * accuracy_normals
)
chamferL1 = 0.5 * (completeness + accuracy)
# F-Score
F = [
2 * precision[i] * recall[i] / (precision[i] + recall[i])
for i in range(len(precision))
]
out_dict = {
'completeness': completeness,
'accuracy': accuracy,
'normals completeness': completeness_normals,
'normals accuracy': accuracy_normals,
'normals': normals_correctness,
'completeness2': completeness2,
'accuracy2': accuracy2,
'chamfer-L2': chamferL2,
'chamfer-L1': chamferL1,
'f-score': F[9], # threshold = 1.0%
'f-score-15': F[14], # threshold = 1.5%
'f-score-20': F[19], # threshold = 2.0%
}
return out_dict
def distance_p2p(points_src, normals_src, points_tgt, normals_tgt):
''' Computes minimal distances of each point in points_src to points_tgt.
Args:
points_src (numpy array): source points
normals_src (numpy array): source normals
points_tgt (numpy array): target points
normals_tgt (numpy array): target normals
'''
kdtree = KDTree(points_tgt)
dist, idx = kdtree.query(points_src)
if normals_src is not None and normals_tgt is not None:
normals_src = \
normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True)
normals_tgt = \
normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True)
normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1)
# Handle normals that point into wrong direction gracefully
# (mostly due to mehtod not caring about this in generation)
normals_dot_product = np.abs(normals_dot_product)
else:
normals_dot_product = np.array(
[np.nan] * points_src.shape[0], dtype=np.float32)
return dist, normals_dot_product
def get_threshold_percentage(dist, thresholds):
''' Evaluates a point cloud.
Args:
dist (numpy array): calculated distance
thresholds (numpy array): threshold values for the F-score calculation
'''
in_threshold = [
(dist <= t).mean() for t in thresholds
]
return in_threshold

63
src/generation.py Normal file
View file

@ -0,0 +1,63 @@
import torch
import time
import trimesh
import numpy as np
from src.utils import mc_from_psr
class Generator3D(object):
''' Generator class for Occupancy Networks.
It provides functions to generate the final mesh as well refining options.
Args:
model (nn.Module): trained Occupancy Network model
points_batch_size (int): batch size for points evaluation
threshold (float): threshold value
device (device): pytorch device
padding (float): how much padding should be used for MISE
sample (bool): whether z should be sampled
input_type (str): type of input
'''
def __init__(self, model, points_batch_size=100000,
threshold=0.5, device=None, padding=0.1,
sample=False, input_type = None, dpsr=None, psr_tanh=True):
self.model = model.to(device)
self.points_batch_size = points_batch_size
self.threshold = threshold
self.device = device
self.input_type = input_type
self.padding = padding
self.sample = sample
self.dpsr = dpsr
self.psr_tanh = psr_tanh
def generate_mesh(self, data, return_stats=True):
''' Generates the output mesh.
Args:
data (tensor): data tensor
return_stats (bool): whether stats should be returned
'''
self.model.eval()
device = self.device
stats_dict = {}
p = data.get('inputs', torch.empty(1, 0)).to(device)
t0 = time.time()
points, normals = self.model(p)
t1 = time.time()
psr_grid = self.dpsr(points, normals)
t2 = time.time()
v, f, _ = mc_from_psr(psr_grid,
zero_level=self.threshold)
stats_dict['pcl'] = t1 - t0
stats_dict['dpsr'] = t2 - t1
stats_dict['mc'] = time.time() - t2
stats_dict['total'] = time.time() - t0
if return_stats:
return v, f, points, normals, stats_dict
else:
return v, f, points, normals

181
src/model.py Normal file
View file

@ -0,0 +1,181 @@
import torch
import numpy as np
import time
from src.utils import point_rasterize, grid_interp, mc_from_psr, \
calc_inters_points
from src.dpsr import DPSR
import torch.nn as nn
from src.network import encoder_dict, decoder_dict
from src.network.utils import map2local
class PSR2Mesh(torch.autograd.Function):
@staticmethod
def forward(ctx, psr_grid):
"""
In the forward pass we receive a Tensor containing the input and return
a Tensor containing the output. ctx is a context object that can be used
to stash information for backward computation. You can cache arbitrary
objects for use in the backward pass using the ctx.save_for_backward method.
"""
verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
verts = verts.unsqueeze(0)
faces = faces.unsqueeze(0)
normals = normals.unsqueeze(0)
res = torch.tensor(psr_grid.detach().shape[2])
ctx.save_for_backward(verts, normals, res)
return verts, faces, normals
@staticmethod
def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals):
"""
In the backward pass we receive a Tensor containing the gradient of the loss
with respect to the output, and we need to compute the gradient of the loss
with respect to the input.
"""
vert_pts, normals, res = ctx.saved_tensors
res = (res.item(), res.item(), res.item())
# matrix multiplication between dL/dV and dV/dPSR
# dV/dPSR = - normals
grad_vert = torch.matmul(dL_dVertex.permute(1, 0, 2), -normals.permute(1, 2, 0))
grad_grid = point_rasterize(vert_pts, grad_vert.permute(1, 0, 2), res) # b x 1 x res x res x res
return grad_grid
class PSR2SurfacePoints(torch.autograd.Function):
@staticmethod
def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample):
verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
verts = verts * 2. - 1. # within the range of [-1, 1]
p_all, n_all, mask_all = [], [], []
for i in range(len(poses)):
pose = poses[i]
if mask_sample is not None:
p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size, mask_gt=mask_sample[i])
else:
p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size)
n_inters = grid_interp(psr_grad[None], (p_inters[None].detach() + 1) / 2).squeeze()
p_all.append(p_inters)
n_all.append(n_inters)
mask_all.append(mask)
p_inters_all = torch.cat(p_all, dim=0)
n_inters_all = torch.cat(n_all, dim=0)
mask_visible = torch.stack(mask_all, dim=0)
res = torch.tensor(psr_grid.detach().shape[2])
ctx.save_for_backward(p_inters_all, n_inters_all, res)
return p_inters_all, mask_visible
@staticmethod
def backward(ctx, dL_dp, dL_dmask):
pts, pts_n, res = ctx.saved_tensors
res = (res.item(), res.item(), res.item())
# grad from the p_inters via MLP renderer
grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None])
grad_grid_pts = point_rasterize((pts[None]+1)/2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res
return grad_grid_pts, None, None, None, None, None
class Encode2Points(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
encoder = cfg['model']['encoder']
decoder = cfg['model']['decoder']
dim = cfg['data']['dim'] # input dim
c_dim = cfg['model']['c_dim']
encoder_kwargs = cfg['model']['encoder_kwargs']
if encoder_kwargs == None:
encoder_kwargs = {}
decoder_kwargs = cfg['model']['decoder_kwargs']
padding = cfg['data']['padding']
self.predict_normal = cfg['model']['predict_normal']
self.predict_offset = cfg['model']['predict_offset']
out_dim = 3
out_dim_offset = 3
num_offset = cfg['data']['num_offset']
# each point predict more than one offset to add output points
if num_offset > 1:
out_dim_offset = out_dim * num_offset
self.num_offset = num_offset
# local mapping
self.map2local = None
if cfg['model']['local_coord']:
if 'unet' in encoder_kwargs.keys():
unit_size = 1 / encoder_kwargs['plane_resolution']
else:
unit_size = 1 / encoder_kwargs['grid_resolution']
local_mapping = map2local(unit_size)
self.encoder = encoder_dict[encoder](
dim=dim, c_dim=c_dim, map2local=local_mapping,
**encoder_kwargs
)
if self.predict_normal:
# decoder for normal prediction
self.decoder_normal = decoder_dict[decoder](
dim=dim, c_dim=c_dim, out_dim=out_dim,
**decoder_kwargs)
if self.predict_offset:
# decoder for offset prediction
self.decoder_offset = decoder_dict[decoder](
dim=dim, c_dim=c_dim, out_dim=out_dim_offset,
map2local=local_mapping,
**decoder_kwargs)
self.s_off = cfg['model']['s_offset']
def forward(self, p):
''' Performs a forward pass through the network.
Args:
p (tensor): input unoriented points
'''
time_dict = {}
mask = None
batch_size = p.size(0)
points = p.clone()
# encode the input point cloud to a feature volume
t0 = time.perf_counter()
c = self.encoder(p)
t1 = time.perf_counter()
if self.predict_offset:
offset = self.decoder_offset(p, c)
# more than one offset is predicted per-point
if self.num_offset > 1:
points = points.repeat(1, 1, self.num_offset).reshape(batch_size, -1, 3)
points = points + self.s_off * offset
else:
points = p
if self.predict_normal:
normals = self.decoder_normal(points, c)
t2 = time.perf_counter()
time_dict['encode'] = t1 - t0
time_dict['predict'] = t2 - t1
points = torch.clamp(points, 0.0, 0.99)
if self.cfg['model']['normal_normalize']:
normals = normals / (normals.norm(dim=-1, keepdim=True)+1e-8)
return points, normals

139
src/model_rgb.py Normal file
View file

@ -0,0 +1,139 @@
import torch
from src.network.net_rgb import RenderingNetwork
from src.utils import approx_psr_grad
from pytorch3d.renderer import (
RasterizationSettings,
PerspectiveCameras,
MeshRenderer,
MeshRasterizer,
SoftSilhouetteShader)
from pytorch3d.structures import Meshes
def approx_psr_grad(psr_grid, res, normalize=True):
delta_x = delta_y = delta_z = 1/res
psr_pad = torch.nn.ReplicationPad3d(1)(psr_grid).squeeze()
grad_x = (psr_pad[2:, :, :] - psr_pad[:-2, :, :]) / 2 / delta_x
grad_y = (psr_pad[:, 2:, :] - psr_pad[:, :-2, :]) / 2 / delta_y
grad_z = (psr_pad[:, :, 2:] - psr_pad[:, :, :-2]) / 2 / delta_z
grad_x = grad_x[:, 1:-1, 1:-1]
grad_y = grad_y[1:-1, :, 1:-1]
grad_z = grad_z[1:-1, 1:-1, :]
psr_grad = torch.stack([grad_x, grad_y, grad_z], dim=3) # [res_x, res_y, res_z, 3]
if normalize:
psr_grad = psr_grad / (psr_grad.norm(dim=3, keepdim=True) + 1e-12)
return psr_grad
class SAP2Image(nn.Module):
def __init__(self, cfg, img_size):
super().__init__()
self.psr2sur = PSR2SurfacePoints.apply
self.psr2mesh = PSR2Mesh.apply
# initialize DPSR
self.dpsr = DPSR(res=(cfg['model']['grid_res'],
cfg['model']['grid_res'],
cfg['model']['grid_res']),
sig=cfg['model']['psr_sigma'])
self.cfg = cfg
if cfg['train']['l_weight']['rgb'] != 0.:
self.rendering_network = RenderingNetwork(**cfg['model']['renderer'])
if cfg['train']['l_weight']['mask'] != 0.:
# initialize rasterizer
sigma = 1e-4
raster_settings_soft = RasterizationSettings(
image_size=img_size,
blur_radius=np.log(1. / 1e-4 - 1.)*sigma,
faces_per_pixel=150,
perspective_correct=False
)
# initialize silhouette renderer
self.mesh_rasterizer = MeshRenderer(
rasterizer=MeshRasterizer(
raster_settings=raster_settings_soft
),
shader=SoftSilhouetteShader()
)
self.cfg = cfg
self.img_size = img_size
def forward(self, inputs, data):
points, normals = inputs[...,:3], inputs[...,3:]
points = torch.sigmoid(points)
normals = normals / normals.norm(dim=-1, keepdim=True)
# DPSR to get grid
psr_grid = self.dpsr(points, normals).unsqueeze(1)
psr_grid = torch.tanh(psr_grid)
return self.render_img(psr_grid, data)
def render_img(self, psr_grid, data):
n_views = len(data['masks'])
n_views_per_iter = self.cfg['data']['n_views_per_iter']
rgb_render_mode = self.cfg['model']['renderer']['mode']
uv = data['uv']
idx = np.random.randint(0, n_views, n_views_per_iter)
pose = [data['poses'][i] for i in idx]
rgb = data['rgbs'][idx]
mask_gt = data['masks'][idx]
ray = None
pred_rgb = None
pred_mask = None
if self.cfg['train']['l_weight']['rgb'] != 0.:
psr_grad = approx_psr_grad(psr_grid, self.cfg['model']['grid_res'])
p_inters, visible_mask = self.psr2sur(psr_grid, pose, self.img_size, uv, psr_grad, None)
n_inters = grid_interp(psr_grad[None], (p_inters.detach()[None] + 1) / 2)
fea_interp = None
if 'rays' in data.keys():
ray = data['rays'].squeeze()[idx][visible_mask]
pred_rgb = self.rendering_network(p_inters, normals=n_inters.squeeze(), view_dirs=ray, feature_vectors=fea_interp)
# silhouette loss
if self.cfg['train']['l_weight']['mask'] != 0.:
# build mesh
v, f, _ = self.psr2mesh(psr_grid)
v = v * 2. - 1 # within the range of [-1, 1]
# ! Fast but more GPU usage
mesh = Meshes(verts=[v.squeeze()], faces=[f.squeeze()])
if True:
#! PyTorch3D silhouette loss
# build pose
R = torch.cat([p.R for p in pose], dim=0)
T = torch.cat([p.T for p in pose], dim=0)
focal = torch.cat([p.focal_length for p in pose], dim=0)
pp = torch.cat([p.principal_point for p in pose], dim=0)
pose_cur = PerspectiveCameras(
focal_length=focal,
principal_point=pp,
R=R, T=T,
device=R.device)
pred_mask = self.mesh_rasterizer(mesh.extend(n_views_per_iter), cameras=pose_cur)[..., 3]
else:
pred_mask = []
# ! Slow but less GPU usage
for i in range(n_views_per_iter):
#! PyTorch3D silhouette loss
pred_mask.append(self.mesh_rasterizer(mesh, cameras=pose[i])[..., 3])
pred_mask = torch.cat(pred_mask, dim=0)
output = {
'rgb': pred_rgb,
'rgb_gt': rgb,
'mask': pred_mask,
'mask_gt': mask_gt,
'vis_mask': visible_mask,
}
return output

8
src/network/__init__.py Normal file
View file

@ -0,0 +1,8 @@
from src.network import encoder, decoder
encoder_dict = {
'local_pool_pointnet': encoder.LocalPoolPointnet,
}
decoder_dict = {
'simple_local': decoder.LocalDecoder,
}

106
src/network/decoder.py Normal file
View file

@ -0,0 +1,106 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ipdb import set_trace as st
from src.network.utils import normalize_3d_coordinate, ResnetBlockFC, \
normalize_coordinate
class LocalDecoder(nn.Module):
''' Decoder.
Instead of conditioning on global features, on plane/volume local features.
Args:
dim (int): input dimension
c_dim (int): dimension of latent conditioned code c
hidden_size (int): hidden size of Decoder network
n_blocks (int): number of blocks ResNetBlockFC layers
leaky (bool): whether to use leaky ReLUs
sample_mode (str): sampling feature strategy, bilinear|nearest
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
'''
def __init__(self, dim=3, c_dim=128, out_dim=3,
hidden_size=256, n_blocks=5, leaky=False, sample_mode='bilinear', padding=0.1, map2local=None):
super().__init__()
self.c_dim = c_dim
self.n_blocks = n_blocks
if c_dim != 0:
self.fc_c = nn.ModuleList([
nn.Linear(c_dim, hidden_size) for i in range(n_blocks)
])
self.fc_p = nn.Linear(dim, hidden_size)
self.blocks = nn.ModuleList([
ResnetBlockFC(hidden_size) for i in range(n_blocks)
])
self.fc_out = nn.Linear(hidden_size, out_dim)
if not leaky:
self.actvn = F.relu
else:
self.actvn = lambda x: F.leaky_relu(x, 0.2)
self.sample_mode = sample_mode
self.padding = padding
self.map2local = map2local
self.out_dim = out_dim
def sample_plane_feature(self, p, c, plane='xz'):
xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
xy = xy[:, :, None].float()
vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
c = F.grid_sample(c, vgrid, padding_mode='border',
align_corners=True,
mode=self.sample_mode).squeeze(-1)
return c
def sample_grid_feature(self, p, c):
p_nor = normalize_3d_coordinate(p.clone())
p_nor = p_nor[:, :, None, None].float()
vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1)
# acutally trilinear interpolation if mode = 'bilinear'
c = F.grid_sample(c, vgrid, padding_mode='border',
align_corners=True,
mode=self.sample_mode).squeeze(-1).squeeze(-1)
return c
def forward(self, p, c_plane, **kwargs):
batch_size = p.shape[0]
plane_type = list(c_plane.keys())
c = 0
if 'grid' in plane_type:
c += self.sample_grid_feature(p, c_plane['grid'])
if 'xz' in plane_type:
c += self.sample_plane_feature(p, c_plane['xz'], plane='xz')
if 'xy' in plane_type:
c += self.sample_plane_feature(p, c_plane['xy'], plane='xy')
if 'yz' in plane_type:
c += self.sample_plane_feature(p, c_plane['yz'], plane='yz')
c = c.transpose(1, 2)
p = p.float()
if self.map2local:
p = self.map2local(p)
net = self.fc_p(p)
for i in range(self.n_blocks):
if self.c_dim != 0:
net = net + self.fc_c[i](c)
net = self.blocks[i](net)
out = self.fc_out(self.actvn(net))
if self.out_dim > 3:
out = out.reshape(batch_size, -1, 3)
return out

181
src/network/encoder.py Normal file
View file

@ -0,0 +1,181 @@
import torch
import torch.nn as nn
import numpy as np
from src.network.unet3d import UNet3D
from src.network.unet import UNet
from ipdb import set_trace as st
from torch_scatter import scatter_mean, scatter_max
from src.network.utils import get_embedder, normalize_3d_coordinate,\
coordinate2index, ResnetBlockFC, normalize_coordinate
class LocalPoolPointnet(nn.Module):
''' PointNet-based encoder network with ResNet blocks for each point.
Number of input points are fixed.
Args:
c_dim (int): dimension of latent code c
dim (int): input points dimension
hidden_dim (int): hidden dimension of the network
scatter_type (str): feature aggregation when doing local pooling
unet (bool): weather to use U-Net
unet_kwargs (str): U-Net parameters
unet3d (bool): weather to use 3D U-Net
unet3d_kwargs (str): 3D U-Net parameters
plane_resolution (int): defined resolution for plane feature
grid_resolution (int): defined resolution for grid feature
plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
n_blocks (int): number of blocks ResNetBlockFC layers
map2local (function): map global coordintes to local ones
pos_encoding (int): frequency for the positional encoding
'''
def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max',
unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None,
plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5,
map2local=None, pos_encoding=0):
super().__init__()
self.c_dim = c_dim
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
self.blocks = nn.ModuleList([
ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
])
self.fc_c = nn.Linear(hidden_dim, c_dim)
self.actvn = nn.ReLU()
self.hidden_dim = hidden_dim
self.unet = None
if unet:
self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
self.unet3d = None
if unet3d:
self.unet3d = UNet3D(**unet3d_kwargs)
self.reso_plane = plane_resolution
self.reso_grid = grid_resolution
self.plane_type = plane_type
self.padding = padding
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
self.pe = None
if pos_encoding > 0:
embed_fn, input_ch = get_embedder(pos_encoding, d_in=dim)
self.pe = embed_fn
self.fc_pos = nn.Linear(input_ch, 2*hidden_dim)
self.map2local = map2local
if scatter_type == 'max':
self.scatter = scatter_max
elif scatter_type == 'mean':
self.scatter = scatter_mean
else:
raise ValueError('incorrect scatter type')
def generate_plane_features(self, p, c, plane='xz'):
# acquire indices of features in plane
xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
index = coordinate2index(xy, self.reso_plane)
# scatter plane features from points
fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
c = c.permute(0, 2, 1) # B x 512 x T
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)
# process the plane features with UNet
if self.unet is not None:
fea_plane = self.unet(fea_plane)
return fea_plane
def generate_grid_features(self, p, c):
p_nor = normalize_3d_coordinate(p.clone())
index = coordinate2index(p_nor, self.reso_grid, coord_type='3d')
# scatter grid features from points
fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3)
c = c.permute(0, 2, 1)
fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3
fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) # sparce matrix (B x 512 x reso x reso)
if self.unet3d is not None:
fea_grid = self.unet3d(fea_grid)
return fea_grid
def pool_local(self, xy, index, c):
bs, fea_dim = c.size(0), c.size(2)
keys = xy.keys()
c_out = 0
for key in keys:
# scatter plane features from points
if key == 'grid':
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid**3)
else:
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2)
if self.scatter == scatter_max:
fea = fea[0]
# gather feature back to points
fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
c_out += fea
return c_out.permute(0, 2, 1)
def forward(self, p, normalize=True):
batch_size, T, D = p.size()
# acquire the index for each point
coord = {}
index = {}
if 'xz' in self.plane_type:
coord['xz'] = normalize_coordinate(p.clone(), plane='xz')
index['xz'] = coordinate2index(coord['xz'], self.reso_plane)
if 'xy' in self.plane_type:
coord['xy'] = normalize_coordinate(p.clone(), plane='xy')
index['xy'] = coordinate2index(coord['xy'], self.reso_plane)
if 'yz' in self.plane_type:
coord['yz'] = normalize_coordinate(p.clone(), plane='yz')
index['yz'] = coordinate2index(coord['yz'], self.reso_plane)
if 'grid' in self.plane_type:
if normalize:
coord['grid'] = normalize_3d_coordinate(p.clone())
else:
coord['grid'] = p.clone()[...,:3]
index['grid'] = coordinate2index(coord['grid'], self.reso_grid, coord_type='3d')
if self.pe:
p = self.pe(p)
if self.map2local:
pp = self.map2local(p)
net = self.fc_pos(pp)
else:
net = self.fc_pos(p)
# net = self.fc_pos(p)
net = self.blocks[0](net)
for block in self.blocks[1:]:
pooled = self.pool_local(coord, index, net)
net = torch.cat([net, pooled], dim=2)
net = block(net)
c = self.fc_c(net)
fea = {}
if 'grid' in self.plane_type:
fea['grid'] = self.generate_grid_features(p, c)
if 'xz' in self.plane_type:
fea['xz'] = self.generate_plane_features(p, c, plane='xz')
if 'xy' in self.plane_type:
fea['xy'] = self.generate_plane_features(p, c, plane='xy')
if 'yz' in self.plane_type:
fea['yz'] = self.generate_plane_features(p, c, plane='yz')
return fea

234
src/network/net_rgb.py Normal file
View file

@ -0,0 +1,234 @@
# code from IDR (https://github.com/lioryariv/idr/blob/main/code/model/implicit_differentiable_renderer.py)
import torch
import torch.nn as nn
import numpy as np
from src.network.utils import get_embedder
from pdb import set_trace as st
class RenderingNetwork(nn.Module):
def __init__(
self,
fea_size=0,
mode='naive',
d_out=3,
dims=[512, 512, 512, 512],
weight_norm=True,
pe_freq_view=0 # for positional encoding
):
super().__init__()
self.mode = mode
if mode == 'naive':
d_in = 3
elif mode == 'no_feature':
d_in = 3 + 3 + 3
fea_size = 0
elif mode == 'full':
d_in = 3 + 3 + 3
else:
d_in = 3 + 3
dims = [d_in + fea_size] + dims + [d_out]
self.embedview_fn = None
if pe_freq_view > 0:
embedview_fn, input_ch = get_embedder(pe_freq_view, d_in=3)
self.embedview_fn = embedview_fn
dims[0] += (input_ch - 3)
self.num_layers = len(dims)
for l in range(0, self.num_layers - 1):
out_dim = dims[l + 1]
lin = nn.Linear(dims[l], out_dim)
if weight_norm:
lin = nn.utils.weight_norm(lin)
setattr(self, "lin" + str(l), lin)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, points, normals=None, view_dirs=None, feature_vectors=None):
if self.embedview_fn is not None:
view_dirs = self.embedview_fn(view_dirs)
# points = self.embedview_fn(points)
if (self.mode == 'full') & (feature_vectors is not None):
rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
elif (self.mode == 'no_feature') | ((self.mode == 'full') & (feature_vectors is None)):
rendering_input = torch.cat([points, view_dirs, normals], dim=-1)
elif self.mode == 'no_view_dir':
rendering_input = torch.cat([points, normals], dim=-1)
elif self.mode == 'no_normal':
rendering_input = torch.cat([points, view_dirs], dim=-1)
else:
rendering_input = points
x = rendering_input
for l in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(l))
x = lin(x)
if l < self.num_layers - 2:
x = self.relu(x)
x = self.tanh(x)
return x
class NeRFRenderingNetwork(nn.Module):
def __init__(
self,
feature_vector_size=0,
mode='naive',
d_in=3,
d_out=3,
dims=[512, 512, 512, 256],
weight_norm=True,
multires=0, # positional encoding of points
multires_view=0 # positional encoding of view
):
super().__init__()
self.mode = mode
dims = [d_in + feature_vector_size] + dims
self.embed_fn = None
if multires > 0:
embed_fn, input_ch = get_embedder(multires, d_in=d_in)
self.embed_fn = embed_fn
dims[0] += (input_ch - 3)
self.num_layers = len(dims)
self.pts_net = nn.ModuleList([nn.Linear(dims[i], dims[i + 1]) for i in range(self.num_layers - 1)])
self.embedview_fn = None
if multires_view > 0:
embedview_fn, view_ch = get_embedder(multires_view, d_in=3)
self.embedview_fn = embedview_fn
# dims[0] += (input_ch - 3)
if mode == 'full':
self.view_net = nn.ModuleList([nn.Linear(dims[-1]+view_ch, 128)])
self.rgb_net = nn.Linear(128, 3)
else:
self.rgb_net = nn.Linear(dims[-1], 3)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, points, normals=None, view_dirs=None, feature_vectors=None):
if self.embed_fn is not None:
points = self.embed_fn(points)
if self.embedview_fn is not None:
view_dirs = self.embedview_fn(view_dirs)
x = points
for net in self.pts_net:
x = net(x)
x = self.relu(x)
if self.mode=='full':
x = torch.cat([x, view_dirs], -1)
for net in self.view_net:
x = net(x)
x = self.relu(x)
x = self.rgb_net(x)
x = self.tanh(x)
return x
class ImplicitNetwork(nn.Module):
def __init__(
self,
d_in,
d_out,
dims,
geometric_init=True,
feature_vector_size=0,
bias=1.0,
skip_in=(),
weight_norm=True,
multires=0
):
super().__init__()
dims = [d_in] + dims + [d_out + feature_vector_size]
self.embed_fn = None
if multires > 0:
embed_fn, input_ch = get_embedder(multires)
self.embed_fn = embed_fn
dims[0] = input_ch
self.num_layers = len(dims)
self.skip_in = skip_in
for l in range(0, self.num_layers - 1):
if l + 1 in self.skip_in:
out_dim = dims[l + 1] - dims[0]
else:
out_dim = dims[l + 1]
lin = nn.Linear(dims[l], out_dim)
if geometric_init:
if l == self.num_layers - 2:
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
torch.nn.init.constant_(lin.bias, -bias)
elif multires > 0 and l == 0:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
elif multires > 0 and l in self.skip_in:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
if weight_norm:
lin = nn.utils.weight_norm(lin)
setattr(self, "lin" + str(l), lin)
self.softplus = nn.Softplus(beta=100)
def forward(self, input, compute_grad=False):
if self.embed_fn is not None:
input = self.embed_fn(input)
x = input
for l in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(l))
if l in self.skip_in:
x = torch.cat([x, input], 1) / np.sqrt(2)
x = lin(x)
if l < self.num_layers - 2:
x = self.softplus(x)
return x
def gradient(self, x):
x.requires_grad_(True)
y = self.forward(x)[:,:1]
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return gradients.unsqueeze(1)

256
src/network/unet.py Normal file
View file

@ -0,0 +1,256 @@
'''
Codes are from:
https://github.com/jaxony/unet-pytorch/blob/master/model.py
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import numpy as np
def conv3x3(in_channels, out_channels, stride=1,
padding=1, bias=True, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=padding,
bias=bias,
groups=groups)
def upconv2x2(in_channels, out_channels, mode='transpose'):
if mode == 'transpose':
return nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=2,
stride=2)
else:
# out_channels is always going to be the same
# as in_channels
return nn.Sequential(
nn.Upsample(mode='bilinear', scale_factor=2),
conv1x1(in_channels, out_channels))
def conv1x1(in_channels, out_channels, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
groups=groups,
stride=1)
class DownConv(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 MaxPool.
A ReLU activation follows each convolution.
"""
def __init__(self, in_channels, out_channels, pooling=True):
super(DownConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.pooling = pooling
self.conv1 = conv3x3(self.in_channels, self.out_channels)
self.conv2 = conv3x3(self.out_channels, self.out_channels)
if self.pooling:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
before_pool = x
if self.pooling:
x = self.pool(x)
return x, before_pool
class UpConv(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 UpConvolution.
A ReLU activation follows each convolution.
"""
def __init__(self, in_channels, out_channels,
merge_mode='concat', up_mode='transpose'):
super(UpConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.merge_mode = merge_mode
self.up_mode = up_mode
self.upconv = upconv2x2(self.in_channels, self.out_channels,
mode=self.up_mode)
if self.merge_mode == 'concat':
self.conv1 = conv3x3(
2*self.out_channels, self.out_channels)
else:
# num of input channels to conv2 is same
self.conv1 = conv3x3(self.out_channels, self.out_channels)
self.conv2 = conv3x3(self.out_channels, self.out_channels)
def forward(self, from_down, from_up):
""" Forward pass
Arguments:
from_down: tensor from the encoder pathway
from_up: upconv'd tensor from the decoder pathway
"""
from_up = self.upconv(from_up)
if self.merge_mode == 'concat':
x = torch.cat((from_up, from_down), 1)
else:
x = from_up + from_down
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
return x
class UNet(nn.Module):
""" `UNet` class is based on https://arxiv.org/abs/1505.04597
The U-Net is a convolutional encoder-decoder neural network.
Contextual spatial information (from the decoding,
expansive pathway) about an input tensor is merged with
information representing the localization of details
(from the encoding, compressive pathway).
Modifications to the original paper:
(1) padding is used in 3x3 convolutions to prevent loss
of border pixels
(2) merging outputs does not require cropping due to (1)
(3) residual connections can be used by specifying
UNet(merge_mode='add')
(4) if non-parametric upsampling is used in the decoder
pathway (specified by upmode='upsample'), then an
additional 1x1 2d convolution occurs after upsampling
to reduce channel dimensionality by a factor of 2.
This channel halving happens with the convolution in
the tranpose convolution (specified by upmode='transpose')
"""
def __init__(self, num_classes, in_channels=3, depth=5,
start_filts=64, up_mode='transpose',
merge_mode='concat', **kwargs):
"""
Arguments:
in_channels: int, number of channels in the input tensor.
Default is 3 for RGB images.
depth: int, number of MaxPools in the U-Net.
start_filts: int, number of convolutional filters for the
first conv.
up_mode: string, type of upconvolution. Choices: 'transpose'
for transpose convolution or 'upsample' for nearest neighbour
upsampling.
"""
super(UNet, self).__init__()
if up_mode in ('transpose', 'upsample'):
self.up_mode = up_mode
else:
raise ValueError("\"{}\" is not a valid mode for "
"upsampling. Only \"transpose\" and "
"\"upsample\" are allowed.".format(up_mode))
if merge_mode in ('concat', 'add'):
self.merge_mode = merge_mode
else:
raise ValueError("\"{}\" is not a valid mode for"
"merging up and down paths. "
"Only \"concat\" and "
"\"add\" are allowed.".format(up_mode))
# NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
if self.up_mode == 'upsample' and self.merge_mode == 'add':
raise ValueError("up_mode \"upsample\" is incompatible "
"with merge_mode \"add\" at the moment "
"because it doesn't make sense to use "
"nearest neighbour to reduce "
"depth channels (by half).")
self.num_classes = num_classes
self.in_channels = in_channels
self.start_filts = start_filts
self.depth = depth
self.down_convs = []
self.up_convs = []
# create the encoder pathway and add to a list
for i in range(depth):
ins = self.in_channels if i == 0 else outs
outs = self.start_filts*(2**i)
pooling = True if i < depth-1 else False
down_conv = DownConv(ins, outs, pooling=pooling)
self.down_convs.append(down_conv)
# create the decoder pathway and add to a list
# - careful! decoding only requires depth-1 blocks
for i in range(depth-1):
ins = outs
outs = ins // 2
up_conv = UpConv(ins, outs, up_mode=up_mode,
merge_mode=merge_mode)
self.up_convs.append(up_conv)
# add the list of modules to current module
self.down_convs = nn.ModuleList(self.down_convs)
self.up_convs = nn.ModuleList(self.up_convs)
self.conv_final = conv1x1(outs, self.num_classes)
self.reset_params()
@staticmethod
def weight_init(m):
if isinstance(m, nn.Conv2d):
init.xavier_normal_(m.weight)
init.constant_(m.bias, 0)
def reset_params(self):
for i, m in enumerate(self.modules()):
self.weight_init(m)
def forward(self, x):
encoder_outs = []
# encoder pathway, save outputs for merging
for i, module in enumerate(self.down_convs):
x, before_pool = module(x)
encoder_outs.append(before_pool)
for i, module in enumerate(self.up_convs):
before_pool = encoder_outs[-(i+2)]
x = module(before_pool, x)
# No softmax is used. This means you need to use
# nn.CrossEntropyLoss is your training script,
# as this module includes a softmax already.
x = self.conv_final(x)
return x
if __name__ == "__main__":
"""
testing
"""
model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32)
print(model)
print(sum(p.numel() for p in model.parameters()))
reso = 176
x = np.zeros((1, 1, reso, reso))
x[:,:,int(reso/2-1), int(reso/2-1)] = np.nan
x = torch.FloatTensor(x)
out = model(x)
print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso)))
# loss = torch.sum(out)
# loss.backward()

559
src/network/unet3d.py Normal file
View file

@ -0,0 +1,559 @@
'''
Code from the 3D UNet implementation:
https://github.com/wolny/pytorch-3dunet/
'''
import importlib
import torch
import torch.nn as nn
from torch.nn import functional as F
from functools import partial
from src.network.utils import get_embedder
def number_of_features_per_level(init_channel_number, num_levels):
return [init_channel_number * 2 ** k for k in range(num_levels)]
def conv3d(in_channels, out_channels, kernel_size, bias, padding=1):
return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1):
"""
Create a list of modules with together constitute a single conv layer with non-linearity
and optional batchnorm/groupnorm.
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
order (string): order of things, e.g.
'cr' -> conv + ReLU
'gcr' -> groupnorm + conv + ReLU
'cl' -> conv + LeakyReLU
'ce' -> conv + ELU
'bcr' -> batchnorm + conv + ReLU
num_groups (int): number of groups for the GroupNorm
padding (int): add zero-padding to the input
Return:
list of tuple (name, module)
"""
assert 'c' in order, "Conv layer MUST be present"
assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer'
modules = []
for i, char in enumerate(order):
if char == 'r':
modules.append(('ReLU', nn.ReLU(inplace=True)))
elif char == 'l':
modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True)))
elif char == 'e':
modules.append(('ELU', nn.ELU(inplace=True)))
elif char == 'c':
# add learnable bias only in the absence of batchnorm/groupnorm
bias = not ('g' in order or 'b' in order)
modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding)))
elif char == 'g':
is_before_conv = i < order.index('c')
if is_before_conv:
num_channels = in_channels
else:
num_channels = out_channels
# use only one group if the given number of groups is greater than the number of channels
if num_channels < num_groups:
num_groups = 1
assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}'
modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)))
elif char == 'b':
is_before_conv = i < order.index('c')
if is_before_conv:
modules.append(('batchnorm', nn.BatchNorm3d(in_channels)))
else:
modules.append(('batchnorm', nn.BatchNorm3d(out_channels)))
else:
raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']")
return modules
class SingleConv(nn.Sequential):
"""
Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
of operations can be specified via the `order` parameter
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
kernel_size (int): size of the convolving kernel
order (string): determines the order of layers, e.g.
'cr' -> conv + ReLU
'crg' -> conv + ReLU + groupnorm
'cl' -> conv + LeakyReLU
'ce' -> conv + ELU
num_groups (int): number of groups for the GroupNorm
"""
def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8, padding=1):
super(SingleConv, self).__init__()
for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding):
self.add_module(name, module)
class DoubleConv(nn.Sequential):
"""
A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
We use (Conv3d+ReLU+GroupNorm3d) by default.
This can be changed however by providing the 'order' argument, e.g. in order
to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
Use padded convolutions to make sure that the output (H_out, W_out) is the same
as (H_in, W_in), so that you don't have to crop in the decoder path.
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
encoder (bool): if True we're in the encoder path, otherwise we're in the decoder
kernel_size (int): size of the convolving kernel
order (string): determines the order of layers, e.g.
'cr' -> conv + ReLU
'crg' -> conv + ReLU + groupnorm
'cl' -> conv + LeakyReLU
'ce' -> conv + ELU
num_groups (int): number of groups for the GroupNorm
"""
def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='crg', num_groups=8):
super(DoubleConv, self).__init__()
if encoder:
# we're in the encoder path
conv1_in_channels = in_channels
conv1_out_channels = out_channels // 2
if conv1_out_channels < in_channels:
conv1_out_channels = in_channels
conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels
else:
# we're in the decoder path, decrease the number of channels in the 1st convolution
conv1_in_channels, conv1_out_channels = in_channels, out_channels
conv2_in_channels, conv2_out_channels = out_channels, out_channels
# conv1
self.add_module('SingleConv1',
SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups))
# conv2
self.add_module('SingleConv2',
SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups))
class ExtResNetBlock(nn.Module):
"""
Basic UNet block consisting of a SingleConv followed by the residual block.
The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number
of output channels is compatible with the residual block that follows.
This block can be used instead of standard DoubleConv in the Encoder module.
Motivated by: https://arxiv.org/pdf/1706.00120.pdf
Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm.
"""
def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs):
super(ExtResNetBlock, self).__init__()
# first convolution
self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
# residual block
self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
# remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
n_order = order
for c in 'rel':
n_order = n_order.replace(c, '')
self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order,
num_groups=num_groups)
# create non-linearity separately
if 'l' in order:
self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True)
elif 'e' in order:
self.non_linearity = nn.ELU(inplace=True)
else:
self.non_linearity = nn.ReLU(inplace=True)
def forward(self, x):
# apply first convolution and save the output as a residual
out = self.conv1(x)
residual = out
# residual block
out = self.conv2(out)
out = self.conv3(out)
out += residual
out = self.non_linearity(out)
return out
class Encoder(nn.Module):
"""
A single module from the encoder path consisting of the optional max
pooling layer (one may specify the MaxPool kernel_size to be different
than the standard (2,2,2), e.g. if the volumetric data is anisotropic
(make sure to use complementary scale_factor in the decoder path) followed by
a DoubleConv module.
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
conv_kernel_size (int): size of the convolving kernel
apply_pooling (bool): if True use MaxPool3d before DoubleConv
pool_kernel_size (tuple): the size of the window to take a max over
pool_type (str): pooling layer: 'max' or 'avg'
basic_module(nn.Module): either ResNetBlock or DoubleConv
conv_layer_order (string): determines the order of layers
in `DoubleConv` module. See `DoubleConv` for more info.
num_groups (int): number of groups for the GroupNorm
"""
def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True,
pool_kernel_size=(2, 2, 2), pool_type='max', basic_module=DoubleConv, conv_layer_order='crg',
num_groups=8):
super(Encoder, self).__init__()
assert pool_type in ['max', 'avg']
if apply_pooling:
if pool_type == 'max':
self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
else:
self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size)
else:
self.pooling = None
self.basic_module = basic_module(in_channels, out_channels,
encoder=True,
kernel_size=conv_kernel_size,
order=conv_layer_order,
num_groups=num_groups)
def forward(self, x):
if self.pooling is not None:
x = self.pooling(x)
x = self.basic_module(x)
return x
class Decoder(nn.Module):
"""
A single module for decoder path consisting of the upsampling layer
(either learned ConvTranspose3d or nearest neighbor interpolation) followed by a basic module (DoubleConv or ExtResNetBlock).
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
kernel_size (int): size of the convolving kernel
scale_factor (tuple): used as the multiplier for the image H/W/D in
case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation
from the corresponding encoder
basic_module(nn.Module): either ResNetBlock or DoubleConv
conv_layer_order (string): determines the order of layers
in `DoubleConv` module. See `DoubleConv` for more info.
num_groups (int): number of groups for the GroupNorm
"""
def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv,
conv_layer_order='crg', num_groups=8, mode='nearest'):
super(Decoder, self).__init__()
if basic_module == DoubleConv:
# if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining
self.upsampling = Upsampling(transposed_conv=False, in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, scale_factor=scale_factor, mode=mode)
# concat joining
self.joining = partial(self._joining, concat=True)
else:
# if basic_module=ExtResNetBlock use transposed convolution upsampling and summation joining
self.upsampling = Upsampling(transposed_conv=True, in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, scale_factor=scale_factor, mode=mode)
# sum joining
self.joining = partial(self._joining, concat=False)
# adapt the number of in_channels for the ExtResNetBlock
in_channels = out_channels
self.basic_module = basic_module(in_channels, out_channels,
encoder=False,
kernel_size=kernel_size,
order=conv_layer_order,
num_groups=num_groups)
def forward(self, encoder_features, x):
x = self.upsampling(encoder_features=encoder_features, x=x)
x = self.joining(encoder_features, x)
x = self.basic_module(x)
return x
@staticmethod
def _joining(encoder_features, x, concat):
if concat:
return torch.cat((encoder_features, x), dim=1)
else:
return encoder_features + x
class Upsampling(nn.Module):
"""
Upsamples a given multi-channel 3D data using either interpolation or learned transposed convolution.
Args:
transposed_conv (bool): if True uses ConvTranspose3d for upsampling, otherwise uses interpolation
concat_joining (bool): if True uses concatenation joining between encoder and decoder features, otherwise
uses summation joining (see Residual U-Net)
in_channels (int): number of input channels for transposed conv
out_channels (int): number of output channels for transpose conv
kernel_size (int or tuple): size of the convolving kernel
scale_factor (int or tuple): stride of the convolution
mode (str): algorithm used for upsampling:
'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest'
"""
def __init__(self, transposed_conv, in_channels=None, out_channels=None, kernel_size=3,
scale_factor=(2, 2, 2), mode='nearest'):
super(Upsampling, self).__init__()
if transposed_conv:
# make sure that the output size reverses the MaxPool3d from the corresponding encoder
# (D_out=(D_in1)×stride[0]2×padding[0]+kernel_size[0]+output_padding[0])
self.upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor,
padding=1)
else:
self.upsample = partial(self._interpolate, mode=mode)
def forward(self, encoder_features, x):
output_size = encoder_features.size()[2:]
return self.upsample(x, output_size)
@staticmethod
def _interpolate(x, size, mode):
return F.interpolate(x, size=size, mode=mode)
class FinalConv(nn.Sequential):
"""
A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution
which reduces the number of channels to 'out_channels'.
with the number of output channels 'out_channels // 2' and 'out_channels' respectively.
We use (Conv3d+ReLU+GroupNorm3d) by default.
This can be change however by providing the 'order' argument, e.g. in order
to change to Conv3d+BatchNorm3d+ReLU use order='cbr'.
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
kernel_size (int): size of the convolving kernel
order (string): determines the order of layers, e.g.
'cr' -> conv + ReLU
'crg' -> conv + ReLU + groupnorm
num_groups (int): number of groups for the GroupNorm
"""
def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8):
super(FinalConv, self).__init__()
# conv1
self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups))
# in the last layer a 1×1 convolution reduces the number of output channels to out_channels
final_conv = nn.Conv3d(in_channels, out_channels, 1)
self.add_module('final_conv', final_conv)
class Abstract3DUNet(nn.Module):
"""
Base class for standard and residual UNet.
Args:
in_channels (int): number of input channels
out_channels (int): number of output segmentation masks;
Note that that the of out_channels might correspond to either
different semantic classes or to different binary segmentation mask.
It's up to the user of the class to interpret the out_channels and
use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class)
or BCEWithLogitsLoss (two-class) respectively)
f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number
of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4
final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the
final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used
to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model.
basic_module: basic model for the encoder/decoder (DoubleConv, ExtResNetBlock, ....)
layer_order (string): determines the order of layers
in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d.
See `SingleConv` for more info
f_maps (int, tuple): if int: number of feature maps in the first conv layer of the encoder (default: 64);
if tuple: number of feature maps at each level
num_groups (int): number of groups for the GroupNorm
num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int)
is_segmentation (bool): if True (semantic segmentation problem) Sigmoid/Softmax normalization is applied
after the final convolution; if False (regression problem) the normalization layer is skipped at the end
testing (bool): if True (testing mode) the `final_activation` (if present, i.e. `is_segmentation=true`)
will be applied as the last operation during the forward pass; if False the model is in training mode
and the `final_activation` (even if present) won't be applied; default: False
"""
def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=4, is_segmentation=False, testing=False, pe_freq=0, **kwargs):
super(Abstract3DUNet, self).__init__()
self.testing = testing
if isinstance(f_maps, int):
f_maps = number_of_features_per_level(f_maps, num_levels=num_levels)
# create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)`
self.embed_fn = None
if pe_freq > 0:
embed_fn, input_ch = get_embedder(pe_freq, d_in=in_channels)
self.embed_fn = embed_fn
in_channels = input_ch
encoders = []
for i, out_feature_num in enumerate(f_maps):
if i == 0:
encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=basic_module,
conv_layer_order=layer_order, num_groups=num_groups)
else:
# TODO: adapt for anisotropy in the data, i.e. use proper pooling kernel to make the data isotropic after 1-2 pooling operations
# currently pools with a constant kernel: (2, 2, 2)
encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=basic_module,
conv_layer_order=layer_order, num_groups=num_groups)
encoders.append(encoder)
self.encoders = nn.ModuleList(encoders)
# create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1`
decoders = []
reversed_f_maps = list(reversed(f_maps))
for i in range(len(reversed_f_maps) - 1):
if basic_module == DoubleConv:
in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1]
else:
in_feature_num = reversed_f_maps[i]
out_feature_num = reversed_f_maps[i + 1]
# TODO: if non-standard pooling was used, make sure to use correct striding for transpose conv
# currently strides with a constant stride: (2, 2, 2)
decoder = Decoder(in_feature_num, out_feature_num, basic_module=basic_module,
conv_layer_order=layer_order, num_groups=num_groups)
decoders.append(decoder)
self.decoders = nn.ModuleList(decoders)
# in the last layer a 1×1 convolution reduces the number of output
# channels to the number of labels
self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)
if is_segmentation:
# semantic segmentation problem
if final_sigmoid:
self.final_activation = nn.Sigmoid()
else:
self.final_activation = nn.Softmax(dim=1)
else:
# regression problem
self.final_activation = None
def forward(self, x):
if self.embed_fn is not None:
x = self.embed_fn(x.permute(0, 2, 3, 4, 1))
x = x.permute(0, 4, 1, 2, 3)
# encoder part
encoders_features = []
for encoder in self.encoders:
x = encoder(x)
# reverse the encoder outputs to be aligned with the decoder
encoders_features.insert(0, x)
# remove the last encoder's output from the list
# !!remember: it's the 1st in the list
encoders_features = encoders_features[1:]
# decoder part
for decoder, encoder_features in zip(self.decoders, encoders_features):
# pass the output from the corresponding encoder and the output
# of the previous decoder
x = decoder(encoder_features, x)
x = self.final_conv(x)
# apply final_activation (i.e. Sigmoid or Softmax) only during prediction. During training the network outputs
# logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric
if self.testing and self.final_activation is not None:
x = self.final_activation(x)
return x
class UNet3D(Abstract3DUNet):
"""
3DUnet model from
`"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation"
<https://arxiv.org/pdf/1606.06650.pdf>`.
Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder
"""
def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=4, is_segmentation=True, **kwargs):
super(UNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid,
basic_module=DoubleConv, f_maps=f_maps, layer_order=layer_order,
num_groups=num_groups, num_levels=num_levels, is_segmentation=is_segmentation,
**kwargs)
class ResidualUNet3D(Abstract3DUNet):
"""
Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
Uses ExtResNetBlock as a basic building block, summation joining instead
of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts).
Since the model effectively becomes a residual net, in theory it allows for deeper UNet.
"""
def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=5, is_segmentation=True, **kwargs):
super(ResidualUNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels,
final_sigmoid=final_sigmoid,
basic_module=ExtResNetBlock, f_maps=f_maps, layer_order=layer_order,
num_groups=num_groups, num_levels=num_levels,
is_segmentation=is_segmentation,
**kwargs)
def get_model(config):
def _model_class(class_name):
m = importlib.import_module('pytorch3dunet.unet3d.model')
clazz = getattr(m, class_name)
return clazz
assert 'model' in config, 'Could not find model configuration'
model_config = config['model']
model_class = _model_class(model_config['name'])
return model_class(**model_config)
if __name__ == "__main__":
"""
testing
"""
in_channels = 1
out_channels = 1
f_maps = 32
num_levels = 2
model = UNet3D(in_channels, out_channels, f_maps=f_maps, num_levels=num_levels, layer_order='cr')
print(model)
print('number of parameters: ', sum(p.numel() for p in model.parameters()))
reso = 18
import numpy as np
import torch
x = np.zeros((1, 1, reso, reso, reso))
x[:,:, int(reso/2-1), int(reso/2-1), int(reso/2-1)] = np.nan
x = torch.FloatTensor(x)
out = model(x)
print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso*reso)))

167
src/network/utils.py Normal file
View file

@ -0,0 +1,167 @@
""" Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """
import torch
import torch.nn as nn
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x: x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn,
freq=freq: p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, d_in=3):
embed_kwargs = {
'include_input': True,
'input_dims': d_in,
'max_freq_log2': multires-1,
'num_freqs': multires,
'log_sampling': True,
'periodic_fns': [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
def embed(x, eo=embedder_obj): return eo.embed(x)
return embed, embedder_obj.out_dim
def normalize_coordinate(p, plane='xz'):
''' Normalize coordinate to [0, 1] for unit cube experiments
Args:
p (tensor): point
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
plane (str): plane feature type, ['xz', 'xy', 'yz']
'''
if plane == 'xz':
xy = p[:, :, [0, 2]]
elif plane =='xy':
xy = p[:, :, [0, 1]]
else:
xy = p[:, :, [1, 2]]
xy_new = xy
# f there are outliers out of the range
if xy_new.max() >= 1:
xy_new[xy_new >= 1] = 1 - 10e-6
if xy_new.min() < 0:
xy_new[xy_new < 0] = 0.0
return xy_new
def normalize_3d_coordinate(p):
''' Normalize coordinate to [0, 1] for unit cube experiments.
'''
if p.max() >= 1:
p[p >= 1] = 1 - 10e-6
if p.min() < 0:
p[p < 0] = 0.0
return p
def coordinate2index(x, reso, coord_type='2d'):
''' Normalize coordinate to [0, 1] for unit cube experiments.
Corresponds to our 3D model
Args:
x (tensor): coordinate
reso (int): defined resolution
coord_type (str): coordinate type
'''
x = (x * reso).long()
if coord_type == '2d': # plane
index = x[:, :, 0] + reso * x[:, :, 1]
elif coord_type == '3d': # grid
index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2])
index = index[:, None, :]
return index
class map2local(object):
''' Add new keys to the given input
Args:
s (float): the defined voxel size
pos_encoding (str): method for the positional encoding, linear|sin_cos
'''
def __init__(self, s, pos_encoding='linear'):
super().__init__()
self.s = s
# self.pe = positional_encoding(basis_function=pos_encoding, local=True)
def __call__(self, p):
# p = torch.remainder(p, self.s) / self.s # always possitive
p = (p % self.s) / self.s
p[p < 0] = 0.0
# p = torch.fmod(p, self.s) / self.s # same sign as input p!
# p = self.pe(p)
return p
# Resnet Blocks
class ResnetBlockFC(nn.Module):
''' Fully connected ResNet Block class.
Args:
size_in (int): input dimension
size_out (int): output dimension
size_h (int): hidden dimension
'''
def __init__(self, size_in, size_out=None, size_h=None, siren=False):
super().__init__()
# Attributes
if size_out is None:
size_out = size_in
if size_h is None:
size_h = min(size_in, size_out)
self.size_in = size_in
self.size_h = size_h
self.size_out = size_out
# Submodules
self.fc_0 = nn.Linear(size_in, size_h)
self.fc_1 = nn.Linear(size_h, size_out)
self.actvn = nn.ReLU()
if size_in == size_out:
self.shortcut = None
else:
self.shortcut = nn.Linear(size_in, size_out, bias=False)
# Initialization
nn.init.zeros_(self.fc_1.weight)
def forward(self, x):
net = self.fc_0(self.actvn(x))
dx = self.fc_1(self.actvn(net))
if self.shortcut is not None:
x_s = self.shortcut(x)
else:
x_s = x
return x_s + dx

349
src/optimization.py Normal file
View file

@ -0,0 +1,349 @@
import time, os
import numpy as np
import torch
from torch.nn import functional as F
import trimesh
from src.dpsr import DPSR
from src.model import PSR2Mesh
from src.utils import grid_interp, verts_on_largest_mesh,\
export_pointcloud, mc_from_psr, GaussianSmoothing
from src.visualize import visualize_points_mesh, visualize_psr_grid, \
visualize_mesh_phong, render_rgb
from torchvision.utils import save_image
from torchvision.io import write_video
from pytorch3d.loss import chamfer_distance
import open3d as o3d
class Trainer(object):
'''
Args:
cfg : config file
optimizer : pytorch optimizer object
device : pytorch device
'''
def __init__(self, cfg, optimizer, device=None):
self.optimizer = optimizer
self.device = device
self.cfg = cfg
self.psr2mesh = PSR2Mesh.apply
self.data_type = cfg['data']['data_type']
# initialize DPSR
self.dpsr = DPSR(res=(cfg['model']['grid_res'],
cfg['model']['grid_res'],
cfg['model']['grid_res']),
sig=cfg['model']['psr_sigma'])
if torch.cuda.device_count() > 1:
self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR
self.dpsr = self.dpsr.to(device)
def train_step(self, data, inputs, model, it):
''' Performs a training step.
Args:
data (dict) : data dictionary
inputs (torch.tensor) : input point clouds
model (nn.Module or None): a neural network or None
it (int) : the number of iterations
'''
self.optimizer.zero_grad()
loss, loss_each = self.compute_loss(inputs, data, model, it)
loss.backward()
self.optimizer.step()
return loss.item(), loss_each
def compute_loss(self, inputs, data, model, it=0):
''' Compute the loss.
Args:
data (dict) : data dictionary
inputs (torch.tensor) : input point clouds
model (nn.Module or None): a neural network or None
it (int) : the number of iterations
'''
device = self.device
res = self.cfg['model']['grid_res']
# source oriented point clouds to PSR grid
psr_grid, points, normals = self.pcl2psr(inputs)
# build mesh
v, f, n = self.psr2mesh(psr_grid)
# the output is in the range of [0, 1), we make it to the real range [0, 1].
# This is a hack for our DPSR solver
v = v * res / (res-1)
points = points * 2. - 1.
v = v * 2. - 1. # within the range of (-1, 1)
loss = 0
loss_each = {}
# compute loss
if self.data_type == 'point':
if self.cfg['train']['w_chamfer'] > 0:
loss_ = self.cfg['train']['w_chamfer'] * \
self.compute_3d_loss(v, data)
loss_each['chamfer'] = loss_
loss += loss_
elif self.data_type == 'img':
loss, loss_each = self.compute_2d_loss(inputs, data, model)
return loss, loss_each
def pcl2psr(self, inputs):
''' Convert an oriented point cloud to PSR indicator grid
Args:
inputs (torch.tensor): input oriented point clouds
'''
points, normals = inputs[...,:3], inputs[...,3:]
if self.cfg['model']['apply_sigmoid']:
points = torch.sigmoid(points)
if self.cfg['model']['normal_normalize']:
normals = normals / normals.norm(dim=-1, keepdim=True)
# DPSR to get grid
psr_grid = self.dpsr(points, normals).unsqueeze(1)
psr_grid = torch.tanh(psr_grid)
return psr_grid, points, normals
def compute_3d_loss(self, v, data):
''' Compute the loss for point clouds.
Args:
v (torch.tensor) : mesh vertices
data (dict) : data dictionary
'''
pts_gt = data.get('target_points')
idx = np.random.randint(pts_gt.shape[1], size=self.cfg['train']['n_sup_point'])
if self.cfg['train']['subsample_vertex']:
#chamfer distance only on random sampled vertices
idx = np.random.randint(v.shape[1], size=self.cfg['train']['n_sup_point'])
loss, _ = chamfer_distance(v[:, idx], pts_gt)
else:
loss, _ = chamfer_distance(v, pts_gt)
return loss
def compute_2d_loss(self, inputs, data, model):
''' Compute the 2D losses.
Args:
inputs (torch.tensor) : input source point clouds
data (dict) : data dictionary
model (nn.Module or None): neural network or None
'''
losses = {"color":
{"weight": self.cfg['train']['l_weight']['rgb'],
"values": []
},
"silhouette":
{"weight": self.cfg['train']['l_weight']['mask'],
"values": []},
}
loss_all = {k: torch.tensor(0.0, device=self.device) for k in losses}
# forward pass
out = model(inputs, data)
if out['rgb'] is not None:
rgb_gt = out['rgb_gt'].reshape(self.cfg['data']['n_views_per_iter'],
-1, 3)[out['vis_mask']]
loss_all["color"] += torch.nn.L1Loss(reduction='sum')(rgb_gt,
out['rgb']) / out['rgb'].shape[0]
if out['mask'] is not None:
loss_all["silhouette"] += ((out['mask'] - out['mask_gt']) ** 2).mean()
# weighted sum of the losses
loss = torch.tensor(0.0, device=self.device)
for k, l in loss_all.items():
loss += l * losses[k]["weight"]
losses[k]["values"].append(l)
return loss, loss_all
def point_resampling(self, inputs):
''' Resample points
Args:
inputs (torch.tensor): oriented point clouds
'''
psr_grid, points, normals = self.pcl2psr(inputs)
# shortcuts
n_grow = self.cfg['train']['n_grow_points']
# [hack] for points resampled from the mesh from marching cubes,
# we need to divide by s instead of (s-1), and the scale is correct.
verts, faces, _ = mc_from_psr(psr_grid, real_scale=False, zero_level=0)
# find the largest component
pts_mesh, faces_mesh = verts_on_largest_mesh(verts, faces)
# sample vertices only from the largest component, not from fragments
mesh = trimesh.Trimesh(vertices=pts_mesh, faces=faces_mesh)
pi, face_idx = mesh.sample(n_grow+points.shape[1], return_index=True)
normals_i = mesh.face_normals[face_idx].astype('float32')
pts_mesh = torch.tensor(pi.astype('float32')).to(self.device)[None]
n_mesh = torch.tensor(normals_i).to(self.device)[None]
points, normals = pts_mesh, n_mesh
print('{} total points are resampled'.format(points.shape[1]))
# update inputs
points = torch.log(points / (1 - points)) # inverse sigmoid
inputs = torch.cat([points, normals], dim=-1)
inputs.requires_grad = True
return inputs
def visualize(self, data, inputs, renderer, epoch, o3d_vis=None):
''' Visualization.
Args:
data (dict) : data dictionary
inputs (torch.tensor) : source point clouds
renderer (nn.Module or None): a neural network or None
epoch (int) : the number of iterations
o3d_vis (o3d.Visualizer) : open3d visualizer
'''
data_type = self.cfg['data']['data_type']
it = '{:04d}'.format(int(epoch/self.cfg['train']['visualize_every']))
if (self.cfg['train']['exp_mesh']) \
| (self.cfg['train']['exp_pcl']) \
| (self.cfg['train']['o3d_show']):
psr_grid, points, normals = self.pcl2psr(inputs)
with torch.no_grad():
v, f, n = mc_from_psr(psr_grid, pytorchify=True,
zero_level=self.cfg['data']['zero_level'], real_scale=True)
v, f, n = v[None], f[None], n[None]
v = v * 2. - 1. # change to the range of [-1, 1]
color_v = None
if data_type == 'img':
if self.cfg['train']['vis_vert_color'] & \
(self.cfg['train']['l_weight']['rgb'] != 0.):
color_v = renderer['color'](v, n).squeeze().detach().cpu().numpy()
color_v[color_v<0], color_v[color_v>1] = 0., 1.
vv = v.detach().squeeze().cpu().numpy()
ff = f.detach().squeeze().cpu().numpy()
points = points * 2 - 1
visualize_points_mesh(o3d_vis, points, normals,
vv, ff, self.cfg, it, epoch, color_v=color_v)
else:
v, f, n = inputs
if (data_type == 'img') & (self.cfg['train']['vis_rendering']):
pred_imgs = []
pred_masks = []
n_views = len(data['poses'])
# idx_list = trange(n_views)
idx_list = [13, 24, 27, 48]
#!
model = renderer.eval()
for idx in idx_list:
pose = data['poses'][idx]
rgb = data['rgbs'][idx]
mask_gt = data['masks'][idx]
img_size = rgb.shape[0] if rgb.shape[0]== rgb.shape[1] else (rgb.shape[0], rgb.shape[1])
ray = None
if 'rays' in data.keys():
ray = data['rays'][idx]
if self.cfg['train']['l_weight']['rgb'] != 0.:
fea_grid = None
if model.unet3d is not None:
with torch.no_grad():
fea_grid = model.unet3d(psr_grid).permute(0, 2, 3, 4, 1)
if model.encoder is not None:
pp = torch.cat([(points+1)/2, normals], dim=-1)
fea_grid = model.encoder(pp,
normalize=False).permute(0, 2, 3, 4, 1)
pred, visible_mask = render_rgb(v, f, n, pose,
model.rendering_network.eval(),
img_size, ray=ray, fea_grid=fea_grid)
img_pred = torch.zeros([rgb.shape[0]*rgb.shape[1], 3])
img_pred[visible_mask] = pred.detach().cpu()
img_pred = img_pred.reshape(rgb.shape[0], rgb.shape[1], 3)
img_pred[img_pred<0], img_pred[img_pred>1] = 0., 1.
filename=os.path.join(self.cfg['train']['dir_rendering'],
'rendering_{}_{:d}.png'.format(it, idx))
save_image(img_pred.permute(2, 0, 1), filename)
pred_imgs.append(img_pred)
#! Mesh rendering using Phong shading model
filename=os.path.join(self.cfg['train']['dir_rendering'],
'mesh_{}_{:d}.png'.format(it, idx))
visualize_mesh_phong(v, f, n, pose, img_size, name=filename)
if len(pred_imgs) >= 1:
pred_imgs = torch.stack(pred_imgs, dim=0)
save_image(pred_imgs.permute(0, 3, 1, 2),
os.path.join(self.cfg['train']['dir_rendering'],
'{}.png'.format(it)), nrow=4)
if self.cfg['train']['save_video']:
write_video(os.path.join(self.cfg['train']['dir_rendering'],
'{}.mp4'.format(it)),
(pred_imgs*255.).type(torch.uint8), fps=24)
def save_mesh_pointclouds(self, inputs, epoch, center=None, scale=None):
''' Save meshes and point clouds.
Args:
inputs (torch.tensor) : source point clouds
epoch (int) : the number of iterations
center (numpy.array) : center of the shape
scale (numpy.array) : scale of the shape
'''
exp_pcl = self.cfg['train']['exp_pcl']
exp_mesh = self.cfg['train']['exp_mesh']
psr_grid, points, normals = self.pcl2psr(inputs)
if exp_pcl:
dir_pcl = self.cfg['train']['dir_pcl']
p = points.squeeze(0).detach().cpu().numpy()
p = p * 2 - 1
n = normals.squeeze(0).detach().cpu().numpy()
if scale is not None:
p *= scale
if center is not None:
p += center
export_pointcloud(os.path.join(dir_pcl, '{:04d}.ply'.format(epoch)), p, n)
if exp_mesh:
dir_mesh = self.cfg['train']['dir_mesh']
with torch.no_grad():
v, f, _ = mc_from_psr(psr_grid,
zero_level=self.cfg['data']['zero_level'], real_scale=True)
v = v * 2 - 1
if scale is not None:
v *= scale
if center is not None:
v += center
mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(v)
mesh.triangles = o3d.utility.Vector3iVector(f)
outdir_mesh = os.path.join(dir_mesh, '{:04d}.ply'.format(epoch))
o3d.io.write_triangle_mesh(outdir_mesh, mesh)
if self.cfg['train']['vis_psr']:
dir_psr_vis = self.cfg['train']['out_dir']+'/psr_vis_all'
visualize_psr_grid(psr_grid, out_dir=dir_psr_vis)

207
src/training.py Normal file
View file

@ -0,0 +1,207 @@
import os
import numpy as np
import torch
from torch.nn import functional as F
from collections import defaultdict
import trimesh
from tqdm import tqdm
from src.dpsr import DPSR
from src.utils import grid_interp, export_pointcloud, export_mesh, \
mc_from_psr, scale2onet, GaussianSmoothing
from pytorch3d.ops.knn import knn_gather, knn_points
from pytorch3d.loss import chamfer_distance
from pdb import set_trace as st
class Trainer(object):
'''
Args:
model (nn.Module): our defined model
optimizer (optimizer): pytorch optimizer object
device (device): pytorch device
input_type (str): input type
vis_dir (str): visualization directory
'''
def __init__(self, cfg, optimizer, device=None):
self.optimizer = optimizer
self.device = device
self.cfg = cfg
if self.cfg['train']['w_raw'] != 0:
from src.model import PSR2Mesh
self.psr2mesh = PSR2Mesh.apply
# initialize DPSR
self.dpsr = DPSR(res=(cfg['model']['grid_res'],
cfg['model']['grid_res'],
cfg['model']['grid_res']),
sig=cfg['model']['psr_sigma'])
if torch.cuda.device_count() > 1:
self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR
self.dpsr = self.dpsr.to(device)
if cfg['train']['gauss_weight']>0.:
self.gauss_smooth = GaussianSmoothing(1, 7, 2).to(device)
def train_step(self, inputs, data, model):
''' Performs a training step.
Args:
data (dict): data dictionary
'''
self.optimizer.zero_grad()
p = data.get('inputs').to(self.device)
out = model(p)
points, normals = out
loss = 0
loss_each = {}
if self.cfg['train']['w_psr'] != 0:
psr_gt = data.get('gt_psr').to(self.device)
if self.cfg['model']['psr_tanh']:
psr_gt = torch.tanh(psr_gt)
psr_grid = self.dpsr(points, normals)
if self.cfg['model']['psr_tanh']:
psr_grid = torch.tanh(psr_grid)
# apply a rescaling weight based on GT SDF values
if self.cfg['train']['gauss_weight']>0:
gauss_sigma = self.cfg['train']['gauss_weight']
# set up the weighting for loss, higher weights
# for points near to the surface
psr_gt_pad = torch.nn.ReplicationPad3d(1)(psr_gt.unsqueeze(1)).squeeze(1)
delta_x = delta_y = delta_z = 1
grad_x = (psr_gt_pad[:, 2:, :, :] - psr_gt_pad[:, :-2, :, :]) / 2 / delta_x
grad_y = (psr_gt_pad[:, :, 2:, :] - psr_gt_pad[:, :, :-2, :]) / 2 / delta_y
grad_z = (psr_gt_pad[:, :, :, 2:] - psr_gt_pad[:, :, :, :-2]) / 2 / delta_z
grad_x = grad_x[:, :, 1:-1, 1:-1]
grad_y = grad_y[:, 1:-1, :, 1:-1]
grad_z = grad_z[:, 1:-1, 1:-1, :]
psr_grad = torch.stack([grad_x, grad_y, grad_z], dim=-1)
psr_grad_norm = psr_grad.norm(dim=-1)[:, None]
w = torch.nn.ReplicationPad3d(3)(psr_grad_norm)
w = 2*self.gauss_smooth(w).squeeze(1)
loss_each['psr'] = self.cfg['train']['w_psr'] * F.mse_loss(w*psr_grid, w*psr_gt)
else:
loss_each['psr'] = self.cfg['train']['w_psr'] * F.mse_loss(psr_grid, psr_gt)
loss += loss_each['psr']
# regularization on the input point positions via chamfer distance
if self.cfg['train']['w_reg_point'] != 0.:
points_gt = data.get('gt_points').to(self.device)
loss_reg, loss_norm = chamfer_distance(points, points_gt)
loss_each['reg'] = self.cfg['train']['w_reg_point'] * loss_reg
loss += loss_each['reg']
if self.cfg['train']['w_normals'] != 0.:
points_gt = data.get('gt_points').to(self.device)
normals_gt = data.get('gt_points.normals').to(self.device)
x_nn = knn_points(points, points_gt, K=1)
x_normals_near = knn_gather(normals_gt, x_nn.idx)[..., 0, :]
cham_norm_x = F.l1_loss(normals, x_normals_near)
loss_norm = cham_norm_x
loss_each['normals'] = self.cfg['train']['w_normals'] * loss_norm
loss += loss_each['normals']
if self.cfg['train']['w_raw'] != 0:
res = self.cfg['model']['grid_res']
# DPSR to get grid
psr_grid = self.dpsr(points, normals)
if self.cfg['model']['psr_tanh']:
psr_grid = torch.tanh(psr_grid)
v, f, n = self.psr2mesh(psr_grid)
pts_gt = data.get('gt_points').to(self.device)
loss, _ = chamfer_distance(v, pts_gt)
loss.backward()
self.optimizer.step()
return loss.item(), loss_each
def save(self, model, data, epoch, id):
p = data.get('inputs').to(self.device)
exp_pcl = self.cfg['train']['exp_pcl']
exp_mesh = self.cfg['train']['exp_mesh']
exp_gt = self.cfg['generation']['exp_gt']
exp_input = self.cfg['generation']['exp_input']
model.eval()
with torch.no_grad():
points, normals = model(p)
if exp_gt:
points_gt = data.get('gt_points').to(self.device)
normals_gt = data.get('gt_points.normals').to(self.device)
if exp_pcl:
dir_pcl = self.cfg['train']['dir_pcl']
export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}.ply'.format(epoch, id)), scale2onet(points), normals)
if exp_gt:
export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}_oracle.ply'.format(epoch, id)), scale2onet(points_gt), normals_gt)
if exp_input:
export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}_input.ply'.format(epoch, id)), scale2onet(p))
if exp_mesh:
dir_mesh = self.cfg['train']['dir_mesh']
psr_grid = self.dpsr(points, normals)
# psr_grid = torch.tanh(psr_grid)
with torch.no_grad():
v, f, _ = mc_from_psr(psr_grid,
zero_level=self.cfg['data']['zero_level'])
outdir_mesh = os.path.join(dir_mesh, '{:04d}_{:01d}.ply'.format(epoch, id))
export_mesh(outdir_mesh, scale2onet(v), f)
if exp_gt:
psr_gt = self.dpsr(points_gt, normals_gt)
with torch.no_grad():
v, f, _ = mc_from_psr(psr_gt,
zero_level=self.cfg['data']['zero_level'])
export_mesh(os.path.join(dir_mesh, '{:04d}_{:01d}_oracle.ply'.format(epoch, id)), scale2onet(v), f)
def evaluate(self, val_loader, model):
''' Performs an evaluation.
Args:
val_loader (dataloader): pytorch dataloader
'''
eval_list = defaultdict(list)
for data in tqdm(val_loader):
eval_step_dict = self.eval_step(data, model)
for k, v in eval_step_dict.items():
eval_list[k].append(v)
eval_dict = {k: np.mean(v) for k, v in eval_list.items()}
return eval_dict
def eval_step(self, data, model):
''' Performs an evaluation step.
Args:
data (dict): data dictionary
'''
model.eval()
eval_dict = {}
p = data.get('inputs').to(self.device)
psr_gt = data.get('gt_psr').to(self.device)
with torch.no_grad():
# forward pass
points, normals = model(p)
# DPSR to get predicted psr grid
psr_grid = self.dpsr(points, normals)
eval_dict['psr_l1'] = F.l1_loss(psr_grid, psr_gt).item()
eval_dict['psr_l2'] = F.mse_loss(psr_grid, psr_gt).item()
return eval_dict

645
src/utils.py Normal file
View file

@ -0,0 +1,645 @@
import torch
import io, os, logging, urllib
import yaml
import trimesh
import imageio
import numbers
import math
import numpy as np
from collections import OrderedDict
from plyfile import PlyData
from torch import nn
from torch.nn import functional as F
from torch.utils import model_zoo
from skimage import measure, img_as_float32
from pytorch3d.structures import Meshes
from pytorch3d.renderer import PerspectiveCameras, rasterize_meshes
from igl import adjacency_matrix, connected_components
import open3d as o3d
##################################################
# Below are functions for DPSR
def fftfreqs(res, dtype=torch.float32, exact=True):
"""
Helper function to return frequency tensors
:param res: n_dims int tuple of number of frequency modes
:return:
"""
n_dims = len(res)
freqs = []
for dim in range(n_dims - 1):
r_ = res[dim]
freq = np.fft.fftfreq(r_, d=1/r_)
freqs.append(torch.tensor(freq, dtype=dtype))
r_ = res[-1]
if exact:
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype))
else:
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype))
omega = torch.meshgrid(freqs)
omega = list(omega)
omega = torch.stack(omega, dim=-1)
return omega
def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag)
"""
multiply tensor x by i ** deg
"""
deg %= 4
if deg == 0:
res = x
elif deg == 1:
res = x[..., [1, 0]]
res[..., 0] = -res[..., 0]
elif deg == 2:
res = -x
elif deg == 3:
res = x[..., [1, 0]]
res[..., 1] = -res[..., 1]
return res
def spec_gaussian_filter(res, sig):
omega = fftfreqs(res, dtype=torch.float64) # [dim0, dim1, dim2, d]
dis = torch.sqrt(torch.sum(omega ** 2, dim=-1))
filter_ = torch.exp(-0.5*((sig*2*dis/res[0])**2)).unsqueeze(-1).unsqueeze(-1)
filter_.requires_grad = False
return filter_
def grid_interp(grid, pts, batched=True):
"""
:param grid: tensor of shape (batch, *size, in_features)
:param pts: tensor of shape (batch, num_points, dim) within range (0, 1)
:return values at query points
"""
if not batched:
grid = grid.unsqueeze(0)
pts = pts.unsqueeze(0)
dim = pts.shape[-1]
bs = grid.shape[0]
size = torch.tensor(grid.shape[1:-1]).to(grid.device).type(pts.dtype)
cubesize = 1.0 / size
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
tmp = torch.tensor([0,1],dtype=torch.long)
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
# latent code on neighbor nodes
if dim == 2:
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1]] # (batch, num_points, 2**dim, in_features)
else:
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2]] # (batch, num_points, 2**dim, in_features)
# weights of neighboring nodes
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
pos_ = pos_.type(pts.dtype)
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
query_values = torch.sum(lat * weights.unsqueeze(-1), dim=-2) # (batch, num_points, in_features)
if not batched:
query_values = query_values.squeeze(0)
return query_values
def scatter_to_grid(inds, vals, size):
"""
Scatter update values into empty tensor of size size.
:param inds: (#values, dims)
:param vals: (#values)
:param size: tuple for size. len(size)=dims
"""
dims = inds.shape[1]
assert(inds.shape[0] == vals.shape[0])
assert(len(size) == dims)
dev = vals.device
# result = torch.zeros(*size).view(-1).to(dev).type(vals.dtype) # flatten
# # flatten inds
result = torch.zeros(*size, device=dev).view(-1).type(vals.dtype) # flatten
# flatten inds
fac = [np.prod(size[i+1:]) for i in range(len(size)-1)] + [1]
fac = torch.tensor(fac, device=dev).type(inds.dtype)
inds_fold = torch.sum(inds*fac, dim=-1) # [#values,]
result.scatter_add_(0, inds_fold, vals)
result = result.view(*size)
return result
def point_rasterize(pts, vals, size):
"""
:param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1)
:param vals: point values, tensor of shape (batch, num_points, features)
:param size: len(size)=dim tuple for grid size
:return rasterized values (batch, features, res0, res1, res2)
"""
dim = pts.shape[-1]
assert(pts.shape[:2] == vals.shape[:2])
assert(pts.shape[2] == dim)
size_list = list(size)
size = torch.tensor(size).to(pts.device).float()
cubesize = 1.0 / size
bs = pts.shape[0]
nf = vals.shape[-1]
npts = pts.shape[1]
dev = pts.device
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
tmp = torch.tensor([0,1],dtype=torch.long)
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
# ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
ind_b = torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
# weights of neighboring nodes
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
pos_ = pos_.type(pts.dtype)
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
ind_b = ind_b.unsqueeze(-1).unsqueeze(-1) # (batch, num_points, 2**dim, 1, 1)
ind_n = ind_n.unsqueeze(-2) # (batch, num_points, 2**dim, 1, dim)
ind_f = torch.arange(nf, device=dev).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1)
# ind_f = torch.arange(nf).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1)
ind_b = ind_b.expand(bs, npts, 2**dim, nf, 1)
ind_n = ind_n.expand(bs, npts, 2**dim, nf, dim).to(dev)
ind_f = ind_f.expand(bs, npts, 2**dim, nf, 1)
inds = torch.cat([ind_b, ind_f, ind_n], dim=-1) # (batch, num_points, 2**dim, nf, 1+1+dim)
# weighted values
vals = weights.unsqueeze(-1) * vals.unsqueeze(-2) # (batch, num_points, 2**dim, nf)
inds = inds.view(-1, dim+2).permute(1, 0).long() # (1+dim+1, bs*npts*2**dim*nf)
vals = vals.reshape(-1) # (bs*npts*2**dim*nf)
tensor_size = [bs, nf] + size_list
raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf] + size_list)
return raster # [batch, nf, res, res, res]
##################################################
# Below are the utilization functions in general
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.n = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.n = n
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
@property
def valcavg(self):
return self.val.sum().item() / (self.n != 0).sum().item()
@property
def avgcavg(self):
return self.avg.sum().item() / (self.count != 0).sum().item()
def load_model_manual(state_dict, model):
new_state_dict = OrderedDict()
is_model_parallel = isinstance(model, torch.nn.DataParallel)
for k, v in state_dict.items():
if k.startswith('module.') != is_model_parallel:
if k.startswith('module.'):
# remove module
k = k[7:]
else:
# add module
k = 'module.' + k
new_state_dict[k]=v
model.load_state_dict(new_state_dict)
def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0):
'''
Run marching cubes from PSR grid
'''
batch_size = psr_grid.shape[0]
s = psr_grid.shape[-1] # size of psr_grid
psr_grid_numpy = psr_grid.squeeze().detach().cpu().numpy()
if batch_size>1:
verts, faces, normals = [], [], []
for i in range(batch_size):
verts_cur, faces_cur, normals_cur, values = measure.marching_cubes(psr_grid_numpy[i], level=0)
verts.append(verts_cur)
faces.append(faces_cur)
normals.append(normals_cur)
verts = np.stack(verts, axis = 0)
faces = np.stack(faces, axis = 0)
normals = np.stack(normals, axis = 0)
else:
try:
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy, level=zero_level)
except:
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy)
if real_scale:
verts = verts / (s-1) # scale to range [0, 1]
else:
verts = verts / s # scale to range [0, 1)
if pytorchify:
device = psr_grid.device
verts = torch.Tensor(np.ascontiguousarray(verts)).to(device)
faces = torch.Tensor(np.ascontiguousarray(faces)).to(device)
normals = torch.Tensor(np.ascontiguousarray(-normals)).to(device)
return verts, faces, normals
def calc_inters_points(verts, faces, pose, img_size, mask_gt=None):
verts = verts.squeeze()
faces = faces.squeeze()
pix_to_face, w, mask = mesh_rasterization(verts, faces, pose, img_size)
if mask_gt is not None:
#! only evaluate within the intersection
mask = mask & mask_gt
# find 3D points intesected on the mesh
if True:
w_masked = w[mask]
f_p = faces[pix_to_face[mask]].long() # cooresponding faces for each pixel
# corresponding vertices for p_closest
v_a, v_b, v_c = verts[f_p[..., 0]], verts[f_p[..., 1]], verts[f_p[..., 2]]
# calculate the intersection point of each pixel and the mesh
p_inters = w_masked[..., 0, None] * v_a + \
w_masked[..., 1, None] * v_b + \
w_masked[..., 2, None] * v_c
else:
# backproject ndc to world coordinates using z-buffer
W, H = img_size[1], img_size[0]
xy = uv.to(mask.device)[mask]
x_ndc = 1 - (2*xy[:, 0]) / (W - 1)
y_ndc = 1 - (2*xy[:, 1]) / (H - 1)
z = zbuf.squeeze().reshape(H * W)[mask]
xy_depth = torch.stack((x_ndc, y_ndc, z), dim=1)
p_inters = pose.unproject_points(xy_depth, world_coordinates=True)
# if there are outlier points, we should remove it
if (p_inters.max()>1) | (p_inters.min()<-1):
mask_bound = (p_inters>=-1) & (p_inters<=1)
mask_bound = (mask_bound.sum(dim=-1)==3)
mask[mask==True] = mask_bound
p_inters = p_inters[mask_bound]
print('!!!!!find outlier!')
return p_inters, mask, f_p, w_masked
def mesh_rasterization(verts, faces, pose, img_size):
'''
Use PyTorch3D to rasterize the mesh given a camera
'''
transformed_v = pose.transform_points(verts.detach()) # world -> ndc coordinate system
if isinstance(pose, PerspectiveCameras):
transformed_v[..., 2] = 1/transformed_v[..., 2]
# find p_closest on mesh of each pixel via rasterization
transformed_mesh = Meshes(verts=[transformed_v], faces=[faces])
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
transformed_mesh,
image_size=img_size,
blur_radius=0,
faces_per_pixel=1,
perspective_correct=False
)
pix_to_face = pix_to_face.reshape(1, -1) # B x reso x reso -> B x (reso x reso)
mask = pix_to_face.clone() != -1
mask = mask.squeeze()
pix_to_face = pix_to_face.squeeze()
w = bary_coords.reshape(-1, 3)
return pix_to_face, w, mask
def verts_on_largest_mesh(verts, faces):
'''
verts: Numpy array or Torch.Tensor (N, 3)
faces: Numpy array (N, 3)
'''
if torch.is_tensor(faces):
verts = verts.squeeze().detach().cpu().numpy()
faces = faces.squeeze().int().detach().cpu().numpy()
A = adjacency_matrix(faces)
num, conn_idx, conn_size = connected_components(A)
if num == 0:
v_large, f_large = verts, faces
else:
max_idx = conn_size.argmax() # find the index of the largest component
v_large = verts[conn_idx==max_idx] # keep points on the largest component
if True:
mesh_largest = trimesh.Trimesh(verts, faces)
connected_comp = mesh_largest.split(only_watertight=False)
mesh_largest = connected_comp[max_idx]
v_large, f_large = mesh_largest.vertices, mesh_largest.faces
v_large = v_large.astype(np.float32)
return v_large, f_large
def load_pointcloud(in_file):
plydata = PlyData.read(in_file)
vertices = np.stack([
plydata['vertex']['x'],
plydata['vertex']['y'],
plydata['vertex']['z']
], axis=1)
return vertices
# General config
def load_config(path, default_path=None):
''' Loads config file.
Args:
path (str): path to config file
default_path (bool): whether to use default path
'''
# Load configuration from file itself
with open(path, 'r') as f:
cfg_special = yaml.load(f, Loader=yaml.Loader)
# Check if we should inherit from a config
inherit_from = cfg_special.get('inherit_from')
# If yes, load this config first as default
# If no, use the default_path
if inherit_from is not None:
cfg = load_config(inherit_from, default_path)
elif default_path is not None:
with open(default_path, 'r') as f:
cfg = yaml.load(f, Loader=yaml.Loader)
else:
cfg = dict()
# Include main configuration
update_recursive(cfg, cfg_special)
return cfg
def update_config(config, unknown):
# update config given args
for idx,arg in enumerate(unknown):
if arg.startswith("--"):
keys = arg.replace("--","").split(':')
assert(len(keys)==2)
k1, k2 = keys
argtype = type(config[k1][k2])
if argtype == bool:
v = unknown[idx+1].lower() == 'true'
else:
if config[k1][k2] is not None:
v = type(config[k1][k2])(unknown[idx+1])
else:
v = unknown[idx+1]
print(f'Changing {k1}:{k2} ---- {config[k1][k2]} to {v}')
config[k1][k2] = v
return config
def initialize_logger(cfg):
out_dir = cfg['train']['out_dir']
if not out_dir:
os.makedirs(out_dir)
cfg['train']['dir_model'] = os.path.join(out_dir, 'model')
os.makedirs(cfg['train']['dir_model'], exist_ok=True)
if cfg['train']['exp_mesh']:
cfg['train']['dir_mesh'] = os.path.join(out_dir, 'vis/mesh')
os.makedirs(cfg['train']['dir_mesh'], exist_ok=True)
if cfg['train']['exp_pcl']:
cfg['train']['dir_pcl'] = os.path.join(out_dir, 'vis/pointcloud')
os.makedirs(cfg['train']['dir_pcl'], exist_ok=True)
if cfg['train']['vis_rendering']:
cfg['train']['dir_rendering'] = os.path.join(out_dir, 'vis/rendering')
os.makedirs(cfg['train']['dir_rendering'], exist_ok=True)
if cfg['train']['o3d_show']:
cfg['train']['dir_o3d'] = os.path.join(out_dir, 'vis/o3d')
os.makedirs(cfg['train']['dir_o3d'], exist_ok=True)
logger = logging.getLogger("train")
logger.setLevel(logging.DEBUG)
logger.handlers = []
# ch = logging.StreamHandler()
# logger.addHandler(ch)
fh = logging.FileHandler(os.path.join(cfg['train']['out_dir'], "log.txt"))
logger.addHandler(fh)
logger.info('Outout dir: %s', out_dir)
return logger
def update_recursive(dict1, dict2):
''' Update two config dictionaries recursively.
Args:
dict1 (dict): first dictionary to be updated
dict2 (dict): second dictionary which entries should be used
'''
for k, v in dict2.items():
if k not in dict1:
dict1[k] = dict()
if isinstance(v, dict):
update_recursive(dict1[k], v)
else:
dict1[k] = v
def export_pointcloud(name, points, normals=None):
if len(points.shape) > 2:
points = points[0]
if normals is not None:
normals = normals[0]
if isinstance(points, torch.Tensor):
points = points.detach().cpu().numpy()
if normals is not None:
normals = normals.detach().cpu().numpy()
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
if normals is not None:
pcd.normals = o3d.utility.Vector3dVector(normals)
o3d.io.write_point_cloud(name, pcd)
def export_mesh(name, v, f):
if len(v.shape) > 2:
v, f = v[0], f[0]
if isinstance(v, torch.Tensor):
v = v.detach().cpu().numpy()
f = f.detach().cpu().numpy()
mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(v)
mesh.triangles = o3d.utility.Vector3iVector(f)
o3d.io.write_triangle_mesh(name, mesh)
def scale2onet(p, scale=1.2):
'''
Scale the point cloud from SAP to ONet range
'''
return (p - 0.5) * scale
def update_optimizer(inputs, cfg, epoch, model=None, schedule=None):
if model is not None:
if schedule is not None:
optimizer = torch.optim.Adam([
{"params": model.parameters(),
"lr": schedule[0].get_learning_rate(epoch)},
{"params": inputs,
"lr": schedule[1].get_learning_rate(epoch)}])
elif 'lr' in cfg['train']:
optimizer = torch.optim.Adam([
{"params": model.parameters(),
"lr": float(cfg['train']['lr'])},
{"params": inputs,
"lr": float(cfg['train']['lr_pcl'])}])
else:
raise Exception('no known learning rate')
else:
if schedule is not None:
optimizer = torch.optim.Adam([inputs], lr=schedule[0].get_learning_rate(epoch))
else:
optimizer = torch.optim.Adam([inputs], lr=float(cfg['train']['lr_pcl']))
return optimizer
def is_url(url):
scheme = urllib.parse.urlparse(url).scheme
return scheme in ('http', 'https')
def load_url(url):
'''Load a module dictionary from url.
Args:
url (str): url to saved model
'''
print(url)
print('=> Loading checkpoint from url...')
state_dict = model_zoo.load_url(url, progress=True)
return state_dict
class GaussianSmoothing(nn.Module):
"""
Apply gaussian smoothing on a
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
in the input using a depthwise convolution.
Arguments:
channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well.
kernel_size (int, sequence): Size of the gaussian kernel.
sigma (float, sequence): Standard deviation of the gaussian kernel.
dim (int, optional): The number of dimensions of the data.
Default value is 2 (spatial).
"""
def __init__(self, channels, kernel_size, sigma, dim=3):
super(GaussianSmoothing, self).__init__()
if isinstance(kernel_size, numbers.Number):
kernel_size = [kernel_size] * dim
if isinstance(sigma, numbers.Number):
sigma = [sigma] * dim
# The gaussian kernel is the product of the
# gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid(
[
torch.arange(size, dtype=torch.float32)
for size in kernel_size
]
)
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
torch.exp(-((mgrid - mean) / std) ** 2 / 2)
# Make sure sum of values in gaussian kernel equals 1.
kernel = kernel / torch.sum(kernel)
# Reshape to depthwise convolutional weight
kernel = kernel.view(1, 1, *kernel.size())
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
self.register_buffer('weight', kernel)
self.groups = channels
if dim == 1:
self.conv = F.conv1d
elif dim == 2:
self.conv = F.conv2d
elif dim == 3:
self.conv = F.conv3d
else:
raise RuntimeError(
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
)
def forward(self, input):
"""
Apply gaussian filter to input.
Arguments:
input (torch.Tensor): Input to apply gaussian filter on.
Returns:
filtered (torch.Tensor): Filtered output.
"""
return self.conv(input, weight=self.weight, groups=self.groups)
# Originally from https://github.com/amosgropp/IGR/blob/0db06b1273/code/utils/general.py
def get_learning_rate_schedules(schedule_specs):
schedules = []
for key in schedule_specs.keys():
schedules.append(StepLearningRateSchedule(
schedule_specs[key]['initial'],
schedule_specs[key]["interval"],
schedule_specs[key]["factor"],
schedule_specs[key]["final"]))
return schedules
class LearningRateSchedule:
def get_learning_rate(self, epoch):
pass
class StepLearningRateSchedule(LearningRateSchedule):
def __init__(self, initial, interval, factor, final=1e-6):
self.initial = float(initial)
self.interval = interval
self.factor = factor
self.final = float(final)
def get_learning_rate(self, epoch):
lr = np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6)
if lr > self.final:
return lr
else:
return self.final
def adjust_learning_rate(lr_schedules, optimizer, epoch):
for i, param_group in enumerate(optimizer.param_groups):
param_group["lr"] = lr_schedules[i].get_learning_rate(epoch)

175
src/visualize.py Normal file
View file

@ -0,0 +1,175 @@
import os
import torch
import numpy as np
import open3d as o3d
import matplotlib.pyplot as plt
from skimage import measure
from src.utils import calc_inters_points, grid_interp
from scipy import ndimage
from tqdm import trange
from torchvision.utils import save_image
from pdb import set_trace as st
def visualize_points_mesh(vis, points, normals, verts, faces, cfg, it, epoch, color_v=None):
''' Visualization.
Args:
data (dict): data dictionary
depth (int): PSR depth
out_path (str): output path for the mesh
'''
mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(verts)
mesh.triangles = o3d.utility.Vector3iVector(faces)
mesh.paint_uniform_color(np.array([0.7,0.7,0.7]))
if color_v is not None:
mesh.vertex_colors = o3d.utility.Vector3dVector(color_v)
if vis is not None:
dir_o3d = cfg['train']['dir_o3d']
wire = o3d.geometry.LineSet.create_from_triangle_mesh(mesh)
p = points.squeeze(0).detach().cpu().numpy()
n = normals.squeeze(0).detach().cpu().numpy()
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(p)
pcd.normals = o3d.utility.Vector3dVector(n)
pcd.paint_uniform_color(np.array([0.7,0.7,1.0]))
# pcd = pcd.uniform_down_sample(5)
vis.clear_geometries()
vis.add_geometry(mesh)
vis.update_geometry(mesh)
#! Thingi wheel - an example for how to change cameras in Open3D viewers
vis.get_view_control().set_front([ 0.0461, -0.7467, 0.6635 ])
vis.get_view_control().set_lookat([ 0.0092, 0.0078, 0.0638 ])
vis.get_view_control().set_up([ 0.0520, 0.6651, 0.7449 ])
vis.get_view_control().set_zoom(0.7)
vis.poll_events()
out_path = os.path.join(dir_o3d, '{}.jpg'.format(it))
vis.capture_screen_image(out_path)
vis.clear_geometries()
vis.add_geometry(pcd, reset_bounding_box=False)
vis.update_geometry(pcd)
vis.get_render_option().point_show_normal=True # visualize point normals
vis.get_view_control().set_front([ 0.0461, -0.7467, 0.6635 ])
vis.get_view_control().set_lookat([ 0.0092, 0.0078, 0.0638 ])
vis.get_view_control().set_up([ 0.0520, 0.6651, 0.7449 ])
vis.get_view_control().set_zoom(0.7)
vis.poll_events()
out_path = os.path.join(dir_o3d, '{}_pcd.jpg'.format(it))
vis.capture_screen_image(out_path)
def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name='video.mp4'):
if pose is not None:
device = psr_grid.device
# get world coordinate of grid points [-1, 1]
res = psr_grid.shape[-1]
x = torch.linspace(-1, 1, steps=res)
co_x, co_y, co_z = torch.meshgrid(x, x, x)
co_grid = torch.stack(
[co_x.reshape(-1), co_y.reshape(-1), co_z.reshape(-1)],
dim=1).to(device).unsqueeze(0)
# visualize the projected occ_soft value
res = 128
psr_grid = psr_grid.reshape(-1)
out_mask = psr_grid>0
in_mask = psr_grid<0
pix = pose.transform_points_screen(co_grid, ((res, res),))[..., :2].round().long().squeeze()
vis_mask = (pix[..., 0]>=0) & (pix[..., 0]<=res-1) & \
(pix[..., 1]>=0) & (pix[..., 1]<=res-1)
pix_out = pix[vis_mask & out_mask]
pix_in = pix[vis_mask & in_mask]
img = torch.ones([res,res]).to(device)
psr_grid = torch.sigmoid(- psr_grid * 5)
img[pix_out[:, 1], pix_out[:, 0]] = psr_grid[vis_mask & out_mask]
img[pix_in[:, 1], pix_in[:, 0]] = psr_grid[vis_mask & in_mask]
# save_image(img, 'tmp.png', normalize=True)
return img
elif out_dir is not None:
dir_psr_vis = out_dir
os.makedirs(dir_psr_vis, exist_ok=True)
psr_grid = psr_grid.squeeze().detach().cpu().numpy()
axis = ['x', 'y', 'z']
s = psr_grid.shape[0]
for idx in trange(s):
my_dpi = 100
plt.figure(figsize=(1000/my_dpi, 300/my_dpi), dpi=my_dpi)
plt.subplot(1, 3, 1)
plt.imshow(ndimage.rotate(psr_grid[idx], 180, mode='nearest'), cmap='nipy_spectral')
plt.clim(-1, 1)
plt.colorbar()
plt.title('x')
plt.grid("off")
plt.axis("off")
plt.subplot(1, 3, 2)
plt.imshow(ndimage.rotate(psr_grid[:, idx], 180, mode='nearest'), cmap='nipy_spectral')
plt.clim(-1, 1)
plt.colorbar()
plt.title('y')
plt.grid("off")
plt.axis("off")
plt.subplot(1, 3, 3)
plt.imshow(ndimage.rotate(psr_grid[:,:,idx], 90, mode='nearest'), cmap='nipy_spectral')
plt.clim(-1, 1)
plt.colorbar()
plt.title('z')
plt.grid("off")
plt.axis("off")
plt.savefig(os.path.join(dir_psr_vis, '{}'.format(idx)), pad_inches = 0, dpi=100)
plt.close()
os.system("rm {}/{}".format(dir_psr_vis, out_video_name))
os.system("ffmpeg -framerate 25 -start_number 0 -i {}/%d.png -pix_fmt yuv420p -crf 17 {}/{}".format(dir_psr_vis, dir_psr_vis, out_video_name))
def visualize_mesh_phong(v, f, n, pose, img_size, name, device='cpu'):
#! Mesh rendering using Phong shading model
_, mask, f_p, w = calc_inters_points(v, f, pose, img_size)
n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]]
n_inters = w[..., 0, None] * n_a.squeeze() + \
w[..., 1, None] * n_b.squeeze() + \
w[..., 2, None] * n_c.squeeze()
n_inters = n_inters.detach().to(device)
light_source = -pose.R@pose.T.squeeze()
light = (light_source / light_source.norm(2)).permute(1, 0).to(device).float()
diffuse_per = torch.Tensor([0.7,0.7,0.7]).float()
ambiant = torch.Tensor([0.3,0.3,0.3]).float()
diffuse = torch.mm(n_inters, light).clamp_min(0).repeat(1, 3) * diffuse_per.unsqueeze(0).to(device)
phong = torch.ones([img_size[0]*img_size[1], 3]).to(device)
phong[mask] = (ambiant.unsqueeze(0).to(device) + diffuse).clamp_max(1.0)
pp = phong.reshape(img_size[0], img_size[1], -1)
save_image(pp.permute(2, 0, 1), name)
def render_rgb(v, f, n, pose, renderer, img_size, mask_gt=None, ray=None, fea_grid=None):
p_inters, mask, f_p, w = calc_inters_points(v.detach(), f, pose, img_size, mask_gt=mask_gt)
# normals for p_inters
n_inters = None
if n is not None:
n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]]
n_inters = w[..., 0, None] * n_a.squeeze() + \
w[..., 1, None] * n_b.squeeze() + \
w[..., 2, None] * n_c.squeeze()
if ray is not None:
ray = ray.squeeze()[mask]
fea = None
if fea_grid is not None:
fea = grid_interp(fea_grid, (p_inters.detach()[None] + 1) / 2).squeeze()
# use MLP to regress color
color_pred = renderer(p_inters, normals=n_inters, view_dirs=ray, feature_vectors=fea).squeeze()
return color_pred, mask

185
train.py Normal file
View file

@ -0,0 +1,185 @@
import os
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)
import torch
import torch.optim as optim
import open3d as o3d
import numpy as np; np.set_printoptions(precision=4)
import shutil, argparse, time
from torch.utils.tensorboard import SummaryWriter
from src import config
from src.data import collate_remove_none, collate_stack_together, worker_init_fn
from src.training import Trainer
from src.model import Encode2Points
from src.utils import load_config, initialize_logger, \
AverageMeter, load_model_manual
def main():
parser = argparse.ArgumentParser(description='MNIST toy experiment')
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--no_cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
args = parser.parse_args()
cfg = load_config(args.config, 'configs/default.yaml')
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
input_type = cfg['data']['input_type']
batch_size = cfg['train']['batch_size']
model_selection_metric = cfg['train']['model_selection_metric']
# PYTORCH VERSION > 1.0.0
assert(float(torch.__version__.split('.')[-3]) > 0)
# boiler-plate
if cfg['train']['timestamp']:
cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S")
logger = initialize_logger(cfg)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
shutil.copyfile(args.config, os.path.join(cfg['train']['out_dir'], 'config.yaml'))
logger.info("using GPU: " + torch.cuda.get_device_name(0))
# TensorboardX writer
tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log")
if not os.path.exists(tblogdir):
os.makedirs(tblogdir, exist_ok=True)
writer = SummaryWriter(log_dir=tblogdir)
inputs = None
train_dataset = config.get_dataset('train', cfg)
val_dataset = config.get_dataset('val', cfg)
vis_dataset = config.get_dataset('vis', cfg)
collate_fn = collate_remove_none
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, num_workers=cfg['train']['n_workers'], shuffle=True,
collate_fn=collate_fn,
worker_init_fn=worker_init_fn)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False,
collate_fn=collate_remove_none,
worker_init_fn=worker_init_fn)
vis_loader = torch.utils.data.DataLoader(
vis_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False,
collate_fn=collate_fn,
worker_init_fn=worker_init_fn)
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(Encode2Points(cfg)).to(device)
else:
model = Encode2Points(cfg).to(device)
n_parameter = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info('Number of parameters: %d'% n_parameter)
# load model
try:
# load model
state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt'))
load_model_manual(state_dict['state_dict'], model)
out = "Load model from iteration %d" % state_dict.get('it', 0)
logger.info(out)
# load point cloud
except:
state_dict = dict()
metric_val_best = state_dict.get(
'loss_val_best', np.inf)
logger.info('Current best validation metric (%s): %.8f'
% (model_selection_metric, metric_val_best))
LR = float(cfg['train']['lr'])
optimizer = optim.Adam(model.parameters(), lr=LR)
start_epoch = state_dict.get('epoch', -1)
it = state_dict.get('it', -1)
trainer = Trainer(cfg, optimizer, device=device)
runtime = {}
runtime['all'] = AverageMeter()
# training loop
for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1):
for batch in train_loader:
it += 1
start = time.time()
loss, loss_each = trainer.train_step(inputs, batch, model)
# measure elapsed time
end = time.time()
runtime['all'].update(end - start)
if it % cfg['train']['print_every'] == 0:
log_text = ('[Epoch %02d] it=%d, loss=%.4f') %(epoch, it, loss)
writer.add_scalar('train/loss', loss, it)
if loss_each is not None:
for k, l in loss_each.items():
if l.item() != 0.:
log_text += (' loss_%s=%.4f') % (k, l.item())
writer.add_scalar('train/%s' % k, l, it)
log_text += (' time=%.3f / %.2f') % (runtime['all'].val, runtime['all'].sum)
logger.info(log_text)
if (it>0)& (it % cfg['train']['visualize_every'] == 0):
for i, batch_vis in enumerate(vis_loader):
trainer.save(model, batch_vis, it, i)
if i >= 4:
break
logger.info('Saved mesh and pointcloud')
# run validation
if it > 0 and (it % cfg['train']['validate_every']) == 0:
eval_dict = trainer.evaluate(val_loader, model)
metric_val = eval_dict[model_selection_metric]
logger.info('Validation metric (%s): %.4f'
% (model_selection_metric, metric_val))
for k, v in eval_dict.items():
writer.add_scalar('val/%s' % k, v, it)
if -(metric_val - metric_val_best) >= 0:
metric_val_best = metric_val
logger.info('New best model (loss %.4f)' % metric_val_best)
state = {'epoch': epoch,
'it': it,
'loss_val_best': metric_val_best}
state['state_dict'] = model.state_dict()
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model_best.pt'))
# save checkpoint
if (epoch > 0) & (it % cfg['train']['checkpoint_every'] == 0):
state = {'epoch': epoch,
'it': it,
'loss_val_best': metric_val_best}
pcl = None
state['state_dict'] = model.state_dict()
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt'))
if (it % cfg['train']['backup_every'] == 0):
torch.save(state, os.path.join(cfg['train']['dir_model'], '%04d' % it + '.pt'))
logger.info("Backup model at iteration %d" % it)
logger.info("Save new model at iteration %d" % it)
done=time.time()
if __name__ == '__main__':
main()