init commit
This commit is contained in:
parent
c4f63f1510
commit
12757682f1
150
.gitignore
vendored
Normal file
150
.gitignore
vendored
Normal file
|
@ -0,0 +1,150 @@
|
||||||
|
/out
|
||||||
|
/data
|
||||||
|
.vscode
|
||||||
|
.cache
|
||||||
|
*.pyc
|
||||||
|
*.pyd
|
||||||
|
*.pt
|
||||||
|
*.so
|
||||||
|
*.o
|
||||||
|
*.prof
|
||||||
|
*.swp
|
||||||
|
*.lib
|
||||||
|
*.obj
|
||||||
|
*.exp
|
||||||
|
.nfs*
|
||||||
|
*.jpg
|
||||||
|
*.png
|
||||||
|
*.ply
|
||||||
|
*.off
|
||||||
|
*.npz
|
||||||
|
*.txt
|
||||||
|
# *.sh
|
||||||
|
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
160
README.md
160
README.md
|
@ -1,5 +1,6 @@
|
||||||
# Shape As Points (SAP)
|
# Shape As Points (SAP)
|
||||||
[**Paper**](https://arxiv.org/abs/2106.03452) | [**Project Page**](https://pengsongyou.github.io/sap) | [**Short Video (6 min)**](https://youtu.be/FL8LMk_qWb4) | [**Long Video (13 min)**](https://youtu.be/TgR0NvYty0A) <br>
|
|
||||||
|
### [**Paper**](https://arxiv.org/abs/2106.03452) | [**Project Page**](https://pengsongyou.github.io/sap) | [**Short Video (6 min)**](https://youtu.be/FL8LMk_qWb4) | [**Long Video (12 min)**](https://youtu.be/TgR0NvYty0A) <br>
|
||||||
|
|
||||||
![](./media/teaser_wheel.gif)
|
![](./media/teaser_wheel.gif)
|
||||||
|
|
||||||
|
@ -10,13 +11,164 @@ Shape As Points: A Differentiable Poisson Solver
|
||||||
**NeurIPS 2021 (Oral)**
|
**NeurIPS 2021 (Oral)**
|
||||||
|
|
||||||
|
|
||||||
## Code is coming soon!
|
|
||||||
|
|
||||||
If you find our code or paper useful, please consider citing
|
If you find our code or paper useful, please consider citing
|
||||||
```bibtex
|
```bibtex
|
||||||
@inproceedings{Peng2021SAP,
|
@inproceedings{Peng2021SAP,
|
||||||
author = {Peng, Songyou and Jiang, Chiyu "Max" and Liao, Yiyi and Niemeyer, Michael and Pollefeys, Marc and Andreas, Geiger},
|
author = {Peng, Songyou and Jiang, Chiyu "Max" and Liao, Yiyi and Niemeyer, Michael and Pollefeys, Marc and Geiger, Andreas},
|
||||||
title = {Shape As Points: A Differentiable Poisson Solver},
|
title = {Shape As Points: A Differentiable Poisson Solver},
|
||||||
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
|
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
|
||||||
year = {2021}}
|
year = {2021}}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
First you have to make sure that you have all dependencies in place.
|
||||||
|
The simplest way to do so, is to use [anaconda](https://www.anaconda.com/).
|
||||||
|
|
||||||
|
You can create an anaconda environment called `sap` using
|
||||||
|
```
|
||||||
|
conda env create -f environment.yaml
|
||||||
|
conda activate sap
|
||||||
|
```
|
||||||
|
|
||||||
|
Now, you can install [PyTorch3D](https://pytorch3d.org/) 0.6.0 from the [official instruction](https://github.com/facebookresearch/pytorch3d/blob/master/INSTALL.md#3-install-wheels-for-linux) as follows
|
||||||
|
```sh
|
||||||
|
pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu102_pyt190/download.html
|
||||||
|
```
|
||||||
|
And install [PyTorch Scatter](https://github.com/rusty1s/pytorch_scatter):
|
||||||
|
```sh
|
||||||
|
conda install pytorch-scatter -c pyg
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Demo - Quick Start
|
||||||
|
|
||||||
|
First, run the script to get the demo data:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash scripts/download_demo_data.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### Optimization-based 3D Surface Reconstruction
|
||||||
|
|
||||||
|
You can now quickly test our code on the data shown in the teaser. To this end, simply run:
|
||||||
|
|
||||||
|
```python
|
||||||
|
python optim_hierarchy.py configs/optim_based/teaser.yaml
|
||||||
|
```
|
||||||
|
This script should create a folder `out/demo_optim` where the output meshes and the optimized oriented point clouds under different grid resolution are stored.
|
||||||
|
|
||||||
|
To visualize the optimization process on the fly, you can set `o3d_show: Frue` in [`configs/optim_based/teaser.yaml`](https://github.com/autonomousvision/shape_as_points/tree/main/configs/optim_based/teaser.yaml).
|
||||||
|
|
||||||
|
### Learning-based 3D Surface Reconstruction
|
||||||
|
You can also test SAP on another application where we can reconstruct from unoriented point clouds with either **large noises** or **outliers** with a learned network.
|
||||||
|
|
||||||
|
![](./media/results_large_noise.gif)
|
||||||
|
|
||||||
|
For the point clouds with large noise as shown above, you can run:
|
||||||
|
```python
|
||||||
|
python generate.py configs/learning_based/demo_large_noise.yaml
|
||||||
|
```
|
||||||
|
The results can been found at `out/demo_shapenet_large_noise/generation/vis`.
|
||||||
|
|
||||||
|
![](./media/results_outliers.gif)
|
||||||
|
As for the point clouds with outliers, you can run:
|
||||||
|
```python
|
||||||
|
python generate.py configs/learning_based/demo_outlier.yaml
|
||||||
|
```
|
||||||
|
You can find the reconstrution on `out/demo_shapenet_outlier/generation/vis`.
|
||||||
|
|
||||||
|
|
||||||
|
## Dataset
|
||||||
|
|
||||||
|
We have different dataset for our optimization-based and learning-based settings.
|
||||||
|
|
||||||
|
### Dataset for Optimization-based Reconstruction
|
||||||
|
Here we consider the following dataset:
|
||||||
|
- [Thingi10K](https://arxiv.org/abs/1605.04797) (synthetic)
|
||||||
|
- [Surface Reconstruction Benchmark (SRB)](https://github.com/fwilliams/deep-geometric-prior) (real scans)
|
||||||
|
- [MPI Dynamic FAUST](https://dfaust.is.tue.mpg.de/) (real scans)
|
||||||
|
|
||||||
|
Please cite the corresponding papers if you use the data.
|
||||||
|
|
||||||
|
You can download the processed dataset (~200 MB) by running:
|
||||||
|
```bash
|
||||||
|
bash scripts/download_optim_data.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### Dataset for Learning-based Reconstruction
|
||||||
|
We train and evaluate on [ShapeNet](https://shapenet.org/).
|
||||||
|
You can download the processed dataset (~220 GB) by running:
|
||||||
|
```bash
|
||||||
|
bash scripts/download_shapenet.sh
|
||||||
|
```
|
||||||
|
After, you should have the dataset in `data/shapenet_psr` folder.
|
||||||
|
|
||||||
|
Alternatively, you can also preprocess the dataset yourself. To this end, you can:
|
||||||
|
* first download the preprocessed dataset (73.4 GB) by running [the script](https://github.com/autonomousvision/occupancy_networks#preprocessed-data) from Occupancy Networks.
|
||||||
|
* check [`scripts/process_shapenet.py`](https://github.com/autonomousvision/shape_as_points/tree/main/scripts/process_shapenet.py), modify the base path and run the code
|
||||||
|
|
||||||
|
|
||||||
|
## Usage for Optimization-based 3D Reconstruction
|
||||||
|
|
||||||
|
For our optimization-based setting, you can consider running with a coarse-to-fine strategy:
|
||||||
|
```python
|
||||||
|
python optim_hierarchy.py configs/optim_based/CONFIG.yaml
|
||||||
|
```
|
||||||
|
We start from a grid resolution of 32^3, and increase to 64^3, 128^3 and finally 256^3.
|
||||||
|
|
||||||
|
Alternatively, you can also run on a single resolution with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
python optim.py configs/optim_based/CONFIG.yaml
|
||||||
|
```
|
||||||
|
You might need to modify the `CONFIG.yaml` accordingly.
|
||||||
|
|
||||||
|
## Usage for Learning-based 3D Reconstruction
|
||||||
|
|
||||||
|
### Mesh Generation
|
||||||
|
To generate meshes using a trained model, use
|
||||||
|
```python
|
||||||
|
python generate.py configs/learning_based/CONFIG.yaml
|
||||||
|
```
|
||||||
|
where you replace `CONFIG.yaml` with the correct config file.
|
||||||
|
|
||||||
|
#### Use a pre-trained model
|
||||||
|
The easiest way is to use a pre-trained model. You can do this by using one of the config files with postfix `_pretrained`.
|
||||||
|
|
||||||
|
For example, for 3D reconstruction from point clouds with outliers using our model with 7x offsets, you can simply run:
|
||||||
|
```python
|
||||||
|
python generate.py configs/learning_based/outlier/ours_7x_pretrained.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
The script will automatically download the pretrained model and run the generation. You can find the outputs in the `out/.../generation_pretrained` folders.
|
||||||
|
|
||||||
|
**Note** config files are only for generation, not for training new models: when these configs are used for training, the model will be trained from scratch, but during inference our code will still use the pretrained model.
|
||||||
|
|
||||||
|
We provide the following pretrained models:
|
||||||
|
```
|
||||||
|
noise_small/ours.pt
|
||||||
|
noise_large/ours.pt
|
||||||
|
outlier/ours_1x.pt
|
||||||
|
outlier/ours_3x.pt
|
||||||
|
outlier/ours_5x.pt
|
||||||
|
outlier/ours_7x.pt
|
||||||
|
outlier/ours_3plane.pt
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Evaluation
|
||||||
|
To evaluate a trained model, we provide the script [`eval_meshes.py`](https://github.com/autonomousvision/shape_as_points/blob/main/eval_meshes.py). You can run it using:
|
||||||
|
```python
|
||||||
|
python eval_meshes.py configs/learning_based/CONFIG.yaml
|
||||||
|
```
|
||||||
|
The script takes the meshes generated in the previous step and evaluates them using a standardized protocol. The output will be written to `.pkl` and `.csv` files in the corresponding generation folder that can be processed using [pandas](https://pandas.pydata.org/).
|
||||||
|
|
||||||
|
### Training
|
||||||
|
|
||||||
|
Finally, to train a new network from scratch, simply run:
|
||||||
|
```python
|
||||||
|
python train.py configs/learning_based/CONFIG.yaml
|
||||||
|
```
|
||||||
|
For available training options, please take a look at `configs/default.yaml`.
|
||||||
|
|
||||||
|
|
104
configs/default.yaml
Normal file
104
configs/default.yaml
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
data:
|
||||||
|
dataset: Shapes3D
|
||||||
|
path: data/ShapeNet
|
||||||
|
class: null
|
||||||
|
data_type: img
|
||||||
|
input_type: pointcloud
|
||||||
|
dim: 3
|
||||||
|
num_points: 1000
|
||||||
|
num_gt_points: 1000
|
||||||
|
num_offset: 1
|
||||||
|
img_size: null
|
||||||
|
n_views_input: 20
|
||||||
|
n_views_per_iter: 2
|
||||||
|
pointcloud_noise: 0
|
||||||
|
pointcloud_file: pointcloud.npz
|
||||||
|
pointcloud_outlier_ratio: 0
|
||||||
|
fixed_scale: 0
|
||||||
|
train_split: train
|
||||||
|
val_split: val
|
||||||
|
test_split: test
|
||||||
|
points_file: null
|
||||||
|
points_iou_file: points.npz
|
||||||
|
points_unpackbits: true
|
||||||
|
padding: 0.1
|
||||||
|
multi_files: null
|
||||||
|
gt_mesh: null
|
||||||
|
zero_level: 0
|
||||||
|
only_single: False
|
||||||
|
sample_only_floor: False
|
||||||
|
model:
|
||||||
|
apply_sigmoid: True
|
||||||
|
grid_res: 128 # poisson grid resolution
|
||||||
|
psr_sigma: 0
|
||||||
|
psr_tanh: False
|
||||||
|
normal_normalize: False
|
||||||
|
raster: {}
|
||||||
|
renderer: {}
|
||||||
|
encoder: null
|
||||||
|
predict_normal: True
|
||||||
|
predict_offset: True
|
||||||
|
s_offset: 0.001
|
||||||
|
local_coord: True
|
||||||
|
encoder_kwargs: {}
|
||||||
|
unet3d: False
|
||||||
|
unet3d_kwargs: {}
|
||||||
|
multi_gpu: false
|
||||||
|
rotate_matrix: false
|
||||||
|
c_dim: 512
|
||||||
|
sphere_radius: 0.2
|
||||||
|
train:
|
||||||
|
lr: 1e-3
|
||||||
|
lr_pcl: 2e-2
|
||||||
|
input_mesh: ''
|
||||||
|
out_dir: out/default
|
||||||
|
subsample_vertex: False
|
||||||
|
batch_size: 1
|
||||||
|
n_grow_points: 0
|
||||||
|
resample_every: 0
|
||||||
|
l_weight: {}
|
||||||
|
w_reg_point: 0
|
||||||
|
w_psr: 0
|
||||||
|
w_raw: 0 # train with raw point cloud
|
||||||
|
gauss_weight: 0
|
||||||
|
n_sup_point: 0
|
||||||
|
w_normals: 0
|
||||||
|
total_epochs: 1000
|
||||||
|
print_every: 20
|
||||||
|
visualize_every: 1000
|
||||||
|
save_every: 1000
|
||||||
|
vis_vert_color: True
|
||||||
|
o3d_show: False
|
||||||
|
o3d_vis_pcl: True
|
||||||
|
o3d_window_size: 540
|
||||||
|
vis_rendering: False
|
||||||
|
vis_psr: False
|
||||||
|
save_video: False
|
||||||
|
exp_mesh: True
|
||||||
|
exp_pcl: True
|
||||||
|
checkpoint_every: 1000
|
||||||
|
validate_every: 2000000
|
||||||
|
backup_every: 100000
|
||||||
|
timestamp: False # add timestamp to out_dir name
|
||||||
|
model_selection_metric: loss
|
||||||
|
model_selection_mode: minimize
|
||||||
|
n_workers: 0
|
||||||
|
n_workers_val: 0
|
||||||
|
test:
|
||||||
|
threshold: 0.5
|
||||||
|
eval_mesh: true
|
||||||
|
eval_pointcloud: false
|
||||||
|
model_file: model_best.pt
|
||||||
|
generation:
|
||||||
|
batch_size: 100000
|
||||||
|
exp_gt: False
|
||||||
|
exp_oracle: false
|
||||||
|
exp_input: False
|
||||||
|
vis_n_outputs: 10
|
||||||
|
generate_mesh: true
|
||||||
|
generate_pointcloud: true
|
||||||
|
generation_dir: generation
|
||||||
|
copy_input: true
|
||||||
|
use_sampling: false
|
||||||
|
psr_resolution: 0
|
||||||
|
psr_sigma: 0
|
9
configs/learning_based/demo_large_noise.yaml
Normal file
9
configs/learning_based/demo_large_noise.yaml
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
|
||||||
|
inherit_from: configs/learning_based/noise_large/ours.yaml
|
||||||
|
data:
|
||||||
|
class: ['']
|
||||||
|
path: data/demo/shapenet_chair
|
||||||
|
train:
|
||||||
|
out_dir: out/demo_shapenet_large_noise
|
||||||
|
test:
|
||||||
|
model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_noise_025.pt
|
9
configs/learning_based/demo_outlier.yaml
Normal file
9
configs/learning_based/demo_outlier.yaml
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
|
||||||
|
inherit_from: configs/learning_based/outlier/ours_7x.yaml
|
||||||
|
data:
|
||||||
|
class: ['']
|
||||||
|
path: data/demo/shapenet_lamp
|
||||||
|
train:
|
||||||
|
out_dir: out/demo_shapenet_outlier
|
||||||
|
test:
|
||||||
|
model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_7x.pt
|
54
configs/learning_based/noise_large/ours.yaml
Normal file
54
configs/learning_based/noise_large/ours.yaml
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
data:
|
||||||
|
class: null
|
||||||
|
data_type: psr_full
|
||||||
|
input_type: pointcloud
|
||||||
|
path: data/shapenet_psr
|
||||||
|
num_gt_points: 10000
|
||||||
|
num_offset: 7
|
||||||
|
pointcloud_n: 3000
|
||||||
|
pointcloud_noise: 0.025
|
||||||
|
model:
|
||||||
|
grid_res: 128 # poisson grid resolution
|
||||||
|
psr_sigma: 2
|
||||||
|
psr_tanh: True
|
||||||
|
normal_normalize: False
|
||||||
|
predict_normal: True
|
||||||
|
predict_offset: True
|
||||||
|
c_dim: 32
|
||||||
|
s_offset: 0.001
|
||||||
|
encoder: local_pool_pointnet
|
||||||
|
encoder_kwargs:
|
||||||
|
hidden_dim: 32
|
||||||
|
plane_type: 'grid'
|
||||||
|
grid_resolution: 32
|
||||||
|
unet3d: True
|
||||||
|
unet3d_kwargs:
|
||||||
|
num_levels: 3
|
||||||
|
f_maps: 32
|
||||||
|
in_channels: 32
|
||||||
|
out_channels: 32
|
||||||
|
decoder: simple_local
|
||||||
|
decoder_kwargs:
|
||||||
|
sample_mode: bilinear # bilinear / nearest
|
||||||
|
hidden_size: 32
|
||||||
|
train:
|
||||||
|
batch_size: 32
|
||||||
|
lr: 5e-4
|
||||||
|
out_dir: out/shapenet/noise_025_ours
|
||||||
|
w_psr: 1
|
||||||
|
model_selection_metric: psr_l2
|
||||||
|
print_every: 100
|
||||||
|
checkpoint_every: 200
|
||||||
|
validate_every: 5000
|
||||||
|
backup_every: 10000
|
||||||
|
total_epochs: 400000
|
||||||
|
visualize_every: 5000
|
||||||
|
exp_pcl: True
|
||||||
|
exp_mesh: True
|
||||||
|
n_workers: 8
|
||||||
|
n_workers_val: 0
|
||||||
|
generation:
|
||||||
|
exp_gt: False
|
||||||
|
exp_input: True
|
||||||
|
psr_resolution: 128
|
||||||
|
psr_sigma: 2
|
5
configs/learning_based/noise_large/ours_pretrained.yaml
Normal file
5
configs/learning_based/noise_large/ours_pretrained.yaml
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
inherit_from: configs/learning_based/noise_large/ours.yaml
|
||||||
|
generation:
|
||||||
|
generation_dir: generation_pretrained
|
||||||
|
test:
|
||||||
|
model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_noise_025.pt
|
54
configs/learning_based/noise_small/ours.yaml
Normal file
54
configs/learning_based/noise_small/ours.yaml
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
data:
|
||||||
|
class: null
|
||||||
|
data_type: psr_full
|
||||||
|
input_type: pointcloud
|
||||||
|
path: data/shapenet_psr
|
||||||
|
num_gt_points: 10000
|
||||||
|
num_offset: 7
|
||||||
|
pointcloud_n: 3000
|
||||||
|
pointcloud_noise: 0.005
|
||||||
|
model:
|
||||||
|
grid_res: 128 # poisson grid resolution
|
||||||
|
psr_sigma: 2
|
||||||
|
psr_tanh: True
|
||||||
|
normal_normalize: False
|
||||||
|
predict_normal: True
|
||||||
|
predict_offset: True
|
||||||
|
c_dim: 32
|
||||||
|
s_offset: 0.001
|
||||||
|
encoder: local_pool_pointnet
|
||||||
|
encoder_kwargs:
|
||||||
|
hidden_dim: 32
|
||||||
|
plane_type: 'grid'
|
||||||
|
grid_resolution: 32
|
||||||
|
unet3d: True
|
||||||
|
unet3d_kwargs:
|
||||||
|
num_levels: 3
|
||||||
|
f_maps: 32
|
||||||
|
in_channels: 32
|
||||||
|
out_channels: 32
|
||||||
|
decoder: simple_local
|
||||||
|
decoder_kwargs:
|
||||||
|
sample_mode: bilinear # bilinear / nearest
|
||||||
|
hidden_size: 32
|
||||||
|
train:
|
||||||
|
batch_size: 32
|
||||||
|
lr: 5e-4
|
||||||
|
out_dir: out/shapenet/noise_005_ours
|
||||||
|
w_psr: 1
|
||||||
|
model_selection_metric: psr_l2
|
||||||
|
print_every: 100
|
||||||
|
checkpoint_every: 200
|
||||||
|
validate_every: 5000
|
||||||
|
backup_every: 10000
|
||||||
|
total_epochs: 400000
|
||||||
|
visualize_every: 5000
|
||||||
|
exp_pcl: True
|
||||||
|
exp_mesh: True
|
||||||
|
n_workers: 8
|
||||||
|
n_workers_val: 0
|
||||||
|
generation:
|
||||||
|
exp_gt: False
|
||||||
|
exp_input: True
|
||||||
|
psr_resolution: 128
|
||||||
|
psr_sigma: 2
|
5
configs/learning_based/noise_small/ours_pretrained.yaml
Normal file
5
configs/learning_based/noise_small/ours_pretrained.yaml
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
inherit_from: configs/learning_based/noise_small/ours.yaml
|
||||||
|
generation:
|
||||||
|
generation_dir: generation_pretrained
|
||||||
|
test:
|
||||||
|
model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_noise_005.pt
|
55
configs/learning_based/outlier/ours_1x.yaml
Normal file
55
configs/learning_based/outlier/ours_1x.yaml
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
data:
|
||||||
|
class: null
|
||||||
|
data_type: psr_full
|
||||||
|
input_type: pointcloud
|
||||||
|
path: data/shapenet_psr
|
||||||
|
num_gt_points: 10000
|
||||||
|
num_offset: 1
|
||||||
|
pointcloud_n: 3000
|
||||||
|
pointcloud_noise: 0.005
|
||||||
|
pointcloud_outlier_ratio: 0.5
|
||||||
|
model:
|
||||||
|
grid_res: 128 # poisson grid resolution
|
||||||
|
psr_sigma: 2
|
||||||
|
psr_tanh: True
|
||||||
|
normal_normalize: False
|
||||||
|
predict_normal: True
|
||||||
|
predict_offset: True
|
||||||
|
c_dim: 32
|
||||||
|
s_offset: 0.001
|
||||||
|
encoder: local_pool_pointnet
|
||||||
|
encoder_kwargs:
|
||||||
|
hidden_dim: 32
|
||||||
|
plane_type: 'grid'
|
||||||
|
grid_resolution: 32
|
||||||
|
unet3d: True
|
||||||
|
unet3d_kwargs:
|
||||||
|
num_levels: 3
|
||||||
|
f_maps: 32
|
||||||
|
in_channels: 32
|
||||||
|
out_channels: 32
|
||||||
|
decoder: simple_local
|
||||||
|
decoder_kwargs:
|
||||||
|
sample_mode: bilinear # bilinear / nearest
|
||||||
|
hidden_size: 32
|
||||||
|
train:
|
||||||
|
batch_size: 32
|
||||||
|
lr: 5e-4
|
||||||
|
out_dir: out/shapenet/outlier_ours_1x
|
||||||
|
w_psr: 1
|
||||||
|
model_selection_metric: psr_l2
|
||||||
|
print_every: 100
|
||||||
|
checkpoint_every: 200
|
||||||
|
validate_every: 5000
|
||||||
|
backup_every: 10000
|
||||||
|
total_epochs: 400000
|
||||||
|
visualize_every: 5000
|
||||||
|
exp_pcl: True
|
||||||
|
exp_mesh: True
|
||||||
|
n_workers: 8
|
||||||
|
n_workers_val: 0
|
||||||
|
generation:
|
||||||
|
exp_gt: False
|
||||||
|
exp_input: True
|
||||||
|
psr_resolution: 128
|
||||||
|
psr_sigma: 2
|
5
configs/learning_based/outlier/ours_1x_pretrained.yaml
Normal file
5
configs/learning_based/outlier/ours_1x_pretrained.yaml
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
inherit_from: configs/learning_based/outlier/ours_3x/ours.yaml
|
||||||
|
generation:
|
||||||
|
generation_dir: generation_pretrained
|
||||||
|
test:
|
||||||
|
model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_3x.pt
|
54
configs/learning_based/outlier/ours_3plane.yaml
Normal file
54
configs/learning_based/outlier/ours_3plane.yaml
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
data:
|
||||||
|
class: null
|
||||||
|
data_type: psr_full
|
||||||
|
input_type: pointcloud
|
||||||
|
path: data/shapenet_psr
|
||||||
|
num_gt_points: 10000
|
||||||
|
num_offset: 5
|
||||||
|
pointcloud_n: 3000
|
||||||
|
pointcloud_noise: 0.005
|
||||||
|
pointcloud_outlier_ratio: 0.5
|
||||||
|
model:
|
||||||
|
grid_res: 128 # poisson grid resolution
|
||||||
|
psr_sigma: 2
|
||||||
|
psr_tanh: True
|
||||||
|
normal_normalize: False
|
||||||
|
predict_normal: True
|
||||||
|
predict_offset: True
|
||||||
|
c_dim: 32
|
||||||
|
s_offset: 0.001
|
||||||
|
encoder: local_pool_pointnet
|
||||||
|
encoder_kwargs:
|
||||||
|
hidden_dim: 32
|
||||||
|
plane_type: ['xz', 'xy', 'yz']
|
||||||
|
plane_resolution: 64
|
||||||
|
unet: True
|
||||||
|
unet_kwargs:
|
||||||
|
depth: 4
|
||||||
|
merge_mode: concat
|
||||||
|
start_filts: 32
|
||||||
|
decoder: simple_local
|
||||||
|
decoder_kwargs:
|
||||||
|
sample_mode: bilinear # bilinear / nearest
|
||||||
|
hidden_size: 32
|
||||||
|
train:
|
||||||
|
batch_size: 32
|
||||||
|
lr: 5e-4
|
||||||
|
out_dir: out/shapenet/outlier_ours_3plane
|
||||||
|
w_psr: 1
|
||||||
|
model_selection_metric: psr_l2
|
||||||
|
print_every: 100
|
||||||
|
checkpoint_every: 200
|
||||||
|
validate_every: 5000
|
||||||
|
backup_every: 10000
|
||||||
|
total_epochs: 400000
|
||||||
|
visualize_every: 5000
|
||||||
|
exp_pcl: True
|
||||||
|
exp_mesh: True
|
||||||
|
n_workers: 8
|
||||||
|
n_workers_val: 0
|
||||||
|
generation:
|
||||||
|
exp_gt: False
|
||||||
|
exp_input: True
|
||||||
|
psr_resolution: 128
|
||||||
|
psr_sigma: 2
|
55
configs/learning_based/outlier/ours_3x.yaml
Normal file
55
configs/learning_based/outlier/ours_3x.yaml
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
data:
|
||||||
|
class: null
|
||||||
|
data_type: psr_full
|
||||||
|
input_type: pointcloud
|
||||||
|
path: data/shapenet_psr
|
||||||
|
num_gt_points: 10000
|
||||||
|
num_offset: 3
|
||||||
|
pointcloud_n: 3000
|
||||||
|
pointcloud_noise: 0.005
|
||||||
|
pointcloud_outlier_ratio: 0.5
|
||||||
|
model:
|
||||||
|
grid_res: 128 # poisson grid resolution
|
||||||
|
psr_sigma: 2
|
||||||
|
psr_tanh: True
|
||||||
|
normal_normalize: False
|
||||||
|
predict_normal: True
|
||||||
|
predict_offset: True
|
||||||
|
c_dim: 32
|
||||||
|
s_offset: 0.001
|
||||||
|
encoder: local_pool_pointnet
|
||||||
|
encoder_kwargs:
|
||||||
|
hidden_dim: 32
|
||||||
|
plane_type: 'grid'
|
||||||
|
grid_resolution: 32
|
||||||
|
unet3d: True
|
||||||
|
unet3d_kwargs:
|
||||||
|
num_levels: 3
|
||||||
|
f_maps: 32
|
||||||
|
in_channels: 32
|
||||||
|
out_channels: 32
|
||||||
|
decoder: simple_local
|
||||||
|
decoder_kwargs:
|
||||||
|
sample_mode: bilinear # bilinear / nearest
|
||||||
|
hidden_size: 32
|
||||||
|
train:
|
||||||
|
batch_size: 32
|
||||||
|
lr: 5e-4
|
||||||
|
out_dir: out/shapenet/outlier_ours_3x
|
||||||
|
w_psr: 1
|
||||||
|
model_selection_metric: psr_l2
|
||||||
|
print_every: 100
|
||||||
|
checkpoint_every: 200
|
||||||
|
validate_every: 5000
|
||||||
|
backup_every: 10000
|
||||||
|
total_epochs: 400000
|
||||||
|
visualize_every: 5000
|
||||||
|
exp_pcl: True
|
||||||
|
exp_mesh: True
|
||||||
|
n_workers: 8
|
||||||
|
n_workers_val: 0
|
||||||
|
generation:
|
||||||
|
exp_gt: False
|
||||||
|
exp_input: True
|
||||||
|
psr_resolution: 128
|
||||||
|
psr_sigma: 2
|
5
configs/learning_based/outlier/ours_3x_pretrained.yaml
Normal file
5
configs/learning_based/outlier/ours_3x_pretrained.yaml
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
inherit_from: configs/learning_based/outlier/ours_1x/ours.yaml
|
||||||
|
generation:
|
||||||
|
generation_dir: generation_pretrained
|
||||||
|
test:
|
||||||
|
model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_1x.pt
|
55
configs/learning_based/outlier/ours_5x.yaml
Normal file
55
configs/learning_based/outlier/ours_5x.yaml
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
data:
|
||||||
|
class: null
|
||||||
|
data_type: psr_full
|
||||||
|
input_type: pointcloud
|
||||||
|
path: data/shapenet_psr
|
||||||
|
num_gt_points: 10000
|
||||||
|
num_offset: 5
|
||||||
|
pointcloud_n: 3000
|
||||||
|
pointcloud_noise: 0.005
|
||||||
|
pointcloud_outlier_ratio: 0.5
|
||||||
|
model:
|
||||||
|
grid_res: 128 # poisson grid resolution
|
||||||
|
psr_sigma: 2
|
||||||
|
psr_tanh: True
|
||||||
|
normal_normalize: False
|
||||||
|
predict_normal: True
|
||||||
|
predict_offset: True
|
||||||
|
c_dim: 32
|
||||||
|
s_offset: 0.001
|
||||||
|
encoder: local_pool_pointnet
|
||||||
|
encoder_kwargs:
|
||||||
|
hidden_dim: 32
|
||||||
|
plane_type: 'grid'
|
||||||
|
grid_resolution: 32
|
||||||
|
unet3d: True
|
||||||
|
unet3d_kwargs:
|
||||||
|
num_levels: 3
|
||||||
|
f_maps: 32
|
||||||
|
in_channels: 32
|
||||||
|
out_channels: 32
|
||||||
|
decoder: simple_local
|
||||||
|
decoder_kwargs:
|
||||||
|
sample_mode: bilinear # bilinear / nearest
|
||||||
|
hidden_size: 32
|
||||||
|
train:
|
||||||
|
batch_size: 32
|
||||||
|
lr: 5e-4
|
||||||
|
out_dir: out/shapenet/outlier_ours_5x
|
||||||
|
w_psr: 1
|
||||||
|
model_selection_metric: psr_l2
|
||||||
|
print_every: 100
|
||||||
|
checkpoint_every: 200
|
||||||
|
validate_every: 5000
|
||||||
|
backup_every: 10000
|
||||||
|
total_epochs: 400000
|
||||||
|
visualize_every: 5000
|
||||||
|
exp_pcl: True
|
||||||
|
exp_mesh: True
|
||||||
|
n_workers: 8
|
||||||
|
n_workers_val: 0
|
||||||
|
generation:
|
||||||
|
exp_gt: False
|
||||||
|
exp_input: True
|
||||||
|
psr_resolution: 128
|
||||||
|
psr_sigma: 2
|
5
configs/learning_based/outlier/ours_5x_pretrained.yaml
Normal file
5
configs/learning_based/outlier/ours_5x_pretrained.yaml
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
inherit_from: configs/learning_based/outlier/ours_5x/ours.yaml
|
||||||
|
generation:
|
||||||
|
generation_dir: generation_pretrained
|
||||||
|
test:
|
||||||
|
model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_5x.pt
|
55
configs/learning_based/outlier/ours_7x.yaml
Normal file
55
configs/learning_based/outlier/ours_7x.yaml
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
data:
|
||||||
|
class: null
|
||||||
|
data_type: psr_full
|
||||||
|
input_type: pointcloud
|
||||||
|
path: data/shapenet_psr
|
||||||
|
num_gt_points: 10000
|
||||||
|
num_offset: 7
|
||||||
|
pointcloud_n: 3000
|
||||||
|
pointcloud_noise: 0.005
|
||||||
|
pointcloud_outlier_ratio: 0.5
|
||||||
|
model:
|
||||||
|
grid_res: 128 # poisson grid resolution
|
||||||
|
psr_sigma: 2
|
||||||
|
psr_tanh: True
|
||||||
|
normal_normalize: False
|
||||||
|
predict_normal: True
|
||||||
|
predict_offset: True
|
||||||
|
c_dim: 32
|
||||||
|
s_offset: 0.001
|
||||||
|
encoder: local_pool_pointnet
|
||||||
|
encoder_kwargs:
|
||||||
|
hidden_dim: 32
|
||||||
|
plane_type: 'grid'
|
||||||
|
grid_resolution: 32
|
||||||
|
unet3d: True
|
||||||
|
unet3d_kwargs:
|
||||||
|
num_levels: 3
|
||||||
|
f_maps: 32
|
||||||
|
in_channels: 32
|
||||||
|
out_channels: 32
|
||||||
|
decoder: simple_local
|
||||||
|
decoder_kwargs:
|
||||||
|
sample_mode: bilinear # bilinear / nearest
|
||||||
|
hidden_size: 32
|
||||||
|
train:
|
||||||
|
batch_size: 32
|
||||||
|
lr: 5e-4
|
||||||
|
out_dir: out/shapenet/outlier_ours_7x
|
||||||
|
w_psr: 1
|
||||||
|
model_selection_metric: psr_l2
|
||||||
|
print_every: 100
|
||||||
|
checkpoint_every: 200
|
||||||
|
validate_every: 5000
|
||||||
|
backup_every: 10000
|
||||||
|
total_epochs: 400000
|
||||||
|
visualize_every: 5000
|
||||||
|
exp_pcl: True
|
||||||
|
exp_mesh: True
|
||||||
|
n_workers: 8
|
||||||
|
n_workers_val: 0
|
||||||
|
generation:
|
||||||
|
exp_gt: False
|
||||||
|
exp_input: True
|
||||||
|
psr_resolution: 128
|
||||||
|
psr_sigma: 2
|
5
configs/learning_based/outlier/ours_7x_pretrained.yaml
Normal file
5
configs/learning_based/outlier/ours_7x_pretrained.yaml
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
inherit_from: configs/learning_based/outlier/ours_7x/ours.yaml
|
||||||
|
generation:
|
||||||
|
generation_dir: generation_pretrained
|
||||||
|
test:
|
||||||
|
model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/models/ours_outlier_7x.pt
|
37
configs/optim_based/dfaust.yaml
Normal file
37
configs/optim_based/dfaust.yaml
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
data:
|
||||||
|
class: 'only_pcl'
|
||||||
|
data_type: point
|
||||||
|
data_path: 'data/dfaust/*.ply'
|
||||||
|
object_id: 0
|
||||||
|
num_points: 20000
|
||||||
|
model:
|
||||||
|
sphere_radius: 0.2
|
||||||
|
grid_res: 256 # poisson grid resolution
|
||||||
|
psr_sigma: 2
|
||||||
|
train:
|
||||||
|
schedule:
|
||||||
|
pcl:
|
||||||
|
initial: 1e-2
|
||||||
|
interval: 700
|
||||||
|
factor: 0.5
|
||||||
|
final: 1e-3
|
||||||
|
out_dir: out/dfaust
|
||||||
|
w_chamfer: 1
|
||||||
|
n_sup_point: 20000
|
||||||
|
batch_size: 1
|
||||||
|
n_grow_points: 2000
|
||||||
|
resample_every: 200
|
||||||
|
subsample_vertex: False
|
||||||
|
total_epochs: 1600
|
||||||
|
print_every: 10
|
||||||
|
visualize_every: 2
|
||||||
|
checkpoint_every: 500
|
||||||
|
save_every: 100
|
||||||
|
exp_pcl: True
|
||||||
|
exp_mesh: True
|
||||||
|
o3d_show: False
|
||||||
|
o3d_vis_pcl: True
|
||||||
|
o3d_window_size: 540
|
||||||
|
vis_rendering: False
|
||||||
|
vis_vert_color: False
|
||||||
|
n_workers: 0
|
38
configs/optim_based/dgp.yaml
Normal file
38
configs/optim_based/dgp.yaml
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
data:
|
||||||
|
class: 'only_pcl'
|
||||||
|
data_type: point
|
||||||
|
data_path: 'data/deep_geometric_prior_data/*.ply'
|
||||||
|
object_id: 0
|
||||||
|
num_points: 20000
|
||||||
|
model:
|
||||||
|
sphere_radius: 0.2
|
||||||
|
grid_res: 256 # poisson grid resolution
|
||||||
|
psr_sigma: 2
|
||||||
|
train:
|
||||||
|
schedule:
|
||||||
|
pcl:
|
||||||
|
initial: 1e-2
|
||||||
|
interval: 700
|
||||||
|
factor: 0.5
|
||||||
|
final: 1e-3
|
||||||
|
out_dir: out/dgp
|
||||||
|
w_reg_point: 0
|
||||||
|
w_chamfer: 1
|
||||||
|
n_sup_point: 20000
|
||||||
|
batch_size: 1
|
||||||
|
n_grow_points: 2000
|
||||||
|
resample_every: 200
|
||||||
|
subsample_vertex: False
|
||||||
|
total_epochs: 1600
|
||||||
|
print_every: 10
|
||||||
|
visualize_every: 2
|
||||||
|
checkpoint_every: 500
|
||||||
|
save_every: 100
|
||||||
|
exp_pcl: True
|
||||||
|
exp_mesh: True
|
||||||
|
o3d_show: False
|
||||||
|
o3d_vis_pcl: True
|
||||||
|
o3d_window_size: 540
|
||||||
|
vis_rendering: False
|
||||||
|
vis_vert_color: False
|
||||||
|
n_workers: 0
|
37
configs/optim_based/teaser.yaml
Normal file
37
configs/optim_based/teaser.yaml
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
data:
|
||||||
|
class: 'only_pcl'
|
||||||
|
data_type: point
|
||||||
|
data_path: 'data/demo/wheel.obj'
|
||||||
|
object_id: 0
|
||||||
|
num_points: 20000
|
||||||
|
model:
|
||||||
|
sphere_radius: 0.2
|
||||||
|
grid_res: 128 # poisson grid resolution
|
||||||
|
psr_sigma: 2
|
||||||
|
train:
|
||||||
|
schedule:
|
||||||
|
pcl:
|
||||||
|
initial: 1e-2
|
||||||
|
interval: 700
|
||||||
|
factor: 0.5
|
||||||
|
final: 1e-3
|
||||||
|
out_dir: out/demo_optim
|
||||||
|
w_chamfer: 1
|
||||||
|
n_sup_point: 20000
|
||||||
|
batch_size: 1
|
||||||
|
n_grow_points: 2000
|
||||||
|
resample_every: 200
|
||||||
|
subsample_vertex: False
|
||||||
|
total_epochs: 1600
|
||||||
|
print_every: 10
|
||||||
|
visualize_every: 2
|
||||||
|
checkpoint_every: 500
|
||||||
|
save_every: 100
|
||||||
|
exp_pcl: True
|
||||||
|
exp_mesh: True
|
||||||
|
o3d_show: False
|
||||||
|
o3d_vis_pcl: True
|
||||||
|
o3d_window_size: 540
|
||||||
|
vis_rendering: False
|
||||||
|
vis_vert_color: False
|
||||||
|
n_workers: 0
|
39
configs/optim_based/thingi.yaml
Normal file
39
configs/optim_based/thingi.yaml
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
data:
|
||||||
|
class: 'only_pcl'
|
||||||
|
data_type: point
|
||||||
|
data_path: 'data/thingi/*.ply'
|
||||||
|
object_id: 0
|
||||||
|
num_points: 20000
|
||||||
|
model:
|
||||||
|
sphere_radius: 0.2
|
||||||
|
grid_res: 128 # poisson grid resolution
|
||||||
|
psr_sigma: 2
|
||||||
|
train:
|
||||||
|
# lr_pcl: 2e-2
|
||||||
|
schedule:
|
||||||
|
pcl:
|
||||||
|
initial: 1e-2
|
||||||
|
interval: 700
|
||||||
|
factor: 0.5
|
||||||
|
final: 1e-3
|
||||||
|
out_dir: out/thingi
|
||||||
|
w_reg_point: 0
|
||||||
|
w_chamfer: 1
|
||||||
|
n_sup_point: 20000
|
||||||
|
batch_size: 1
|
||||||
|
n_grow_points: 2000
|
||||||
|
resample_every: 200
|
||||||
|
subsample_vertex: False
|
||||||
|
total_epochs: 1600
|
||||||
|
print_every: 10
|
||||||
|
visualize_every: 2
|
||||||
|
checkpoint_every: 500
|
||||||
|
save_every: 100
|
||||||
|
exp_pcl: True
|
||||||
|
exp_mesh: True
|
||||||
|
o3d_show: False
|
||||||
|
o3d_vis_pcl: True
|
||||||
|
o3d_window_size: 540
|
||||||
|
vis_rendering: False
|
||||||
|
vis_vert_color: False
|
||||||
|
n_workers: 0
|
39
configs/optim_based/thingi_noisy.yaml
Normal file
39
configs/optim_based/thingi_noisy.yaml
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
data:
|
||||||
|
class: 'only_pcl'
|
||||||
|
data_type: point
|
||||||
|
data_path: 'data/thingi_noisy/*.ply'
|
||||||
|
object_id: 0
|
||||||
|
num_points: 20000
|
||||||
|
model:
|
||||||
|
sphere_radius: 0.2
|
||||||
|
grid_res: 128 # poisson grid resolution
|
||||||
|
psr_sigma: 2
|
||||||
|
train:
|
||||||
|
# lr_pcl: 2e-2
|
||||||
|
schedule:
|
||||||
|
pcl:
|
||||||
|
initial: 1e-2
|
||||||
|
interval: 700
|
||||||
|
factor: 0.5
|
||||||
|
final: 1e-3
|
||||||
|
out_dir: out/thingi_noisy
|
||||||
|
w_reg_point: 0
|
||||||
|
w_chamfer: 1
|
||||||
|
n_sup_point: 20000
|
||||||
|
batch_size: 1
|
||||||
|
n_grow_points: 2000
|
||||||
|
resample_every: 200
|
||||||
|
subsample_vertex: False
|
||||||
|
total_epochs: 1600
|
||||||
|
print_every: 10
|
||||||
|
visualize_every: 2
|
||||||
|
checkpoint_every: 500
|
||||||
|
save_every: 100
|
||||||
|
exp_pcl: True
|
||||||
|
exp_mesh: True
|
||||||
|
o3d_show: False
|
||||||
|
o3d_vis_pcl: True
|
||||||
|
o3d_window_size: 540
|
||||||
|
vis_rendering: False
|
||||||
|
vis_vert_color: False
|
||||||
|
n_workers: 0
|
29
environment.yaml
Normal file
29
environment.yaml
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
name: sap
|
||||||
|
channels:
|
||||||
|
- anaconda
|
||||||
|
- conda-forge
|
||||||
|
- pytorch
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- python=3.8
|
||||||
|
- pytorch=1.9.0
|
||||||
|
- torchvision=0.10.0
|
||||||
|
- cudatoolkit=10.2
|
||||||
|
- numpy=1.18.1
|
||||||
|
- matplotlib=3.4.3
|
||||||
|
- pyyaml=5.3.1
|
||||||
|
- scipy=1.4.1
|
||||||
|
- tqdm=4.54.0
|
||||||
|
- trimesh=3.8.14
|
||||||
|
- igl=2.2.1
|
||||||
|
- tensorboard=2.6.0
|
||||||
|
- pip
|
||||||
|
- pip:
|
||||||
|
- plyfile==0.7
|
||||||
|
- open3d>=0.11.1
|
||||||
|
- scikit-image>=0.18.0
|
||||||
|
- python-mnist==0.7
|
||||||
|
- opencv-python>=4.4
|
||||||
|
- av==8.0.3
|
||||||
|
- pykdtree==1.3.4
|
||||||
|
- ipdb==0.13.7
|
155
eval_meshes.py
Normal file
155
eval_meshes.py
Normal file
|
@ -0,0 +1,155 @@
|
||||||
|
import torch
|
||||||
|
import trimesh
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import numpy as np; np.set_printoptions(precision=4)
|
||||||
|
import shutil, argparse, time, os
|
||||||
|
import pandas as pd
|
||||||
|
from src.data import collate_remove_none, collate_stack_together, worker_init_fn
|
||||||
|
from src.training import Trainer
|
||||||
|
from src.model import Encode2Points
|
||||||
|
from src.data import PointCloudField, IndexField, Shapes3dDataset
|
||||||
|
from src.utils import load_config, load_pointcloud
|
||||||
|
from src.eval import MeshEvaluator
|
||||||
|
from tqdm import tqdm
|
||||||
|
from pdb import set_trace as st
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='MNIST toy experiment')
|
||||||
|
parser.add_argument('config', type=str, help='Path to config file.')
|
||||||
|
parser.add_argument('--no_cuda', action='store_true', default=False,
|
||||||
|
help='disables CUDA training')
|
||||||
|
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
|
||||||
|
parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
cfg = load_config(args.config, 'configs/default.yaml')
|
||||||
|
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||||
|
device = torch.device("cuda" if use_cuda else "cpu")
|
||||||
|
data_type = cfg['data']['data_type']
|
||||||
|
# Shorthands
|
||||||
|
out_dir = cfg['train']['out_dir']
|
||||||
|
generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir'])
|
||||||
|
|
||||||
|
if cfg['generation'].get('iter', 0)!=0:
|
||||||
|
generation_dir += '_%04d'%cfg['generation']['iter']
|
||||||
|
elif args.iter is not None:
|
||||||
|
generation_dir += '_%04d'%args.iter
|
||||||
|
|
||||||
|
print('Evaluate meshes under %s'%generation_dir)
|
||||||
|
|
||||||
|
out_file = os.path.join(generation_dir, 'eval_meshes_full.pkl')
|
||||||
|
out_file_class = os.path.join(generation_dir, 'eval_meshes.csv')
|
||||||
|
|
||||||
|
# PYTORCH VERSION > 1.0.0
|
||||||
|
assert(float(torch.__version__.split('.')[-3]) > 0)
|
||||||
|
|
||||||
|
pointcloud_field = PointCloudField(cfg['data']['pointcloud_file'])
|
||||||
|
fields = {
|
||||||
|
'pointcloud': pointcloud_field,
|
||||||
|
'idx': IndexField(),
|
||||||
|
}
|
||||||
|
|
||||||
|
print('Test split: ', cfg['data']['test_split'])
|
||||||
|
|
||||||
|
dataset_folder = cfg['data']['path']
|
||||||
|
dataset = Shapes3dDataset(
|
||||||
|
dataset_folder, fields,
|
||||||
|
cfg['data']['test_split'],
|
||||||
|
categories=cfg['data']['class'], cfg=cfg)
|
||||||
|
|
||||||
|
# Loader
|
||||||
|
test_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset, batch_size=1, num_workers=0, shuffle=False)
|
||||||
|
|
||||||
|
# Evaluator
|
||||||
|
evaluator = MeshEvaluator(n_points=100000)
|
||||||
|
|
||||||
|
eval_dicts = []
|
||||||
|
print('Evaluating meshes...')
|
||||||
|
for it, data in enumerate(tqdm(test_loader)):
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
print('Invalid data.')
|
||||||
|
continue
|
||||||
|
|
||||||
|
mesh_dir = os.path.join(generation_dir, 'meshes')
|
||||||
|
pointcloud_dir = os.path.join(generation_dir, 'pointcloud')
|
||||||
|
|
||||||
|
|
||||||
|
# Get index etc.
|
||||||
|
idx = data['idx'].item()
|
||||||
|
try:
|
||||||
|
model_dict = dataset.get_model_dict(idx)
|
||||||
|
except AttributeError:
|
||||||
|
model_dict = {'model': str(idx), 'category': 'n/a'}
|
||||||
|
|
||||||
|
modelname = model_dict['model']
|
||||||
|
category_id = model_dict['category']
|
||||||
|
|
||||||
|
try:
|
||||||
|
category_name = dataset.metadata[category_id].get('name', 'n/a')
|
||||||
|
except AttributeError:
|
||||||
|
category_name = 'n/a'
|
||||||
|
|
||||||
|
if category_id != 'n/a':
|
||||||
|
mesh_dir = os.path.join(mesh_dir, category_id)
|
||||||
|
pointcloud_dir = os.path.join(pointcloud_dir, category_id)
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
pointcloud_tgt = data['pointcloud'].squeeze(0).numpy()
|
||||||
|
normals_tgt = data['pointcloud.normals'].squeeze(0).numpy()
|
||||||
|
|
||||||
|
|
||||||
|
eval_dict = {
|
||||||
|
'idx': idx,
|
||||||
|
'class id': category_id,
|
||||||
|
'class name': category_name,
|
||||||
|
'modelname':modelname,
|
||||||
|
}
|
||||||
|
eval_dicts.append(eval_dict)
|
||||||
|
|
||||||
|
# Evaluate mesh
|
||||||
|
if cfg['test']['eval_mesh']:
|
||||||
|
mesh_file = os.path.join(mesh_dir, '%s.off' % modelname)
|
||||||
|
|
||||||
|
if os.path.exists(mesh_file):
|
||||||
|
mesh = trimesh.load(mesh_file, process=False)
|
||||||
|
eval_dict_mesh = evaluator.eval_mesh(
|
||||||
|
mesh, pointcloud_tgt, normals_tgt)
|
||||||
|
for k, v in eval_dict_mesh.items():
|
||||||
|
eval_dict[k + ' (mesh)'] = v
|
||||||
|
else:
|
||||||
|
print('Warning: mesh does not exist: %s' % mesh_file)
|
||||||
|
|
||||||
|
# Evaluate point cloud
|
||||||
|
if cfg['test']['eval_pointcloud']:
|
||||||
|
pointcloud_file = os.path.join(
|
||||||
|
pointcloud_dir, '%s.ply' % modelname)
|
||||||
|
|
||||||
|
if os.path.exists(pointcloud_file):
|
||||||
|
pointcloud = load_pointcloud(pointcloud_file).astype(np.float32)
|
||||||
|
eval_dict_pcl = evaluator.eval_pointcloud(
|
||||||
|
pointcloud, pointcloud_tgt)
|
||||||
|
for k, v in eval_dict_pcl.items():
|
||||||
|
eval_dict[k + ' (pcl)'] = v
|
||||||
|
else:
|
||||||
|
print('Warning: pointcloud does not exist: %s'
|
||||||
|
% pointcloud_file)
|
||||||
|
|
||||||
|
|
||||||
|
# Create pandas dataframe and save
|
||||||
|
eval_df = pd.DataFrame(eval_dicts)
|
||||||
|
eval_df.set_index(['idx'], inplace=True)
|
||||||
|
eval_df.to_pickle(out_file)
|
||||||
|
|
||||||
|
# Create CSV file with main statistics
|
||||||
|
eval_df_class = eval_df.groupby(by=['class name']).mean()
|
||||||
|
eval_df_class.loc['mean'] = eval_df_class.mean()
|
||||||
|
eval_df_class.to_csv(out_file_class)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print(eval_df_class)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
217
generate.py
Normal file
217
generate.py
Normal file
|
@ -0,0 +1,217 @@
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import numpy as np; np.set_printoptions(precision=4)
|
||||||
|
import shutil, argparse, time, os
|
||||||
|
import pandas as pd
|
||||||
|
from collections import defaultdict
|
||||||
|
from src import config
|
||||||
|
from src.utils import mc_from_psr, export_mesh, export_pointcloud
|
||||||
|
from src.dpsr import DPSR
|
||||||
|
from src.training import Trainer
|
||||||
|
from src.model import Encode2Points
|
||||||
|
from src.utils import load_config, load_model_manual, scale2onet, is_url, load_url
|
||||||
|
from tqdm import tqdm
|
||||||
|
from pdb import set_trace as st
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='MNIST toy experiment')
|
||||||
|
parser.add_argument('config', type=str, help='Path to config file.')
|
||||||
|
parser.add_argument('--no_cuda', action='store_true', default=False,
|
||||||
|
help='disables CUDA training')
|
||||||
|
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
|
||||||
|
parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
cfg = load_config(args.config, 'configs/default.yaml')
|
||||||
|
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||||
|
device = torch.device("cuda" if use_cuda else "cpu")
|
||||||
|
data_type = cfg['data']['data_type']
|
||||||
|
input_type = cfg['data']['input_type']
|
||||||
|
vis_n_outputs = cfg['generation']['vis_n_outputs']
|
||||||
|
if vis_n_outputs is None:
|
||||||
|
vis_n_outputs = -1
|
||||||
|
# Shorthands
|
||||||
|
out_dir = cfg['train']['out_dir']
|
||||||
|
if not out_dir:
|
||||||
|
os.makedirs(out_dir)
|
||||||
|
generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir'])
|
||||||
|
out_time_file = os.path.join(generation_dir, 'time_generation_full.pkl')
|
||||||
|
out_time_file_class = os.path.join(generation_dir, 'time_generation.pkl')
|
||||||
|
|
||||||
|
# PYTORCH VERSION > 1.0.0
|
||||||
|
assert(float(torch.__version__.split('.')[-3]) > 0)
|
||||||
|
|
||||||
|
dataset = config.get_dataset('test', cfg, return_idx=True)
|
||||||
|
test_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset, batch_size=1, num_workers=0, shuffle=False)
|
||||||
|
|
||||||
|
model = Encode2Points(cfg).to(device)
|
||||||
|
|
||||||
|
# load model
|
||||||
|
try:
|
||||||
|
if is_url(cfg['test']['model_file']):
|
||||||
|
state_dict = load_url(cfg['test']['model_file'])
|
||||||
|
elif cfg['generation'].get('iter', 0)!=0:
|
||||||
|
state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% cfg['generation']['iter']))
|
||||||
|
generation_dir += '_%04d'%cfg['generation']['iter']
|
||||||
|
elif args.iter is not None:
|
||||||
|
state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% args.iter))
|
||||||
|
else:
|
||||||
|
state_dict = torch.load(os.path.join(out_dir, 'model_best.pt'))
|
||||||
|
|
||||||
|
load_model_manual(state_dict['state_dict'], model)
|
||||||
|
|
||||||
|
except:
|
||||||
|
print('Model loading error. Exiting.')
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
# Generator
|
||||||
|
generator = config.get_generator(model, cfg, device=device)
|
||||||
|
|
||||||
|
# Determine what to generate
|
||||||
|
generate_mesh = cfg['generation']['generate_mesh']
|
||||||
|
generate_pointcloud = cfg['generation']['generate_pointcloud']
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
time_dicts = []
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
model.eval()
|
||||||
|
dpsr = DPSR(res=(cfg['generation']['psr_resolution'],
|
||||||
|
cfg['generation']['psr_resolution'],
|
||||||
|
cfg['generation']['psr_resolution']),
|
||||||
|
sig= cfg['generation']['psr_sigma']).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Count how many models already created
|
||||||
|
model_counter = defaultdict(int)
|
||||||
|
|
||||||
|
print('Generating...')
|
||||||
|
for it, data in enumerate(tqdm(test_loader)):
|
||||||
|
|
||||||
|
# Output folders
|
||||||
|
mesh_dir = os.path.join(generation_dir, 'meshes')
|
||||||
|
in_dir = os.path.join(generation_dir, 'input')
|
||||||
|
pointcloud_dir = os.path.join(generation_dir, 'pointcloud')
|
||||||
|
generation_vis_dir = os.path.join(generation_dir, 'vis', )
|
||||||
|
|
||||||
|
# Get index etc.
|
||||||
|
idx = data['idx'].item()
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_dict = dataset.get_model_dict(idx)
|
||||||
|
except AttributeError:
|
||||||
|
model_dict = {'model': str(idx), 'category': 'n/a'}
|
||||||
|
|
||||||
|
modelname = model_dict['model']
|
||||||
|
category_id = model_dict['category']
|
||||||
|
|
||||||
|
try:
|
||||||
|
category_name = dataset.metadata[category_id].get('name', 'n/a')
|
||||||
|
except AttributeError:
|
||||||
|
category_name = 'n/a'
|
||||||
|
|
||||||
|
if category_id != 'n/a':
|
||||||
|
mesh_dir = os.path.join(mesh_dir, str(category_id))
|
||||||
|
pointcloud_dir = os.path.join(pointcloud_dir, str(category_id))
|
||||||
|
in_dir = os.path.join(in_dir, str(category_id))
|
||||||
|
|
||||||
|
folder_name = str(category_id)
|
||||||
|
if category_name != 'n/a':
|
||||||
|
folder_name = str(folder_name) + '_' + category_name.split(',')[0]
|
||||||
|
|
||||||
|
generation_vis_dir = os.path.join(generation_vis_dir, folder_name)
|
||||||
|
|
||||||
|
# Create directories if necessary
|
||||||
|
if vis_n_outputs >= 0 and not os.path.exists(generation_vis_dir):
|
||||||
|
os.makedirs(generation_vis_dir)
|
||||||
|
|
||||||
|
if generate_mesh and not os.path.exists(mesh_dir):
|
||||||
|
os.makedirs(mesh_dir)
|
||||||
|
|
||||||
|
if generate_pointcloud and not os.path.exists(pointcloud_dir):
|
||||||
|
os.makedirs(pointcloud_dir)
|
||||||
|
|
||||||
|
if not os.path.exists(in_dir):
|
||||||
|
os.makedirs(in_dir)
|
||||||
|
|
||||||
|
# Timing dict
|
||||||
|
time_dict = {
|
||||||
|
'idx': idx,
|
||||||
|
'class id': category_id,
|
||||||
|
'class name': category_name,
|
||||||
|
'modelname':modelname,
|
||||||
|
}
|
||||||
|
time_dicts.append(time_dict)
|
||||||
|
|
||||||
|
# Generate outputs
|
||||||
|
out_file_dict = {}
|
||||||
|
|
||||||
|
if generate_mesh:
|
||||||
|
#! deploy the generator to a separate class
|
||||||
|
out = generator.generate_mesh(data)
|
||||||
|
|
||||||
|
v, f, points, normals, stats_dict = out
|
||||||
|
time_dict.update(stats_dict)
|
||||||
|
|
||||||
|
# Write output
|
||||||
|
mesh_out_file = os.path.join(mesh_dir, '%s.off' % modelname)
|
||||||
|
export_mesh(mesh_out_file, scale2onet(v), f)
|
||||||
|
out_file_dict['mesh'] = mesh_out_file
|
||||||
|
|
||||||
|
if generate_pointcloud:
|
||||||
|
pointcloud_out_file = os.path.join(
|
||||||
|
pointcloud_dir, '%s.ply' % modelname)
|
||||||
|
export_pointcloud(pointcloud_out_file, scale2onet(points), normals)
|
||||||
|
out_file_dict['pointcloud'] = pointcloud_out_file
|
||||||
|
|
||||||
|
if cfg['generation']['copy_input']:
|
||||||
|
inputs_path = os.path.join(in_dir, '%s.ply' % modelname)
|
||||||
|
p = data.get('inputs').to(device)
|
||||||
|
export_pointcloud(inputs_path, scale2onet(p))
|
||||||
|
out_file_dict['in'] = inputs_path
|
||||||
|
|
||||||
|
# Copy to visualization directory for first vis_n_output samples
|
||||||
|
c_it = model_counter[category_id]
|
||||||
|
if c_it < vis_n_outputs:
|
||||||
|
# Save output files
|
||||||
|
img_name = '%02d.off' % c_it
|
||||||
|
for k, filepath in out_file_dict.items():
|
||||||
|
ext = os.path.splitext(filepath)[1]
|
||||||
|
out_file = os.path.join(generation_vis_dir, '%02d_%s%s'
|
||||||
|
% (c_it, k, ext))
|
||||||
|
shutil.copyfile(filepath, out_file)
|
||||||
|
|
||||||
|
# Also generate oracle meshes
|
||||||
|
if cfg['generation']['exp_oracle']:
|
||||||
|
points_gt = data.get('gt_points').to(device)
|
||||||
|
normals_gt = data.get('gt_points.normals').to(device)
|
||||||
|
psr_gt = dpsr(points_gt, normals_gt)
|
||||||
|
v, f, _ = mc_from_psr(psr_gt,
|
||||||
|
zero_level=cfg['data']['zero_level'])
|
||||||
|
out_file = os.path.join(generation_vis_dir, '%02d_%s%s'
|
||||||
|
% (c_it, 'mesh_oracle', '.off'))
|
||||||
|
export_mesh(out_file, scale2onet(v), f)
|
||||||
|
|
||||||
|
model_counter[category_id] += 1
|
||||||
|
|
||||||
|
|
||||||
|
# Create pandas dataframe and save
|
||||||
|
time_df = pd.DataFrame(time_dicts)
|
||||||
|
time_df.set_index(['idx'], inplace=True)
|
||||||
|
time_df.to_pickle(out_time_file)
|
||||||
|
|
||||||
|
# Create pickle files with main statistics
|
||||||
|
time_df_class = time_df.groupby(by=['class name']).mean()
|
||||||
|
time_df_class.loc['mean'] = time_df_class.mean()
|
||||||
|
time_df_class.to_pickle(out_time_file_class)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print('Timings [s]:')
|
||||||
|
print(time_df_class)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
BIN
media/results_large_noise.gif
Normal file
BIN
media/results_large_noise.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.7 MiB |
BIN
media/results_outliers.gif
Normal file
BIN
media/results_outliers.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.7 MiB |
315
optim.py
Normal file
315
optim.py
Normal file
|
@ -0,0 +1,315 @@
|
||||||
|
import torch
|
||||||
|
import trimesh
|
||||||
|
import shutil, argparse, time, os, glob
|
||||||
|
|
||||||
|
import numpy as np; np.set_printoptions(precision=4)
|
||||||
|
import open3d as o3d
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from torchvision.utils import save_image
|
||||||
|
from torchvision.io import write_video
|
||||||
|
|
||||||
|
from src.optimization import Trainer
|
||||||
|
from src.utils import load_config, update_config, initialize_logger, \
|
||||||
|
get_learning_rate_schedules, adjust_learning_rate, AverageMeter,\
|
||||||
|
update_optimizer, export_pointcloud
|
||||||
|
from skimage import measure
|
||||||
|
from plyfile import PlyData
|
||||||
|
from pytorch3d.ops import sample_points_from_meshes
|
||||||
|
from pytorch3d.io import load_objs_as_meshes
|
||||||
|
from pytorch3d.structures import Meshes
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='MNIST toy experiment')
|
||||||
|
parser.add_argument('config', type=str, help='Path to config file.')
|
||||||
|
parser.add_argument('--no_cuda', action='store_true', default=False,
|
||||||
|
help='disables CUDA training')
|
||||||
|
parser.add_argument('--seed', type=int, default=1457, metavar='S',
|
||||||
|
help='random seed')
|
||||||
|
|
||||||
|
args, unknown = parser.parse_known_args()
|
||||||
|
cfg = load_config(args.config, 'configs/default.yaml')
|
||||||
|
cfg = update_config(cfg, unknown)
|
||||||
|
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||||
|
device = torch.device("cuda" if use_cuda else "cpu")
|
||||||
|
data_type = cfg['data']['data_type']
|
||||||
|
data_class = cfg['data']['class']
|
||||||
|
|
||||||
|
print(cfg['train']['out_dir'])
|
||||||
|
|
||||||
|
# PYTORCH VERSION > 1.0.0
|
||||||
|
assert(float(torch.__version__.split('.')[-3]) > 0)
|
||||||
|
|
||||||
|
# boiler-plate
|
||||||
|
if cfg['train']['timestamp']:
|
||||||
|
cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S")
|
||||||
|
logger = initialize_logger(cfg)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
shutil.copyfile(args.config,
|
||||||
|
os.path.join(cfg['train']['out_dir'], 'config.yaml'))
|
||||||
|
|
||||||
|
# tensorboardX writer
|
||||||
|
tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log")
|
||||||
|
if not os.path.exists(tblogdir):
|
||||||
|
os.makedirs(tblogdir)
|
||||||
|
writer = SummaryWriter(log_dir=tblogdir)
|
||||||
|
|
||||||
|
# initialize o3d visualizer
|
||||||
|
vis = None
|
||||||
|
if cfg['train']['o3d_show']:
|
||||||
|
vis = o3d.visualization.Visualizer()
|
||||||
|
vis.create_window(width=cfg['train']['o3d_window_size'],
|
||||||
|
height=cfg['train']['o3d_window_size'])
|
||||||
|
|
||||||
|
# initialize dataset
|
||||||
|
if data_type == 'point':
|
||||||
|
if cfg['data']['object_id'] != -1:
|
||||||
|
data_paths = sorted(glob.glob(cfg['data']['data_path']))
|
||||||
|
data_path = data_paths[cfg['data']['object_id']]
|
||||||
|
print('Loaded %d/%d object' % (cfg['data']['object_id']+1, len(data_paths)))
|
||||||
|
else:
|
||||||
|
data_path = cfg['data']['data_path']
|
||||||
|
print('Data loaded')
|
||||||
|
ext = data_path.split('.')[-1]
|
||||||
|
if ext == 'obj': # have GT mesh
|
||||||
|
mesh = load_objs_as_meshes([data_path], device=device)
|
||||||
|
# scale the mesh into unit cube
|
||||||
|
verts = mesh.verts_packed()
|
||||||
|
N = verts.shape[0]
|
||||||
|
center = verts.mean(0)
|
||||||
|
mesh.offset_verts_(-center.expand(N, 3))
|
||||||
|
scale = max((verts - center).abs().max(0)[0])
|
||||||
|
mesh.scale_verts_((1.0 / float(scale)))
|
||||||
|
# important for our DPSR to have the range in [0, 1), not reaching 1
|
||||||
|
mesh.scale_verts_(0.9)
|
||||||
|
|
||||||
|
target_pts, target_normals = sample_points_from_meshes(mesh,
|
||||||
|
num_samples=200000, return_normals=True)
|
||||||
|
elif ext == 'ply': # only have the point cloud
|
||||||
|
plydata = PlyData.read(data_path)
|
||||||
|
vertices = np.stack([plydata['vertex']['x'],
|
||||||
|
plydata['vertex']['y'],
|
||||||
|
plydata['vertex']['z']], axis=1)
|
||||||
|
normals = np.stack([plydata['vertex']['nx'],
|
||||||
|
plydata['vertex']['ny'],
|
||||||
|
plydata['vertex']['nz']], axis=1)
|
||||||
|
N = vertices.shape[0]
|
||||||
|
center = vertices.mean(0)
|
||||||
|
scale = np.max(np.max(np.abs(vertices - center), axis=0))
|
||||||
|
vertices -= center
|
||||||
|
vertices /= scale
|
||||||
|
vertices *= 0.9
|
||||||
|
|
||||||
|
target_pts = torch.tensor(vertices, device=device)[None].float()
|
||||||
|
target_normals = torch.tensor(normals, device=device)[None].float()
|
||||||
|
mesh = None # no GT mesh
|
||||||
|
|
||||||
|
if not torch.is_tensor(center):
|
||||||
|
center = torch.from_numpy(center)
|
||||||
|
if not torch.is_tensor(scale):
|
||||||
|
scale = torch.from_numpy(np.array([scale]))
|
||||||
|
|
||||||
|
data = {'target_points': target_pts,
|
||||||
|
'target_normals': target_normals, # normals are never used
|
||||||
|
'gt_mesh': mesh}
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# save the input point cloud
|
||||||
|
if 'target_points' in data.keys():
|
||||||
|
outdir_pcl = os.path.join(cfg['train']['out_dir'], 'target_pcl.ply')
|
||||||
|
if 'target_normals' in data.keys():
|
||||||
|
export_pointcloud(outdir_pcl, data['target_points'], data['target_normals'])
|
||||||
|
else:
|
||||||
|
export_pointcloud(outdir_pcl, data['target_points'])
|
||||||
|
|
||||||
|
# save oracle PSR mesh (mesh from our PSR using GT point+normals)
|
||||||
|
if data.get('gt_mesh') is not None:
|
||||||
|
gt_verts, gt_faces = data['gt_mesh'].get_mesh_verts_faces(0)
|
||||||
|
pts_gt, norms_gt = sample_points_from_meshes(data['gt_mesh'],
|
||||||
|
num_samples=500000, return_normals=True)
|
||||||
|
pts_gt = (pts_gt + 1) / 2
|
||||||
|
from src.dpsr import DPSR
|
||||||
|
dpsr_tmp = DPSR(res=(cfg['model']['grid_res'],
|
||||||
|
cfg['model']['grid_res'],
|
||||||
|
cfg['model']['grid_res']),
|
||||||
|
sig=cfg['model']['psr_sigma']).to(device)
|
||||||
|
target = dpsr_tmp(pts_gt, norms_gt).unsqueeze(1).to(device)
|
||||||
|
target = torch.tanh(target)
|
||||||
|
s = target.shape[-1] # size of psr_grid
|
||||||
|
psr_grid_numpy = target.squeeze().detach().cpu().numpy()
|
||||||
|
verts, faces, _, _ = measure.marching_cubes(psr_grid_numpy)
|
||||||
|
verts = verts / s * 2. - 1 # [-1, 1]
|
||||||
|
mesh = o3d.geometry.TriangleMesh()
|
||||||
|
mesh.vertices = o3d.utility.Vector3dVector(verts)
|
||||||
|
mesh.triangles = o3d.utility.Vector3iVector(faces)
|
||||||
|
outdir_mesh = os.path.join(cfg['train']['out_dir'], 'oracle_mesh.ply')
|
||||||
|
o3d.io.write_triangle_mesh(outdir_mesh, mesh)
|
||||||
|
|
||||||
|
# initialize the source point cloud given an input mesh
|
||||||
|
if 'input_mesh' in cfg['train'].keys() and \
|
||||||
|
os.path.isfile(cfg['train']['input_mesh']):
|
||||||
|
if cfg['train']['input_mesh'].split('/')[-2] == 'mesh':
|
||||||
|
mesh_tmp = trimesh.load_mesh(cfg['train']['input_mesh'])
|
||||||
|
verts = torch.from_numpy(mesh_tmp.vertices[None]).float().to(device)
|
||||||
|
faces = torch.from_numpy(mesh_tmp.faces[None]).to(device)
|
||||||
|
mesh = Meshes(verts=verts, faces=faces)
|
||||||
|
points, normals = sample_points_from_meshes(mesh,
|
||||||
|
num_samples=cfg['data']['num_points'], return_normals=True)
|
||||||
|
# mesh is saved in the original scale of the gt
|
||||||
|
points -= center.float().to(device)
|
||||||
|
points /= scale.float().to(device)
|
||||||
|
points *= 0.9
|
||||||
|
# make sure the points are within the range of [0, 1)
|
||||||
|
points = points / 2. + 0.5
|
||||||
|
else:
|
||||||
|
# directly initialize from a point cloud
|
||||||
|
pcd = o3d.io.read_point_cloud(cfg['train']['input_mesh'])
|
||||||
|
points = torch.from_numpy(np.array(pcd.points)[None]).float().to(device)
|
||||||
|
normals = torch.from_numpy(np.array(pcd.normals)[None]).float().to(device)
|
||||||
|
points -= center.float().to(device)
|
||||||
|
points /= scale.float().to(device)
|
||||||
|
points *= 0.9
|
||||||
|
points = points / 2. + 0.5
|
||||||
|
else: #! initialize our source point cloud from a sphere
|
||||||
|
sphere_radius = cfg['model']['sphere_radius']
|
||||||
|
sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius,
|
||||||
|
count=[256,256])
|
||||||
|
points, idx = sphere_mesh.sample(cfg['data']['num_points'],
|
||||||
|
return_index=True)
|
||||||
|
points += 0.5 # make sure the points are within the range of [0, 1)
|
||||||
|
normals = sphere_mesh.face_normals[idx]
|
||||||
|
points = torch.from_numpy(points).unsqueeze(0).to(device)
|
||||||
|
normals = torch.from_numpy(normals).unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
points = torch.log(points/(1-points)) # inverse sigmoid
|
||||||
|
inputs = torch.cat([points, normals], axis=-1).float()
|
||||||
|
inputs.requires_grad = True
|
||||||
|
|
||||||
|
model = None # no network
|
||||||
|
|
||||||
|
# initialize optimizer
|
||||||
|
cfg['train']['schedule']['pcl']['initial'] = cfg['train']['lr_pcl']
|
||||||
|
print('Initial learning rate:', cfg['train']['schedule']['pcl']['initial'])
|
||||||
|
if 'schedule' in cfg['train']:
|
||||||
|
lr_schedules = get_learning_rate_schedules(cfg['train']['schedule'])
|
||||||
|
else:
|
||||||
|
lr_schedules = None
|
||||||
|
|
||||||
|
optimizer = update_optimizer(inputs, cfg,
|
||||||
|
epoch=0, model=model, schedule=lr_schedules)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# load model
|
||||||
|
state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt'))
|
||||||
|
if ('pcl' in state_dict.keys()) & (state_dict['pcl'] is not None):
|
||||||
|
inputs = state_dict['pcl'].to(device)
|
||||||
|
inputs.requires_grad = True
|
||||||
|
|
||||||
|
optimizer = update_optimizer(inputs, cfg,
|
||||||
|
epoch=state_dict.get('epoch'), schedule=lr_schedules)
|
||||||
|
|
||||||
|
out = "Load model from epoch %d" % state_dict.get('epoch', 0)
|
||||||
|
print(out)
|
||||||
|
logger.info(out)
|
||||||
|
except:
|
||||||
|
state_dict = dict()
|
||||||
|
|
||||||
|
start_epoch = state_dict.get('epoch', -1)
|
||||||
|
|
||||||
|
trainer = Trainer(cfg, optimizer, device=device)
|
||||||
|
runtime = {}
|
||||||
|
runtime['all'] = AverageMeter()
|
||||||
|
|
||||||
|
# training loop
|
||||||
|
for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1):
|
||||||
|
|
||||||
|
# schedule the learning rate
|
||||||
|
if (epoch>0) & (lr_schedules is not None):
|
||||||
|
if (epoch % lr_schedules[0].interval == 0):
|
||||||
|
adjust_learning_rate(lr_schedules, optimizer, epoch)
|
||||||
|
if len(lr_schedules) >1:
|
||||||
|
print('[epoch {}] net_lr: {}, pcl_lr: {}'.format(epoch,
|
||||||
|
lr_schedules[0].get_learning_rate(epoch),
|
||||||
|
lr_schedules[1].get_learning_rate(epoch)))
|
||||||
|
else:
|
||||||
|
print('[epoch {}] adjust pcl_lr to: {}'.format(epoch,
|
||||||
|
lr_schedules[0].get_learning_rate(epoch)))
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
loss, loss_each = trainer.train_step(data, inputs, model, epoch)
|
||||||
|
runtime['all'].update(time.time() - start)
|
||||||
|
|
||||||
|
if epoch % cfg['train']['print_every'] == 0:
|
||||||
|
log_text = ('[Epoch %02d] loss=%.5f') %(epoch, loss)
|
||||||
|
if loss_each is not None:
|
||||||
|
for k, l in loss_each.items():
|
||||||
|
if l.item() != 0.:
|
||||||
|
log_text += (' loss_%s=%.5f') % (k, l.item())
|
||||||
|
|
||||||
|
log_text += (' time=%.3f / %.3f') % (runtime['all'].val,
|
||||||
|
runtime['all'].sum)
|
||||||
|
logger.info(log_text)
|
||||||
|
print(log_text)
|
||||||
|
|
||||||
|
# visualize point clouds and meshes
|
||||||
|
if (epoch % cfg['train']['visualize_every'] == 0) & (vis is not None):
|
||||||
|
trainer.visualize(data, inputs, model, epoch, o3d_vis=vis)
|
||||||
|
|
||||||
|
# save outputs
|
||||||
|
if epoch % cfg['train']['save_every'] == 0:
|
||||||
|
trainer.save_mesh_pointclouds(inputs, epoch,
|
||||||
|
center.cpu().numpy(),
|
||||||
|
scale.cpu().numpy()*(1/0.9))
|
||||||
|
|
||||||
|
# save checkpoints
|
||||||
|
if (epoch > 0) & (epoch % cfg['train']['checkpoint_every'] == 0):
|
||||||
|
state = {'epoch': epoch}
|
||||||
|
pcl = None
|
||||||
|
if isinstance(inputs, torch.Tensor):
|
||||||
|
state['pcl'] = inputs.detach().cpu()
|
||||||
|
|
||||||
|
torch.save(state, os.path.join(cfg['train']['dir_model'],
|
||||||
|
'%04d' % epoch + '.pt'))
|
||||||
|
print("Save new model at epoch %d" % epoch)
|
||||||
|
logger.info("Save new model at epoch %d" % epoch)
|
||||||
|
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt'))
|
||||||
|
|
||||||
|
# resample and gradually add new points to the source pcl
|
||||||
|
if (epoch > 0) & \
|
||||||
|
(cfg['train']['resample_every']!=0) & \
|
||||||
|
(epoch % cfg['train']['resample_every'] == 0) & \
|
||||||
|
(epoch < cfg['train']['total_epochs']):
|
||||||
|
inputs = trainer.point_resampling(inputs)
|
||||||
|
optimizer = update_optimizer(inputs, cfg,
|
||||||
|
epoch=epoch, model=model, schedule=lr_schedules)
|
||||||
|
trainer = Trainer(cfg, optimizer, device=device)
|
||||||
|
|
||||||
|
# visualize the Open3D outputs
|
||||||
|
if cfg['train']['o3d_show']:
|
||||||
|
out_video_dir = os.path.join(cfg['train']['out_dir'],
|
||||||
|
'vis/o3d/video.mp4')
|
||||||
|
if os.path.isfile(out_video_dir):
|
||||||
|
os.system('rm {}'.format(out_video_dir))
|
||||||
|
os.system('ffmpeg -framerate 30 \
|
||||||
|
-start_number 0 \
|
||||||
|
-i {}/vis/o3d/%04d.jpg \
|
||||||
|
-pix_fmt yuv420p \
|
||||||
|
-crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir))
|
||||||
|
out_video_dir = os.path.join(cfg['train']['out_dir'],
|
||||||
|
'vis/o3d/video_pcd.mp4')
|
||||||
|
if os.path.isfile(out_video_dir):
|
||||||
|
os.system('rm {}'.format(out_video_dir))
|
||||||
|
os.system('ffmpeg -framerate 30 \
|
||||||
|
-start_number 0 \
|
||||||
|
-i {}/vis/o3d/%04d_pcd.jpg \
|
||||||
|
-pix_fmt yuv420p \
|
||||||
|
-crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir))
|
||||||
|
print('Video saved.')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
69
optim_hierarchy.py
Normal file
69
optim_hierarchy.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
import sys, os
|
||||||
|
import argparse
|
||||||
|
from src.utils import load_config
|
||||||
|
import subprocess
|
||||||
|
os.environ['MKL_THREADING_LAYER'] = 'GNU'
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='MNIST toy experiment')
|
||||||
|
parser.add_argument('config', type=str, help='Path to config file.')
|
||||||
|
parser.add_argument('--start_res', type=int, default=-1, help='Resolution to start with.')
|
||||||
|
parser.add_argument('--object_id', type=int, default=-1, help='Object index.')
|
||||||
|
|
||||||
|
args, unknown = parser.parse_known_args()
|
||||||
|
cfg = load_config(args.config, 'configs/default.yaml')
|
||||||
|
|
||||||
|
resolutions=[32, 64, 128, 256]
|
||||||
|
iterations=[1000, 1000, 1000, 200]
|
||||||
|
lrs=[2e-3, 2e-3*0.7, 2e-3*(0.7**2), 2e-3*(0.7**3)] # reduce lr
|
||||||
|
for idx,(res, iteration, lr) in enumerate(zip(resolutions, iterations, lrs)):
|
||||||
|
|
||||||
|
if res<args.start_res:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if res>cfg['model']['grid_res']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
psr_sigma= 2 if res<=128 else 3
|
||||||
|
|
||||||
|
if res > 128:
|
||||||
|
psr_sigma = 5 if 'thingi_noisy' in args.config else 3
|
||||||
|
|
||||||
|
if args.object_id != -1:
|
||||||
|
out_dir = os.path.join(cfg['train']['out_dir'], 'object_%02d'%args.object_id, 'res_%d'%res)
|
||||||
|
else:
|
||||||
|
out_dir = os.path.join(cfg['train']['out_dir'], 'res_%d'%res)
|
||||||
|
|
||||||
|
# sample from mesh when resampling is enabled, otherwise reuse the pointcloud
|
||||||
|
init_shape='mesh' if cfg['train']['resample_every']>0 else 'pointcloud'
|
||||||
|
|
||||||
|
|
||||||
|
if args.object_id != -1:
|
||||||
|
input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'],
|
||||||
|
'object_%02d'%args.object_id, 'res_%d' % (resolutions[idx-1]),
|
||||||
|
'vis', init_shape, '%04d.ply' % (iterations[idx-1]))
|
||||||
|
else:
|
||||||
|
input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'],
|
||||||
|
'res_%d' % (resolutions[idx-1]),
|
||||||
|
'vis', init_shape, '%04d.ply' % (iterations[idx-1]))
|
||||||
|
|
||||||
|
|
||||||
|
cmd = 'export MKL_SERVICE_FORCE_INTEL=1 && '
|
||||||
|
cmd += "python optim.py %s --model:grid_res %d --model:psr_sigma %d \
|
||||||
|
--train:input_mesh %s --train:total_epochs %d \
|
||||||
|
--train:out_dir %s --train:lr_pcl %f \
|
||||||
|
--data:object_id %d" % (
|
||||||
|
args.config,
|
||||||
|
res,
|
||||||
|
psr_sigma,
|
||||||
|
input_mesh,
|
||||||
|
iteration,
|
||||||
|
out_dir,
|
||||||
|
lr,
|
||||||
|
args.object_id)
|
||||||
|
print(cmd)
|
||||||
|
os.system(cmd)
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
main()
|
8
scripts/download_demo_data.sh
Normal file
8
scripts/download_demo_data.sh
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
#!/bin/bash
|
||||||
|
mkdir -p data
|
||||||
|
cd data
|
||||||
|
echo "Start downloading ..."
|
||||||
|
wget https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/data/demo.zip
|
||||||
|
unzip demo.zip
|
||||||
|
rm demo.zip
|
||||||
|
echo "Done!"
|
8
scripts/download_optim_data.sh
Normal file
8
scripts/download_optim_data.sh
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
#!/bin/bash
|
||||||
|
mkdir -p data
|
||||||
|
cd data
|
||||||
|
echo "Start downloading data for optimization-based setting (~200 MB)"
|
||||||
|
wget https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/data/optim_data.zip
|
||||||
|
unzip optim_data.zip
|
||||||
|
rm optim_data.zip
|
||||||
|
echo "Done!"
|
8
scripts/download_shapenet.sh
Normal file
8
scripts/download_shapenet.sh
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
#!/bin/bash
|
||||||
|
mkdir -p data
|
||||||
|
cd data
|
||||||
|
echo "Start downloading preprocessed ShapeNet data (~220G)"
|
||||||
|
wget https://s3.eu-central-1.amazonaws.com/avg-projects/shape_as_points/data/shapenet_psr.zip
|
||||||
|
unzip shapenet_psr.zip
|
||||||
|
rm shapenet_psr.zip
|
||||||
|
echo "Done!"
|
101
scripts/process_shapenet.py
Normal file
101
scripts/process_shapenet.py
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import multiprocessing
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from src.dpsr import DPSR
|
||||||
|
|
||||||
|
data_path = 'data/ShapeNet' # path for ShapeNet from ONet
|
||||||
|
base = 'data' # output base directory
|
||||||
|
dataset_name = 'shapenet_psr'
|
||||||
|
multiprocess = True
|
||||||
|
njobs = 8
|
||||||
|
save_pointcloud = True
|
||||||
|
save_psr_field = True
|
||||||
|
resolution = 128
|
||||||
|
zero_level = 0.0
|
||||||
|
num_points = 100000
|
||||||
|
padding = 1.2
|
||||||
|
|
||||||
|
dpsr = DPSR(res=(resolution, resolution, resolution), sig=0)
|
||||||
|
|
||||||
|
def process_one(obj):
|
||||||
|
|
||||||
|
obj_name = obj.split('/')[-1]
|
||||||
|
c = obj.split('/')[-2]
|
||||||
|
|
||||||
|
# create new for the current object
|
||||||
|
out_path_cur = os.path.join(base, dataset_name, c)
|
||||||
|
out_path_cur_obj = os.path.join(out_path_cur, obj_name)
|
||||||
|
os.makedirs(out_path_cur_obj, exist_ok=True)
|
||||||
|
|
||||||
|
gt_path = os.path.join(data_path, c, obj_name, 'pointcloud.npz')
|
||||||
|
data = np.load(gt_path)
|
||||||
|
points = data['points']
|
||||||
|
normals = data['normals']
|
||||||
|
|
||||||
|
# normalize the point to [0, 1)
|
||||||
|
points = points / padding + 0.5
|
||||||
|
# to scale back during inference, we should:
|
||||||
|
#! p = (p - 0.5) * padding
|
||||||
|
|
||||||
|
if save_pointcloud:
|
||||||
|
outdir = os.path.join(out_path_cur_obj, 'pointcloud.npz')
|
||||||
|
# np.savez(outdir, points=points, normals=normals)
|
||||||
|
np.savez(outdir, points=data['points'], normals=data['normals'])
|
||||||
|
# return
|
||||||
|
|
||||||
|
if save_psr_field:
|
||||||
|
psr_gt = dpsr(torch.from_numpy(points.astype(np.float32))[None],
|
||||||
|
torch.from_numpy(normals.astype(np.float32))[None]).squeeze().cpu().numpy().astype(np.float16)
|
||||||
|
|
||||||
|
outdir = os.path.join(out_path_cur_obj, 'psr.npz')
|
||||||
|
np.savez(outdir, psr=psr_gt)
|
||||||
|
|
||||||
|
|
||||||
|
def main(c):
|
||||||
|
|
||||||
|
print('---------------------------------------')
|
||||||
|
print('Processing {} {}'.format(c, split))
|
||||||
|
print('---------------------------------------')
|
||||||
|
|
||||||
|
for split in ['train', 'val', 'test']:
|
||||||
|
fname = os.path.join(data_path, c, split+'.lst')
|
||||||
|
with open(fname, 'r') as f:
|
||||||
|
obj_list = f.read().splitlines()
|
||||||
|
|
||||||
|
obj_list = [c+'/'+s for s in obj_list]
|
||||||
|
|
||||||
|
if multiprocess:
|
||||||
|
# multiprocessing.set_start_method('spawn', force=True)
|
||||||
|
pool = multiprocessing.Pool(njobs)
|
||||||
|
try:
|
||||||
|
for _ in tqdm(pool.imap_unordered(process_one, obj_list), total=len(obj_list)):
|
||||||
|
pass
|
||||||
|
# pool.map_async(process_one, obj_list).get()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
# Allow ^C to interrupt from any thread.
|
||||||
|
exit()
|
||||||
|
pool.close()
|
||||||
|
else:
|
||||||
|
for obj in tqdm(obj_list):
|
||||||
|
process_one(obj)
|
||||||
|
|
||||||
|
print('Done Processing {} {}!'.format(c, split))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
classes = ['02691156', '02828884', '02933112',
|
||||||
|
'02958343', '03211117', '03001627',
|
||||||
|
'03636649', '03691459', '04090263',
|
||||||
|
'04256520', '04379243', '04401088', '04530566']
|
||||||
|
|
||||||
|
|
||||||
|
t_start = time.time()
|
||||||
|
for c in classes:
|
||||||
|
main(c)
|
||||||
|
|
||||||
|
t_end = time.time()
|
||||||
|
print('Total processing time: ', t_end - t_start)
|
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
146
src/config.py
Normal file
146
src/config.py
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
import yaml
|
||||||
|
from torchvision import transforms
|
||||||
|
from src import data, generation
|
||||||
|
from src.dpsr import DPSR
|
||||||
|
from ipdb import set_trace as st
|
||||||
|
|
||||||
|
|
||||||
|
# Generator for final mesh extraction
|
||||||
|
def get_generator(model, cfg, device, **kwargs):
|
||||||
|
''' Returns the generator object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): Occupancy Network model
|
||||||
|
cfg (dict): imported yaml config
|
||||||
|
device (device): pytorch device
|
||||||
|
'''
|
||||||
|
|
||||||
|
if cfg['generation']['psr_resolution'] == 0:
|
||||||
|
psr_res = cfg['model']['grid_res']
|
||||||
|
psr_sigma = cfg['model']['psr_sigma']
|
||||||
|
else:
|
||||||
|
psr_res = cfg['generation']['psr_resolution']
|
||||||
|
psr_sigma = cfg['generation']['psr_sigma']
|
||||||
|
|
||||||
|
dpsr = DPSR(res=(psr_res, psr_res, psr_res),
|
||||||
|
sig= psr_sigma).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
generator = generation.Generator3D(
|
||||||
|
model,
|
||||||
|
device=device,
|
||||||
|
threshold=cfg['data']['zero_level'],
|
||||||
|
sample=cfg['generation']['use_sampling'],
|
||||||
|
input_type = cfg['data']['input_type'],
|
||||||
|
padding=cfg['data']['padding'],
|
||||||
|
dpsr=dpsr,
|
||||||
|
psr_tanh=cfg['model']['psr_tanh']
|
||||||
|
)
|
||||||
|
return generator
|
||||||
|
|
||||||
|
# Datasets
|
||||||
|
def get_dataset(mode, cfg, return_idx=False):
|
||||||
|
''' Returns the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): the model which is used
|
||||||
|
cfg (dict): config dictionary
|
||||||
|
return_idx (bool): whether to include an ID field
|
||||||
|
'''
|
||||||
|
dataset_type = cfg['data']['dataset']
|
||||||
|
dataset_folder = cfg['data']['path']
|
||||||
|
categories = cfg['data']['class']
|
||||||
|
|
||||||
|
# Get split
|
||||||
|
splits = {
|
||||||
|
'train': cfg['data']['train_split'],
|
||||||
|
'val': cfg['data']['val_split'],
|
||||||
|
'test': cfg['data']['test_split'],
|
||||||
|
'vis': cfg['data']['val_split'],
|
||||||
|
}
|
||||||
|
|
||||||
|
split = splits[mode]
|
||||||
|
|
||||||
|
# Create dataset
|
||||||
|
if dataset_type == 'Shapes3D':
|
||||||
|
fields = get_data_fields(mode, cfg)
|
||||||
|
# Input fields
|
||||||
|
inputs_field = get_inputs_field(mode, cfg)
|
||||||
|
if inputs_field is not None:
|
||||||
|
fields['inputs'] = inputs_field
|
||||||
|
|
||||||
|
if return_idx:
|
||||||
|
fields['idx'] = data.IndexField()
|
||||||
|
|
||||||
|
dataset = data.Shapes3dDataset(
|
||||||
|
dataset_folder, fields,
|
||||||
|
split=split,
|
||||||
|
categories=categories,
|
||||||
|
cfg = cfg
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError('Invalid dataset "%s"' % cfg['data']['dataset'])
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def get_inputs_field(mode, cfg):
|
||||||
|
''' Returns the inputs fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode (str): the mode which is used
|
||||||
|
cfg (dict): config dictionary
|
||||||
|
'''
|
||||||
|
input_type = cfg['data']['input_type']
|
||||||
|
|
||||||
|
if input_type is None:
|
||||||
|
inputs_field = None
|
||||||
|
elif input_type == 'pointcloud':
|
||||||
|
noise_level = cfg['data']['pointcloud_noise']
|
||||||
|
if cfg['data']['pointcloud_outlier_ratio']>0:
|
||||||
|
transform = transforms.Compose([
|
||||||
|
data.SubsamplePointcloud(cfg['data']['pointcloud_n']),
|
||||||
|
data.PointcloudNoise(noise_level),
|
||||||
|
data.PointcloudOutliers(cfg['data']['pointcloud_outlier_ratio'])
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
transform = transforms.Compose([
|
||||||
|
data.SubsamplePointcloud(cfg['data']['pointcloud_n']),
|
||||||
|
data.PointcloudNoise(noise_level)
|
||||||
|
])
|
||||||
|
|
||||||
|
data_type = cfg['data']['data_type']
|
||||||
|
inputs_field = data.PointCloudField(
|
||||||
|
cfg['data']['pointcloud_file'], data_type, transform,
|
||||||
|
multi_files= cfg['data']['multi_files']
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'Invalid input type (%s)' % input_type)
|
||||||
|
return inputs_field
|
||||||
|
|
||||||
|
def get_data_fields(mode, cfg):
|
||||||
|
''' Returns the data fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode (str): the mode which is used
|
||||||
|
cfg (dict): imported yaml config
|
||||||
|
'''
|
||||||
|
data_type = cfg['data']['data_type']
|
||||||
|
fields = {}
|
||||||
|
|
||||||
|
if (mode in ('val', 'test')):
|
||||||
|
transform = data.SubsamplePointcloud(100000)
|
||||||
|
else:
|
||||||
|
transform = data.SubsamplePointcloud(cfg['data']['num_gt_points'])
|
||||||
|
|
||||||
|
data_name = cfg['data']['pointcloud_file']
|
||||||
|
fields['gt_points'] = data.PointCloudField(data_name,
|
||||||
|
transform=transform, data_type=data_type, multi_files=cfg['data']['multi_files'])
|
||||||
|
if data_type == 'psr_full':
|
||||||
|
if mode != 'test':
|
||||||
|
fields['gt_psr'] = data.FullPSRField(multi_files=cfg['data']['multi_files'])
|
||||||
|
else:
|
||||||
|
raise ValueError('Invalid data type (%s)' % data_type)
|
||||||
|
|
||||||
|
return fields
|
25
src/data/__init__.py
Normal file
25
src/data/__init__.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
|
||||||
|
from src.data.core import (
|
||||||
|
Shapes3dDataset, collate_remove_none, worker_init_fn, collate_stack_together
|
||||||
|
)
|
||||||
|
from src.data.fields import (
|
||||||
|
IndexField, PointCloudField, FullPSRField
|
||||||
|
)
|
||||||
|
from src.data.transforms import (
|
||||||
|
PointcloudNoise, SubsamplePointcloud,
|
||||||
|
PointcloudOutliers,
|
||||||
|
)
|
||||||
|
__all__ = [
|
||||||
|
# Core
|
||||||
|
Shapes3dDataset,
|
||||||
|
collate_remove_none,
|
||||||
|
worker_init_fn,
|
||||||
|
# Fields
|
||||||
|
IndexField,
|
||||||
|
PointCloudField,
|
||||||
|
FullPSRField,
|
||||||
|
# Transforms
|
||||||
|
PointcloudNoise,
|
||||||
|
SubsamplePointcloud,
|
||||||
|
PointcloudOutliers,
|
||||||
|
]
|
237
src/data/core.py
Normal file
237
src/data/core.py
Normal file
|
@ -0,0 +1,237 @@
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from torch.utils import data
|
||||||
|
from pdb import set_trace as st
|
||||||
|
import numpy as np
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Fields
|
||||||
|
class Field(object):
|
||||||
|
''' Data fields class.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def load(self, data_path, idx, category):
|
||||||
|
''' Loads a data point.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_path (str): path to data file
|
||||||
|
idx (int): index of data point
|
||||||
|
category (int): index of category
|
||||||
|
'''
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def check_complete(self, files):
|
||||||
|
''' Checks if set is complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: files
|
||||||
|
'''
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class Shapes3dDataset(data.Dataset):
|
||||||
|
''' 3D Shapes dataset class.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, dataset_folder, fields, split=None,
|
||||||
|
categories=None, no_except=True, transform=None, cfg=None):
|
||||||
|
''' Initialization of the the 3D shape dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_folder (str): dataset folder
|
||||||
|
fields (dict): dictionary of fields
|
||||||
|
split (str): which split is used
|
||||||
|
categories (list): list of categories to use
|
||||||
|
no_except (bool): no exception
|
||||||
|
transform (callable): transformation applied to data points
|
||||||
|
cfg (yaml): config file
|
||||||
|
'''
|
||||||
|
# Attributes
|
||||||
|
self.dataset_folder = dataset_folder
|
||||||
|
self.fields = fields
|
||||||
|
self.no_except = no_except
|
||||||
|
self.transform = transform
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
# If categories is None, use all subfolders
|
||||||
|
if categories is None:
|
||||||
|
categories = os.listdir(dataset_folder)
|
||||||
|
categories = [c for c in categories
|
||||||
|
if os.path.isdir(os.path.join(dataset_folder, c))]
|
||||||
|
|
||||||
|
# Read metadata file
|
||||||
|
metadata_file = os.path.join(dataset_folder, 'metadata.yaml')
|
||||||
|
|
||||||
|
if os.path.exists(metadata_file):
|
||||||
|
with open(metadata_file, 'r') as f:
|
||||||
|
self.metadata = yaml.load(f, Loader=yaml.Loader)
|
||||||
|
else:
|
||||||
|
self.metadata = {
|
||||||
|
c: {'id': c, 'name': 'n/a'} for c in categories
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set index
|
||||||
|
for c_idx, c in enumerate(categories):
|
||||||
|
self.metadata[c]['idx'] = c_idx
|
||||||
|
|
||||||
|
# Get all models
|
||||||
|
self.models = []
|
||||||
|
for c_idx, c in enumerate(categories):
|
||||||
|
subpath = os.path.join(dataset_folder, c)
|
||||||
|
if not os.path.isdir(subpath):
|
||||||
|
logger.warning('Category %s does not exist in dataset.' % c)
|
||||||
|
|
||||||
|
if split is None:
|
||||||
|
self.models += [
|
||||||
|
{'category': c, 'model': m} for m in [d for d in os.listdir(subpath) if (os.path.isdir(os.path.join(subpath, d)) and d != '') ]
|
||||||
|
]
|
||||||
|
|
||||||
|
else:
|
||||||
|
split_file = os.path.join(subpath, split + '.lst')
|
||||||
|
with open(split_file, 'r') as f:
|
||||||
|
models_c = f.read().split('\n')
|
||||||
|
|
||||||
|
if '' in models_c:
|
||||||
|
models_c.remove('')
|
||||||
|
|
||||||
|
self.models += [
|
||||||
|
{'category': c, 'model': m}
|
||||||
|
for m in models_c
|
||||||
|
]
|
||||||
|
|
||||||
|
# precompute
|
||||||
|
self.split = split
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
''' Returns the length of the dataset.
|
||||||
|
'''
|
||||||
|
return len(self.models)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
''' Returns an item of the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int): ID of data point
|
||||||
|
'''
|
||||||
|
|
||||||
|
category = self.models[idx]['category']
|
||||||
|
model = self.models[idx]['model']
|
||||||
|
c_idx = self.metadata[category]['idx']
|
||||||
|
|
||||||
|
model_path = os.path.join(self.dataset_folder, category, model)
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
info = c_idx
|
||||||
|
|
||||||
|
if self.cfg['data']['multi_files'] is not None:
|
||||||
|
idx = np.random.randint(self.cfg['data']['multi_files'])
|
||||||
|
if self.split != 'train':
|
||||||
|
idx = 0
|
||||||
|
|
||||||
|
for field_name, field in self.fields.items():
|
||||||
|
try:
|
||||||
|
field_data = field.load(model_path, idx, info)
|
||||||
|
except Exception:
|
||||||
|
if self.no_except:
|
||||||
|
logger.warn(
|
||||||
|
'Error occured when loading field %s of model %s'
|
||||||
|
% (field_name, model)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
if isinstance(field_data, dict):
|
||||||
|
for k, v in field_data.items():
|
||||||
|
if k is None:
|
||||||
|
data[field_name] = v
|
||||||
|
else:
|
||||||
|
data['%s.%s' % (field_name, k)] = v
|
||||||
|
else:
|
||||||
|
data[field_name] = field_data
|
||||||
|
|
||||||
|
if self.transform is not None:
|
||||||
|
data = self.transform(data)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_dict(self, idx):
|
||||||
|
return self.models[idx]
|
||||||
|
|
||||||
|
def test_model_complete(self, category, model):
|
||||||
|
''' Tests if model is complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): modelname
|
||||||
|
'''
|
||||||
|
model_path = os.path.join(self.dataset_folder, category, model)
|
||||||
|
files = os.listdir(model_path)
|
||||||
|
for field_name, field in self.fields.items():
|
||||||
|
if not field.check_complete(files):
|
||||||
|
logger.warn('Field "%s" is incomplete: %s'
|
||||||
|
% (field_name, model_path))
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def collate_remove_none(batch):
|
||||||
|
''' Collater that puts each data field into a tensor with outer dimension
|
||||||
|
batch size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: batch
|
||||||
|
'''
|
||||||
|
|
||||||
|
batch = list(filter(lambda x: x is not None, batch))
|
||||||
|
return data.dataloader.default_collate(batch)
|
||||||
|
|
||||||
|
def collate_stack_together(batch):
|
||||||
|
''' Collater that puts each data field into a tensor with outer dimension
|
||||||
|
batch size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: batch
|
||||||
|
'''
|
||||||
|
|
||||||
|
batch = list(filter(lambda x: x is not None, batch))
|
||||||
|
keys = batch[0].keys()
|
||||||
|
concat = {}
|
||||||
|
if len(batch)>1:
|
||||||
|
for key in keys:
|
||||||
|
key_val = [item[key] for item in batch]
|
||||||
|
concat[key] = np.concatenate(key_val, axis=0)
|
||||||
|
if key == 'inputs':
|
||||||
|
n_pts = [item[key].shape[0] for item in batch]
|
||||||
|
|
||||||
|
concat['batch_ind'] = np.concatenate(
|
||||||
|
[i * np.ones(n, dtype=int) for i, n in enumerate(n_pts)], axis=0)
|
||||||
|
|
||||||
|
return data.dataloader.default_collate([concat])
|
||||||
|
else:
|
||||||
|
n_pts = batch[0]['inputs'].shape[0]
|
||||||
|
batch[0]['batch_ind'] = np.zeros(n_pts, dtype=int)
|
||||||
|
return data.dataloader.default_collate(batch)
|
||||||
|
|
||||||
|
|
||||||
|
def worker_init_fn(worker_id):
|
||||||
|
''' Worker init function to ensure true randomness.
|
||||||
|
'''
|
||||||
|
def set_num_threads(nt):
|
||||||
|
try:
|
||||||
|
import mkl; mkl.set_num_threads(nt)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
os.environ['IPC_ENABLE']='1'
|
||||||
|
for o in ['OPENBLAS_NUM_THREADS','NUMEXPR_NUM_THREADS','OMP_NUM_THREADS','MKL_NUM_THREADS']:
|
||||||
|
os.environ[o] = str(nt)
|
||||||
|
|
||||||
|
random_data = os.urandom(4)
|
||||||
|
base_seed = int.from_bytes(random_data, byteorder="big")
|
||||||
|
np.random.seed(base_seed + worker_id)
|
118
src/data/fields.py
Normal file
118
src/data/fields.py
Normal file
|
@ -0,0 +1,118 @@
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
import trimesh
|
||||||
|
from src.data.core import Field
|
||||||
|
from pdb import set_trace as st
|
||||||
|
|
||||||
|
|
||||||
|
class IndexField(Field):
|
||||||
|
''' Basic index field.'''
|
||||||
|
def load(self, model_path, idx, category):
|
||||||
|
''' Loads the index field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str): path to model
|
||||||
|
idx (int): ID of data point
|
||||||
|
category (int): index of category
|
||||||
|
'''
|
||||||
|
return idx
|
||||||
|
|
||||||
|
def check_complete(self, files):
|
||||||
|
''' Check if field is complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: files
|
||||||
|
'''
|
||||||
|
return True
|
||||||
|
|
||||||
|
class FullPSRField(Field):
|
||||||
|
def __init__(self, transform=None, multi_files=None):
|
||||||
|
self.transform = transform
|
||||||
|
# self.unpackbits = unpackbits
|
||||||
|
self.multi_files = multi_files
|
||||||
|
|
||||||
|
def load(self, model_path, idx, category):
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# t0 = time.time()
|
||||||
|
if self.multi_files is not None:
|
||||||
|
psr_path = os.path.join(model_path, 'psr', 'psr_{:02d}.npz'.format(idx))
|
||||||
|
else:
|
||||||
|
psr_path = os.path.join(model_path, 'psr.npz')
|
||||||
|
psr_dict = np.load(psr_path)
|
||||||
|
# t1 = time.time()
|
||||||
|
psr = psr_dict['psr']
|
||||||
|
psr = psr.astype(np.float32)
|
||||||
|
# t2 = time.time()
|
||||||
|
# print('load PSR: {:.4f}, change type: {:.4f}, total: {:.4f}'.format(t1 - t0, t2 - t1, t2-t0))
|
||||||
|
data = {None: psr}
|
||||||
|
|
||||||
|
if self.transform is not None:
|
||||||
|
data = self.transform(data)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
class PointCloudField(Field):
|
||||||
|
''' Point cloud field.
|
||||||
|
|
||||||
|
It provides the field used for point cloud data. These are the points
|
||||||
|
randomly sampled on the mesh.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name (str): file name
|
||||||
|
transform (list): list of transformations applied to data points
|
||||||
|
multi_files (callable): number of files
|
||||||
|
'''
|
||||||
|
def __init__(self, file_name, data_type=None, transform=None, multi_files=None, padding=0.1, scale=1.2):
|
||||||
|
self.file_name = file_name
|
||||||
|
self.data_type = data_type # to make sure the range of input is correct
|
||||||
|
self.transform = transform
|
||||||
|
self.multi_files = multi_files
|
||||||
|
self.padding = padding
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def load(self, model_path, idx, category):
|
||||||
|
''' Loads the data point.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str): path to model
|
||||||
|
idx (int): ID of data point
|
||||||
|
category (int): index of category
|
||||||
|
'''
|
||||||
|
if self.multi_files is None:
|
||||||
|
file_path = os.path.join(model_path, self.file_name)
|
||||||
|
else:
|
||||||
|
# num = np.random.randint(self.multi_files)
|
||||||
|
# file_path = os.path.join(model_path, self.file_name, '%s_%02d.npz' % (self.file_name, num))
|
||||||
|
file_path = os.path.join(model_path, self.file_name, 'pointcloud_%02d.npz' % (idx))
|
||||||
|
|
||||||
|
pointcloud_dict = np.load(file_path)
|
||||||
|
|
||||||
|
points = pointcloud_dict['points'].astype(np.float32)
|
||||||
|
normals = pointcloud_dict['normals'].astype(np.float32)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
None: points,
|
||||||
|
'normals': normals,
|
||||||
|
}
|
||||||
|
if self.transform is not None:
|
||||||
|
data = self.transform(data)
|
||||||
|
|
||||||
|
if self.data_type == 'psr_full':
|
||||||
|
# scale the point cloud to the range of (0, 1)
|
||||||
|
data[None] = data[None] / self.scale + 0.5
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def check_complete(self, files):
|
||||||
|
''' Check if field is complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: files
|
||||||
|
'''
|
||||||
|
complete = (self.file_name in files)
|
||||||
|
return complete
|
86
src/data/transforms.py
Normal file
86
src/data/transforms.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# Transforms
|
||||||
|
class PointcloudNoise(object):
|
||||||
|
''' Point cloud noise transformation class.
|
||||||
|
|
||||||
|
It adds noise to point cloud data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stddev (int): standard deviation
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, stddev):
|
||||||
|
self.stddev = stddev
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
''' Calls the transformation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (dictionary): data dictionary
|
||||||
|
'''
|
||||||
|
data_out = data.copy()
|
||||||
|
points = data[None]
|
||||||
|
noise = self.stddev * np.random.randn(*points.shape)
|
||||||
|
noise = noise.astype(np.float32)
|
||||||
|
data_out[None] = points + noise
|
||||||
|
return data_out
|
||||||
|
|
||||||
|
class PointcloudOutliers(object):
|
||||||
|
''' Point cloud outlier transformation class.
|
||||||
|
|
||||||
|
It adds outliers to point cloud data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ratio (int): outlier percentage to the entire point cloud
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, ratio):
|
||||||
|
self.ratio = ratio
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
''' Calls the transformation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (dictionary): data dictionary
|
||||||
|
'''
|
||||||
|
data_out = data.copy()
|
||||||
|
points = data[None]
|
||||||
|
n_points = points.shape[0]
|
||||||
|
n_outlier_points = int(n_points*self.ratio)
|
||||||
|
ind = np.random.randint(0, n_points, n_outlier_points)
|
||||||
|
|
||||||
|
outliers = np.random.uniform(-0.55, 0.55, (n_outlier_points, 3))
|
||||||
|
outliers = outliers.astype(np.float32)
|
||||||
|
points[ind] = outliers
|
||||||
|
data_out[None] = points
|
||||||
|
return data_out
|
||||||
|
|
||||||
|
class SubsamplePointcloud(object):
|
||||||
|
''' Point cloud subsampling transformation class.
|
||||||
|
|
||||||
|
It subsamples the point cloud data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
N (int): number of points to be subsampled
|
||||||
|
'''
|
||||||
|
def __init__(self, N):
|
||||||
|
self.N = N
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
''' Calls the transformation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (dict): data dictionary
|
||||||
|
'''
|
||||||
|
data_out = data.copy()
|
||||||
|
points = data[None]
|
||||||
|
|
||||||
|
indices = np.random.randint(points.shape[0], size=self.N)
|
||||||
|
data_out[None] = points[indices, :]
|
||||||
|
if 'normals' in data.keys():
|
||||||
|
normals = data['normals']
|
||||||
|
data_out['normals'] = normals[indices, :]
|
||||||
|
|
||||||
|
return data_out
|
228
src/data_loader.py
Normal file
228
src/data_loader.py
Normal file
|
@ -0,0 +1,228 @@
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from glob import glob
|
||||||
|
from torch.utils import data
|
||||||
|
from src.utils import load_rgb, load_mask, get_camera_params
|
||||||
|
from pytorch3d.renderer import PerspectiveCameras
|
||||||
|
from skimage import img_as_float32
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# Below are for the differentiable renderer
|
||||||
|
# Taken from https://github.com/lioryariv/idr/blob/main/code/utils/rend_util.py
|
||||||
|
|
||||||
|
def load_rgb(path):
|
||||||
|
img = imageio.imread(path)
|
||||||
|
img = img_as_float32(img)
|
||||||
|
|
||||||
|
# pixel values between [-1,1]
|
||||||
|
# img -= 0.5
|
||||||
|
# img *= 2.
|
||||||
|
|
||||||
|
# img = img.transpose(2, 0, 1)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def load_mask(path):
|
||||||
|
alpha = imageio.imread(path, as_gray=True)
|
||||||
|
alpha = img_as_float32(alpha)
|
||||||
|
object_mask = alpha > 127.5
|
||||||
|
|
||||||
|
return object_mask
|
||||||
|
|
||||||
|
|
||||||
|
def get_camera_params(uv, pose, intrinsics):
|
||||||
|
if pose.shape[1] == 7: #In case of quaternion vector representation
|
||||||
|
cam_loc = pose[:, 4:]
|
||||||
|
R = quat_to_rot(pose[:,:4])
|
||||||
|
p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float()
|
||||||
|
p[:, :3, :3] = R
|
||||||
|
p[:, :3, 3] = cam_loc
|
||||||
|
else: # In case of pose matrix representation
|
||||||
|
cam_loc = pose[:, :3, 3]
|
||||||
|
p = pose
|
||||||
|
|
||||||
|
batch_size, num_samples, _ = uv.shape
|
||||||
|
|
||||||
|
depth = torch.ones((batch_size, num_samples))
|
||||||
|
x_cam = uv[:, :, 0].view(batch_size, -1)
|
||||||
|
y_cam = uv[:, :, 1].view(batch_size, -1)
|
||||||
|
z_cam = depth.view(batch_size, -1)
|
||||||
|
|
||||||
|
pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics)
|
||||||
|
|
||||||
|
# permute for batch matrix product
|
||||||
|
pixel_points_cam = pixel_points_cam.permute(0, 2, 1)
|
||||||
|
|
||||||
|
world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3]
|
||||||
|
ray_dirs = world_coords - cam_loc[:, None, :]
|
||||||
|
ray_dirs = F.normalize(ray_dirs, dim=2)
|
||||||
|
|
||||||
|
return ray_dirs, cam_loc
|
||||||
|
|
||||||
|
def quat_to_rot(q):
|
||||||
|
batch_size, _ = q.shape
|
||||||
|
q = F.normalize(q, dim=1)
|
||||||
|
R = torch.ones((batch_size, 3,3)).cuda()
|
||||||
|
qr=q[:,0]
|
||||||
|
qi = q[:, 1]
|
||||||
|
qj = q[:, 2]
|
||||||
|
qk = q[:, 3]
|
||||||
|
R[:, 0, 0]=1-2 * (qj**2 + qk**2)
|
||||||
|
R[:, 0, 1] = 2 * (qj *qi -qk*qr)
|
||||||
|
R[:, 0, 2] = 2 * (qi * qk + qr * qj)
|
||||||
|
R[:, 1, 0] = 2 * (qj * qi + qk * qr)
|
||||||
|
R[:, 1, 1] = 1-2 * (qi**2 + qk**2)
|
||||||
|
R[:, 1, 2] = 2*(qj*qk - qi*qr)
|
||||||
|
R[:, 2, 0] = 2 * (qk * qi-qj * qr)
|
||||||
|
R[:, 2, 1] = 2 * (qj*qk + qi*qr)
|
||||||
|
R[:, 2, 2] = 1-2 * (qi**2 + qj**2)
|
||||||
|
return R
|
||||||
|
|
||||||
|
def lift(x, y, z, intrinsics):
|
||||||
|
# parse intrinsics
|
||||||
|
# intrinsics = intrinsics.cuda()
|
||||||
|
fx = intrinsics[:, 0, 0]
|
||||||
|
fy = intrinsics[:, 1, 1]
|
||||||
|
cx = intrinsics[:, 0, 2]
|
||||||
|
cy = intrinsics[:, 1, 2]
|
||||||
|
sk = intrinsics[:, 0, 1]
|
||||||
|
|
||||||
|
x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z
|
||||||
|
y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z
|
||||||
|
|
||||||
|
# homogeneous
|
||||||
|
return torch.stack((x_lift, y_lift, z, torch.ones_like(z)), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class PixelNeRFDTUDataset(data.Dataset):
|
||||||
|
"""
|
||||||
|
Processed DTU from pixelNeRF
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
data_dir='data/DTU',
|
||||||
|
scan_id=65,
|
||||||
|
img_size=None,
|
||||||
|
device=None,
|
||||||
|
fixed_scale=0,
|
||||||
|
):
|
||||||
|
data_dir = os.path.join(data_dir, "scan{}".format(scan_id))
|
||||||
|
rgb_paths = [
|
||||||
|
x for x in glob(os.path.join(data_dir, "image", "*"))
|
||||||
|
if (x.endswith(".jpg") or x.endswith(".png"))
|
||||||
|
]
|
||||||
|
rgb_paths = sorted(rgb_paths)
|
||||||
|
mask_paths = sorted(glob(os.path.join(data_dir, "mask", "*.png")))
|
||||||
|
if len(mask_paths) == 0:
|
||||||
|
mask_paths = [None] * len(rgb_paths)
|
||||||
|
sel_indices = np.arange(len(rgb_paths))
|
||||||
|
|
||||||
|
cam_path = os.path.join(data_dir, "cameras.npz")
|
||||||
|
all_cam = np.load(cam_path)
|
||||||
|
all_imgs = []
|
||||||
|
all_poses = []
|
||||||
|
all_masks = []
|
||||||
|
all_rays = []
|
||||||
|
all_light_pose = []
|
||||||
|
all_K = []
|
||||||
|
all_R = []
|
||||||
|
all_T = []
|
||||||
|
|
||||||
|
for idx, (rgb_path, mask_path) in enumerate(zip(rgb_paths, mask_paths)):
|
||||||
|
|
||||||
|
i = sel_indices[idx]
|
||||||
|
rgb = load_rgb(rgb_path)
|
||||||
|
mask = load_mask(mask_path)
|
||||||
|
rgb[~mask] = 0.
|
||||||
|
rgb = torch.from_numpy(rgb).float().to(device)
|
||||||
|
mask = torch.from_numpy(mask).float().to(device)
|
||||||
|
x_scale = y_scale = 1.0
|
||||||
|
xy_delta = 0.0
|
||||||
|
|
||||||
|
P = all_cam["world_mat_" + str(i)]
|
||||||
|
P = P[:3]
|
||||||
|
|
||||||
|
# scale the original shape to really [-0.9, 0.9]
|
||||||
|
if fixed_scale!=0.:
|
||||||
|
scale_mat_new = np.eye(4, 4)
|
||||||
|
scale_mat_new[:3, :3] *= fixed_scale # scale to [-0.9, 0.9]
|
||||||
|
P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)] @ scale_mat_new
|
||||||
|
else:
|
||||||
|
P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)]
|
||||||
|
|
||||||
|
P = P[:3, :4]
|
||||||
|
K, R, t = cv2.decomposeProjectionMatrix(P)[:3]
|
||||||
|
K = K / K[2, 2]
|
||||||
|
|
||||||
|
fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
|
||||||
|
|
||||||
|
########!!!!!
|
||||||
|
RR = torch.from_numpy(R).permute(1, 0).unsqueeze(0)
|
||||||
|
tt = torch.from_numpy(-R@(t[:3] / t[3])).permute(1, 0)
|
||||||
|
focal = torch.tensor((fx, fy), dtype=torch.float32).unsqueeze(0)
|
||||||
|
pc = torch.tensor((cx, cy), dtype=torch.float32).unsqueeze(0)
|
||||||
|
im_size = (rgb.shape[1], rgb.shape[0])
|
||||||
|
|
||||||
|
# check https://pytorch3d.org/docs/cameras for how to transform from screen to NDC
|
||||||
|
s = min(im_size)
|
||||||
|
focal[:, 0] = focal[:, 0] * 2 / (s-1)
|
||||||
|
focal[:, 1] = focal[:, 1] * 2 /(s-1)
|
||||||
|
pc[:, 0] = -(pc[:, 0] - (im_size[0]-1)/2) * 2 / (s-1)
|
||||||
|
pc[:, 1] = -(pc[:, 1] - (im_size[1]-1)/2) * 2 / (s-1)
|
||||||
|
|
||||||
|
camera = PerspectiveCameras(focal_length=-focal, principal_point=pc,
|
||||||
|
device=device, R=RR, T=tt)
|
||||||
|
|
||||||
|
# calculate camera rays
|
||||||
|
uv = uv_creation(im_size)[None].float()
|
||||||
|
pose = np.eye(4, dtype=np.float32)
|
||||||
|
pose[:3, :3] = R.transpose()
|
||||||
|
pose[:3,3] = (t[:3] / t[3])[:,0]
|
||||||
|
pose = torch.from_numpy(pose)[None].float()
|
||||||
|
intrinsics = np.eye(4)
|
||||||
|
intrinsics[:3, :3] = K
|
||||||
|
intrinsics[0, 1] = 0. #! remove skew for now
|
||||||
|
intrinsics = torch.from_numpy(intrinsics)[None].float()
|
||||||
|
|
||||||
|
|
||||||
|
rays, _ = get_camera_params(uv, pose, intrinsics)
|
||||||
|
rays = -rays.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
all_poses.append(camera)
|
||||||
|
all_imgs.append(rgb)
|
||||||
|
all_masks.append(mask)
|
||||||
|
all_rays.append(rays)
|
||||||
|
all_light_pose.append(pose)
|
||||||
|
# only for neural renderer
|
||||||
|
all_K.append(torch.tensor(K).to(device))
|
||||||
|
all_R.append(torch.tensor(R).to(device))
|
||||||
|
all_T.append(torch.tensor(t[:3]/t[3]).to(device))
|
||||||
|
|
||||||
|
all_imgs = torch.stack(all_imgs)
|
||||||
|
all_masks = torch.stack(all_masks)
|
||||||
|
all_rays = torch.stack(all_rays)
|
||||||
|
all_light_pose = torch.stack(all_light_pose).squeeze()
|
||||||
|
# only for neural renderer
|
||||||
|
all_K = torch.stack(all_K).float()
|
||||||
|
all_R = torch.stack(all_R).float()
|
||||||
|
all_T = torch.stack(all_T).permute(0, 2, 1).float()
|
||||||
|
|
||||||
|
uv = uv_creation((all_imgs.size(2), all_imgs.size(1)))
|
||||||
|
self.data = {'rgbs': all_imgs,
|
||||||
|
'masks': all_masks,
|
||||||
|
'poses': all_poses,
|
||||||
|
'rays': all_rays,
|
||||||
|
'uv': uv,
|
||||||
|
'light_pose': all_light_pose, # for rendering lights
|
||||||
|
'K': all_K,
|
||||||
|
'R': all_R,
|
||||||
|
'T': all_T,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.data
|
75
src/dpsr.py
Normal file
75
src/dpsr.py
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from src.utils import spec_gaussian_filter, fftfreqs, img, grid_interp, point_rasterize
|
||||||
|
import numpy as np
|
||||||
|
import torch.fft
|
||||||
|
|
||||||
|
class DPSR(nn.Module):
|
||||||
|
def __init__(self, res, sig=10, scale=True, shift=True):
|
||||||
|
"""
|
||||||
|
:param res: tuple of output field resolution. eg., (128,128)
|
||||||
|
:param sig: degree of gaussian smoothing
|
||||||
|
"""
|
||||||
|
super(DPSR, self).__init__()
|
||||||
|
self.res = res
|
||||||
|
self.sig = sig
|
||||||
|
self.dim = len(res)
|
||||||
|
self.denom = np.prod(res)
|
||||||
|
G = spec_gaussian_filter(res=res, sig=sig).float()
|
||||||
|
G = G
|
||||||
|
# self.G.requires_grad = False # True, if we also make sig a learnable parameter
|
||||||
|
self.omega = fftfreqs(res, dtype=torch.float32)
|
||||||
|
self.scale = scale
|
||||||
|
self.shift = shift
|
||||||
|
self.register_buffer("G", G)
|
||||||
|
|
||||||
|
def forward(self, V, N):
|
||||||
|
"""
|
||||||
|
:param V: (batch, nv, 2 or 3) tensor for point cloud coordinates
|
||||||
|
:param N: (batch, nv, 2 or 3) tensor for point normals
|
||||||
|
:return phi: (batch, res, res, ...) tensor of output indicator function field
|
||||||
|
"""
|
||||||
|
assert(V.shape == N.shape) # [b, nv, ndims]
|
||||||
|
ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2]
|
||||||
|
|
||||||
|
#!!! OLD
|
||||||
|
# ras_s = torch.rfft(ras_p, signal_ndim=self.dim) # [b, n_dim, dim0, dim1, dim2/2+1, 2]
|
||||||
|
# ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+2))+[1, self.dim+2]))
|
||||||
|
# N_ = (ras_s * self.G) # [b, n_dim, dim0, dim1, dim2/2+1, 2]
|
||||||
|
|
||||||
|
ras_s = torch.fft.rfftn(ras_p, dim=(2,3,4))
|
||||||
|
ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+1))+[self.dim+1, 1]))
|
||||||
|
N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1]
|
||||||
|
|
||||||
|
omega = fftfreqs(self.res, dtype=torch.float32).unsqueeze(-1) # [dim0, dim1, dim2/2+1, n_dim, 1]
|
||||||
|
omega *= 2 * np.pi # normalize frequencies
|
||||||
|
omega = omega.to(V.device)
|
||||||
|
|
||||||
|
# DivN = torch.sum(-img(N_) * omega, dim=-2) #!!! OLD [b, dim0, dim1, dim2/2+1, 2]
|
||||||
|
DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2)
|
||||||
|
|
||||||
|
Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1]
|
||||||
|
Phi = DivN / (Lap+1e-6) # [b, dim0, dim1, dim2/2+1, 2]
|
||||||
|
Phi = Phi.permute(*tuple([list(range(1,self.dim+2)) + [0]])) # [dim0, dim1, dim2/2+1, 2, b]
|
||||||
|
Phi[tuple([0] * self.dim)] = 0
|
||||||
|
Phi = Phi.permute(*tuple([[self.dim+1] + list(range(self.dim+1))])) # [b, dim0, dim1, dim2/2+1, 2]
|
||||||
|
|
||||||
|
# phi = torch.irfft(Phi, signal_ndim=self.dim, signal_sizes=self.res)#!!! OLD [b, dim0, dim1, dim2]
|
||||||
|
phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1,2,3))
|
||||||
|
|
||||||
|
if self.shift or self.scale:
|
||||||
|
# ensure values at points are zero
|
||||||
|
fv = grid_interp(phi.unsqueeze(-1), V, batched=True).squeeze(-1) # [b, nv]
|
||||||
|
if self.shift: # offset points to have mean of 0
|
||||||
|
offset = torch.mean(fv, dim=-1) # [b,]
|
||||||
|
phi -= offset.view(*tuple([-1] + [1] * self.dim))
|
||||||
|
|
||||||
|
phi = phi.permute(*tuple([list(range(1,self.dim+1)) + [0]]))
|
||||||
|
fv0 = phi[tuple([0] * self.dim)] # [b,]
|
||||||
|
phi = phi.permute(*tuple([[self.dim] + list(range(self.dim))]))
|
||||||
|
|
||||||
|
if self.scale:
|
||||||
|
# phi = phi / fv0.view(*tuple([-1] + [1] * self.dim)) * 0.5
|
||||||
|
phi = -phi / torch.abs(fv0.view(*tuple([-1]+[1] * self.dim))) *0.5
|
||||||
|
return phi
|
168
src/eval.py
Normal file
168
src/eval.py
Normal file
|
@ -0,0 +1,168 @@
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import trimesh
|
||||||
|
from pykdtree.kdtree import KDTree
|
||||||
|
|
||||||
|
EMPTY_PCL_DICT = {
|
||||||
|
'completeness': np.sqrt(3),
|
||||||
|
'accuracy': np.sqrt(3),
|
||||||
|
'completeness2': 3,
|
||||||
|
'accuracy2': 3,
|
||||||
|
'chamfer': 6,
|
||||||
|
}
|
||||||
|
|
||||||
|
EMPTY_PCL_DICT_NORMALS = {
|
||||||
|
'normals completeness': -1.,
|
||||||
|
'normals accuracy': -1.,
|
||||||
|
'normals': -1.,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MeshEvaluator(object):
|
||||||
|
''' Mesh evaluation class.
|
||||||
|
It handles the mesh evaluation process.
|
||||||
|
Args:
|
||||||
|
n_points (int): number of points to be used for evaluation
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, n_points=100000):
|
||||||
|
self.n_points = n_points
|
||||||
|
|
||||||
|
def eval_mesh(self, mesh, pointcloud_tgt, normals_tgt, thresholds=np.linspace(1./1000, 1, 1000)):
|
||||||
|
''' Evaluates a mesh.
|
||||||
|
Args:
|
||||||
|
mesh (trimesh): mesh which should be evaluated
|
||||||
|
pointcloud_tgt (numpy array): target point cloud
|
||||||
|
normals_tgt (numpy array): target normals
|
||||||
|
thresholds (numpy arry): for F-Score
|
||||||
|
'''
|
||||||
|
if len(mesh.vertices) != 0 and len(mesh.faces) != 0:
|
||||||
|
pointcloud, idx = mesh.sample(self.n_points, return_index=True)
|
||||||
|
|
||||||
|
pointcloud = pointcloud.astype(np.float32)
|
||||||
|
normals = mesh.face_normals[idx]
|
||||||
|
else:
|
||||||
|
pointcloud = np.empty((0, 3))
|
||||||
|
normals = np.empty((0, 3))
|
||||||
|
|
||||||
|
out_dict = self.eval_pointcloud(
|
||||||
|
pointcloud, pointcloud_tgt, normals, normals_tgt, thresholds=thresholds)
|
||||||
|
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
def eval_pointcloud(self, pointcloud, pointcloud_tgt,
|
||||||
|
normals=None, normals_tgt=None,
|
||||||
|
thresholds=np.linspace(1./1000, 1, 1000)):
|
||||||
|
''' Evaluates a point cloud.
|
||||||
|
Args:
|
||||||
|
pointcloud (numpy array): predicted point cloud
|
||||||
|
pointcloud_tgt (numpy array): target point cloud
|
||||||
|
normals (numpy array): predicted normals
|
||||||
|
normals_tgt (numpy array): target normals
|
||||||
|
thresholds (numpy array): threshold values for the F-score calculation
|
||||||
|
'''
|
||||||
|
# Return maximum losses if pointcloud is empty
|
||||||
|
if pointcloud.shape[0] == 0:
|
||||||
|
logger.warn('Empty pointcloud / mesh detected!')
|
||||||
|
out_dict = EMPTY_PCL_DICT.copy()
|
||||||
|
if normals is not None and normals_tgt is not None:
|
||||||
|
out_dict.update(EMPTY_PCL_DICT_NORMALS)
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
pointcloud = np.asarray(pointcloud)
|
||||||
|
pointcloud_tgt = np.asarray(pointcloud_tgt)
|
||||||
|
|
||||||
|
|
||||||
|
# Completeness: how far are the points of the target point cloud
|
||||||
|
# from thre predicted point cloud
|
||||||
|
completeness, completeness_normals = distance_p2p(
|
||||||
|
pointcloud_tgt, normals_tgt, pointcloud, normals
|
||||||
|
)
|
||||||
|
recall = get_threshold_percentage(completeness, thresholds)
|
||||||
|
completeness2 = completeness**2
|
||||||
|
|
||||||
|
completeness = completeness.mean()
|
||||||
|
completeness2 = completeness2.mean()
|
||||||
|
completeness_normals = completeness_normals.mean()
|
||||||
|
|
||||||
|
# Accuracy: how far are th points of the predicted pointcloud
|
||||||
|
# from the target pointcloud
|
||||||
|
accuracy, accuracy_normals = distance_p2p(
|
||||||
|
pointcloud, normals, pointcloud_tgt, normals_tgt
|
||||||
|
)
|
||||||
|
precision = get_threshold_percentage(accuracy, thresholds)
|
||||||
|
accuracy2 = accuracy**2
|
||||||
|
|
||||||
|
accuracy = accuracy.mean()
|
||||||
|
accuracy2 = accuracy2.mean()
|
||||||
|
accuracy_normals = accuracy_normals.mean()
|
||||||
|
|
||||||
|
# Chamfer distance
|
||||||
|
chamferL2 = 0.5 * (completeness2 + accuracy2)
|
||||||
|
normals_correctness = (
|
||||||
|
0.5 * completeness_normals + 0.5 * accuracy_normals
|
||||||
|
)
|
||||||
|
chamferL1 = 0.5 * (completeness + accuracy)
|
||||||
|
|
||||||
|
# F-Score
|
||||||
|
F = [
|
||||||
|
2 * precision[i] * recall[i] / (precision[i] + recall[i])
|
||||||
|
for i in range(len(precision))
|
||||||
|
]
|
||||||
|
|
||||||
|
out_dict = {
|
||||||
|
'completeness': completeness,
|
||||||
|
'accuracy': accuracy,
|
||||||
|
'normals completeness': completeness_normals,
|
||||||
|
'normals accuracy': accuracy_normals,
|
||||||
|
'normals': normals_correctness,
|
||||||
|
'completeness2': completeness2,
|
||||||
|
'accuracy2': accuracy2,
|
||||||
|
'chamfer-L2': chamferL2,
|
||||||
|
'chamfer-L1': chamferL1,
|
||||||
|
'f-score': F[9], # threshold = 1.0%
|
||||||
|
'f-score-15': F[14], # threshold = 1.5%
|
||||||
|
'f-score-20': F[19], # threshold = 2.0%
|
||||||
|
}
|
||||||
|
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
|
def distance_p2p(points_src, normals_src, points_tgt, normals_tgt):
|
||||||
|
''' Computes minimal distances of each point in points_src to points_tgt.
|
||||||
|
Args:
|
||||||
|
points_src (numpy array): source points
|
||||||
|
normals_src (numpy array): source normals
|
||||||
|
points_tgt (numpy array): target points
|
||||||
|
normals_tgt (numpy array): target normals
|
||||||
|
'''
|
||||||
|
kdtree = KDTree(points_tgt)
|
||||||
|
dist, idx = kdtree.query(points_src)
|
||||||
|
|
||||||
|
if normals_src is not None and normals_tgt is not None:
|
||||||
|
normals_src = \
|
||||||
|
normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True)
|
||||||
|
normals_tgt = \
|
||||||
|
normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True)
|
||||||
|
|
||||||
|
normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1)
|
||||||
|
# Handle normals that point into wrong direction gracefully
|
||||||
|
# (mostly due to mehtod not caring about this in generation)
|
||||||
|
normals_dot_product = np.abs(normals_dot_product)
|
||||||
|
else:
|
||||||
|
normals_dot_product = np.array(
|
||||||
|
[np.nan] * points_src.shape[0], dtype=np.float32)
|
||||||
|
return dist, normals_dot_product
|
||||||
|
|
||||||
|
def get_threshold_percentage(dist, thresholds):
|
||||||
|
''' Evaluates a point cloud.
|
||||||
|
Args:
|
||||||
|
dist (numpy array): calculated distance
|
||||||
|
thresholds (numpy array): threshold values for the F-score calculation
|
||||||
|
'''
|
||||||
|
in_threshold = [
|
||||||
|
(dist <= t).mean() for t in thresholds
|
||||||
|
]
|
||||||
|
return in_threshold
|
63
src/generation.py
Normal file
63
src/generation.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import trimesh
|
||||||
|
import numpy as np
|
||||||
|
from src.utils import mc_from_psr
|
||||||
|
|
||||||
|
class Generator3D(object):
|
||||||
|
''' Generator class for Occupancy Networks.
|
||||||
|
|
||||||
|
It provides functions to generate the final mesh as well refining options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): trained Occupancy Network model
|
||||||
|
points_batch_size (int): batch size for points evaluation
|
||||||
|
threshold (float): threshold value
|
||||||
|
device (device): pytorch device
|
||||||
|
padding (float): how much padding should be used for MISE
|
||||||
|
sample (bool): whether z should be sampled
|
||||||
|
input_type (str): type of input
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, model, points_batch_size=100000,
|
||||||
|
threshold=0.5, device=None, padding=0.1,
|
||||||
|
sample=False, input_type = None, dpsr=None, psr_tanh=True):
|
||||||
|
self.model = model.to(device)
|
||||||
|
self.points_batch_size = points_batch_size
|
||||||
|
self.threshold = threshold
|
||||||
|
self.device = device
|
||||||
|
self.input_type = input_type
|
||||||
|
self.padding = padding
|
||||||
|
self.sample = sample
|
||||||
|
self.dpsr = dpsr
|
||||||
|
self.psr_tanh = psr_tanh
|
||||||
|
|
||||||
|
def generate_mesh(self, data, return_stats=True):
|
||||||
|
''' Generates the output mesh.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (tensor): data tensor
|
||||||
|
return_stats (bool): whether stats should be returned
|
||||||
|
'''
|
||||||
|
self.model.eval()
|
||||||
|
device = self.device
|
||||||
|
stats_dict = {}
|
||||||
|
|
||||||
|
p = data.get('inputs', torch.empty(1, 0)).to(device)
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
points, normals = self.model(p)
|
||||||
|
t1 = time.time()
|
||||||
|
psr_grid = self.dpsr(points, normals)
|
||||||
|
t2 = time.time()
|
||||||
|
v, f, _ = mc_from_psr(psr_grid,
|
||||||
|
zero_level=self.threshold)
|
||||||
|
stats_dict['pcl'] = t1 - t0
|
||||||
|
stats_dict['dpsr'] = t2 - t1
|
||||||
|
stats_dict['mc'] = time.time() - t2
|
||||||
|
stats_dict['total'] = time.time() - t0
|
||||||
|
|
||||||
|
if return_stats:
|
||||||
|
return v, f, points, normals, stats_dict
|
||||||
|
else:
|
||||||
|
return v, f, points, normals
|
181
src/model.py
Normal file
181
src/model.py
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
from src.utils import point_rasterize, grid_interp, mc_from_psr, \
|
||||||
|
calc_inters_points
|
||||||
|
from src.dpsr import DPSR
|
||||||
|
import torch.nn as nn
|
||||||
|
from src.network import encoder_dict, decoder_dict
|
||||||
|
from src.network.utils import map2local
|
||||||
|
|
||||||
|
class PSR2Mesh(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, psr_grid):
|
||||||
|
"""
|
||||||
|
In the forward pass we receive a Tensor containing the input and return
|
||||||
|
a Tensor containing the output. ctx is a context object that can be used
|
||||||
|
to stash information for backward computation. You can cache arbitrary
|
||||||
|
objects for use in the backward pass using the ctx.save_for_backward method.
|
||||||
|
"""
|
||||||
|
verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
|
||||||
|
verts = verts.unsqueeze(0)
|
||||||
|
faces = faces.unsqueeze(0)
|
||||||
|
normals = normals.unsqueeze(0)
|
||||||
|
|
||||||
|
res = torch.tensor(psr_grid.detach().shape[2])
|
||||||
|
ctx.save_for_backward(verts, normals, res)
|
||||||
|
|
||||||
|
return verts, faces, normals
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals):
|
||||||
|
"""
|
||||||
|
In the backward pass we receive a Tensor containing the gradient of the loss
|
||||||
|
with respect to the output, and we need to compute the gradient of the loss
|
||||||
|
with respect to the input.
|
||||||
|
"""
|
||||||
|
vert_pts, normals, res = ctx.saved_tensors
|
||||||
|
res = (res.item(), res.item(), res.item())
|
||||||
|
# matrix multiplication between dL/dV and dV/dPSR
|
||||||
|
# dV/dPSR = - normals
|
||||||
|
grad_vert = torch.matmul(dL_dVertex.permute(1, 0, 2), -normals.permute(1, 2, 0))
|
||||||
|
grad_grid = point_rasterize(vert_pts, grad_vert.permute(1, 0, 2), res) # b x 1 x res x res x res
|
||||||
|
|
||||||
|
return grad_grid
|
||||||
|
|
||||||
|
class PSR2SurfacePoints(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample):
|
||||||
|
verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
|
||||||
|
verts = verts * 2. - 1. # within the range of [-1, 1]
|
||||||
|
|
||||||
|
|
||||||
|
p_all, n_all, mask_all = [], [], []
|
||||||
|
|
||||||
|
for i in range(len(poses)):
|
||||||
|
pose = poses[i]
|
||||||
|
if mask_sample is not None:
|
||||||
|
p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size, mask_gt=mask_sample[i])
|
||||||
|
else:
|
||||||
|
p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size)
|
||||||
|
|
||||||
|
n_inters = grid_interp(psr_grad[None], (p_inters[None].detach() + 1) / 2).squeeze()
|
||||||
|
p_all.append(p_inters)
|
||||||
|
n_all.append(n_inters)
|
||||||
|
mask_all.append(mask)
|
||||||
|
p_inters_all = torch.cat(p_all, dim=0)
|
||||||
|
n_inters_all = torch.cat(n_all, dim=0)
|
||||||
|
mask_visible = torch.stack(mask_all, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
res = torch.tensor(psr_grid.detach().shape[2])
|
||||||
|
ctx.save_for_backward(p_inters_all, n_inters_all, res)
|
||||||
|
|
||||||
|
return p_inters_all, mask_visible
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, dL_dp, dL_dmask):
|
||||||
|
pts, pts_n, res = ctx.saved_tensors
|
||||||
|
res = (res.item(), res.item(), res.item())
|
||||||
|
|
||||||
|
# grad from the p_inters via MLP renderer
|
||||||
|
grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None])
|
||||||
|
grad_grid_pts = point_rasterize((pts[None]+1)/2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res
|
||||||
|
|
||||||
|
return grad_grid_pts, None, None, None, None, None
|
||||||
|
|
||||||
|
class Encode2Points(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
encoder = cfg['model']['encoder']
|
||||||
|
decoder = cfg['model']['decoder']
|
||||||
|
dim = cfg['data']['dim'] # input dim
|
||||||
|
c_dim = cfg['model']['c_dim']
|
||||||
|
encoder_kwargs = cfg['model']['encoder_kwargs']
|
||||||
|
if encoder_kwargs == None:
|
||||||
|
encoder_kwargs = {}
|
||||||
|
decoder_kwargs = cfg['model']['decoder_kwargs']
|
||||||
|
padding = cfg['data']['padding']
|
||||||
|
self.predict_normal = cfg['model']['predict_normal']
|
||||||
|
self.predict_offset = cfg['model']['predict_offset']
|
||||||
|
|
||||||
|
out_dim = 3
|
||||||
|
out_dim_offset = 3
|
||||||
|
num_offset = cfg['data']['num_offset']
|
||||||
|
# each point predict more than one offset to add output points
|
||||||
|
if num_offset > 1:
|
||||||
|
out_dim_offset = out_dim * num_offset
|
||||||
|
self.num_offset = num_offset
|
||||||
|
|
||||||
|
# local mapping
|
||||||
|
self.map2local = None
|
||||||
|
if cfg['model']['local_coord']:
|
||||||
|
if 'unet' in encoder_kwargs.keys():
|
||||||
|
unit_size = 1 / encoder_kwargs['plane_resolution']
|
||||||
|
else:
|
||||||
|
unit_size = 1 / encoder_kwargs['grid_resolution']
|
||||||
|
|
||||||
|
local_mapping = map2local(unit_size)
|
||||||
|
|
||||||
|
self.encoder = encoder_dict[encoder](
|
||||||
|
dim=dim, c_dim=c_dim, map2local=local_mapping,
|
||||||
|
**encoder_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.predict_normal:
|
||||||
|
# decoder for normal prediction
|
||||||
|
self.decoder_normal = decoder_dict[decoder](
|
||||||
|
dim=dim, c_dim=c_dim, out_dim=out_dim,
|
||||||
|
**decoder_kwargs)
|
||||||
|
if self.predict_offset:
|
||||||
|
# decoder for offset prediction
|
||||||
|
self.decoder_offset = decoder_dict[decoder](
|
||||||
|
dim=dim, c_dim=c_dim, out_dim=out_dim_offset,
|
||||||
|
map2local=local_mapping,
|
||||||
|
**decoder_kwargs)
|
||||||
|
|
||||||
|
self.s_off = cfg['model']['s_offset']
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, p):
|
||||||
|
''' Performs a forward pass through the network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
p (tensor): input unoriented points
|
||||||
|
'''
|
||||||
|
|
||||||
|
time_dict = {}
|
||||||
|
mask = None
|
||||||
|
|
||||||
|
batch_size = p.size(0)
|
||||||
|
points = p.clone()
|
||||||
|
|
||||||
|
# encode the input point cloud to a feature volume
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
c = self.encoder(p)
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
if self.predict_offset:
|
||||||
|
offset = self.decoder_offset(p, c)
|
||||||
|
# more than one offset is predicted per-point
|
||||||
|
if self.num_offset > 1:
|
||||||
|
points = points.repeat(1, 1, self.num_offset).reshape(batch_size, -1, 3)
|
||||||
|
points = points + self.s_off * offset
|
||||||
|
else:
|
||||||
|
points = p
|
||||||
|
|
||||||
|
if self.predict_normal:
|
||||||
|
normals = self.decoder_normal(points, c)
|
||||||
|
t2 = time.perf_counter()
|
||||||
|
|
||||||
|
time_dict['encode'] = t1 - t0
|
||||||
|
time_dict['predict'] = t2 - t1
|
||||||
|
|
||||||
|
points = torch.clamp(points, 0.0, 0.99)
|
||||||
|
if self.cfg['model']['normal_normalize']:
|
||||||
|
normals = normals / (normals.norm(dim=-1, keepdim=True)+1e-8)
|
||||||
|
|
||||||
|
|
||||||
|
return points, normals
|
||||||
|
|
139
src/model_rgb.py
Normal file
139
src/model_rgb.py
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
import torch
|
||||||
|
from src.network.net_rgb import RenderingNetwork
|
||||||
|
from src.utils import approx_psr_grad
|
||||||
|
from pytorch3d.renderer import (
|
||||||
|
RasterizationSettings,
|
||||||
|
PerspectiveCameras,
|
||||||
|
MeshRenderer,
|
||||||
|
MeshRasterizer,
|
||||||
|
SoftSilhouetteShader)
|
||||||
|
from pytorch3d.structures import Meshes
|
||||||
|
|
||||||
|
|
||||||
|
def approx_psr_grad(psr_grid, res, normalize=True):
|
||||||
|
delta_x = delta_y = delta_z = 1/res
|
||||||
|
psr_pad = torch.nn.ReplicationPad3d(1)(psr_grid).squeeze()
|
||||||
|
|
||||||
|
grad_x = (psr_pad[2:, :, :] - psr_pad[:-2, :, :]) / 2 / delta_x
|
||||||
|
grad_y = (psr_pad[:, 2:, :] - psr_pad[:, :-2, :]) / 2 / delta_y
|
||||||
|
grad_z = (psr_pad[:, :, 2:] - psr_pad[:, :, :-2]) / 2 / delta_z
|
||||||
|
grad_x = grad_x[:, 1:-1, 1:-1]
|
||||||
|
grad_y = grad_y[1:-1, :, 1:-1]
|
||||||
|
grad_z = grad_z[1:-1, 1:-1, :]
|
||||||
|
|
||||||
|
psr_grad = torch.stack([grad_x, grad_y, grad_z], dim=3) # [res_x, res_y, res_z, 3]
|
||||||
|
if normalize:
|
||||||
|
psr_grad = psr_grad / (psr_grad.norm(dim=3, keepdim=True) + 1e-12)
|
||||||
|
|
||||||
|
return psr_grad
|
||||||
|
|
||||||
|
|
||||||
|
class SAP2Image(nn.Module):
|
||||||
|
def __init__(self, cfg, img_size):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.psr2sur = PSR2SurfacePoints.apply
|
||||||
|
self.psr2mesh = PSR2Mesh.apply
|
||||||
|
# initialize DPSR
|
||||||
|
self.dpsr = DPSR(res=(cfg['model']['grid_res'],
|
||||||
|
cfg['model']['grid_res'],
|
||||||
|
cfg['model']['grid_res']),
|
||||||
|
sig=cfg['model']['psr_sigma'])
|
||||||
|
self.cfg = cfg
|
||||||
|
if cfg['train']['l_weight']['rgb'] != 0.:
|
||||||
|
self.rendering_network = RenderingNetwork(**cfg['model']['renderer'])
|
||||||
|
|
||||||
|
if cfg['train']['l_weight']['mask'] != 0.:
|
||||||
|
# initialize rasterizer
|
||||||
|
sigma = 1e-4
|
||||||
|
raster_settings_soft = RasterizationSettings(
|
||||||
|
image_size=img_size,
|
||||||
|
blur_radius=np.log(1. / 1e-4 - 1.)*sigma,
|
||||||
|
faces_per_pixel=150,
|
||||||
|
perspective_correct=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# initialize silhouette renderer
|
||||||
|
self.mesh_rasterizer = MeshRenderer(
|
||||||
|
rasterizer=MeshRasterizer(
|
||||||
|
raster_settings=raster_settings_soft
|
||||||
|
),
|
||||||
|
shader=SoftSilhouetteShader()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cfg = cfg
|
||||||
|
self.img_size = img_size
|
||||||
|
|
||||||
|
def forward(self, inputs, data):
|
||||||
|
points, normals = inputs[...,:3], inputs[...,3:]
|
||||||
|
points = torch.sigmoid(points)
|
||||||
|
normals = normals / normals.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# DPSR to get grid
|
||||||
|
psr_grid = self.dpsr(points, normals).unsqueeze(1)
|
||||||
|
psr_grid = torch.tanh(psr_grid)
|
||||||
|
|
||||||
|
return self.render_img(psr_grid, data)
|
||||||
|
|
||||||
|
def render_img(self, psr_grid, data):
|
||||||
|
|
||||||
|
n_views = len(data['masks'])
|
||||||
|
n_views_per_iter = self.cfg['data']['n_views_per_iter']
|
||||||
|
|
||||||
|
rgb_render_mode = self.cfg['model']['renderer']['mode']
|
||||||
|
uv = data['uv']
|
||||||
|
|
||||||
|
idx = np.random.randint(0, n_views, n_views_per_iter)
|
||||||
|
pose = [data['poses'][i] for i in idx]
|
||||||
|
rgb = data['rgbs'][idx]
|
||||||
|
mask_gt = data['masks'][idx]
|
||||||
|
ray = None
|
||||||
|
pred_rgb = None
|
||||||
|
pred_mask = None
|
||||||
|
|
||||||
|
if self.cfg['train']['l_weight']['rgb'] != 0.:
|
||||||
|
psr_grad = approx_psr_grad(psr_grid, self.cfg['model']['grid_res'])
|
||||||
|
p_inters, visible_mask = self.psr2sur(psr_grid, pose, self.img_size, uv, psr_grad, None)
|
||||||
|
n_inters = grid_interp(psr_grad[None], (p_inters.detach()[None] + 1) / 2)
|
||||||
|
fea_interp = None
|
||||||
|
if 'rays' in data.keys():
|
||||||
|
ray = data['rays'].squeeze()[idx][visible_mask]
|
||||||
|
pred_rgb = self.rendering_network(p_inters, normals=n_inters.squeeze(), view_dirs=ray, feature_vectors=fea_interp)
|
||||||
|
|
||||||
|
# silhouette loss
|
||||||
|
if self.cfg['train']['l_weight']['mask'] != 0.:
|
||||||
|
# build mesh
|
||||||
|
v, f, _ = self.psr2mesh(psr_grid)
|
||||||
|
v = v * 2. - 1 # within the range of [-1, 1]
|
||||||
|
# ! Fast but more GPU usage
|
||||||
|
mesh = Meshes(verts=[v.squeeze()], faces=[f.squeeze()])
|
||||||
|
if True:
|
||||||
|
#! PyTorch3D silhouette loss
|
||||||
|
# build pose
|
||||||
|
R = torch.cat([p.R for p in pose], dim=0)
|
||||||
|
T = torch.cat([p.T for p in pose], dim=0)
|
||||||
|
focal = torch.cat([p.focal_length for p in pose], dim=0)
|
||||||
|
pp = torch.cat([p.principal_point for p in pose], dim=0)
|
||||||
|
pose_cur = PerspectiveCameras(
|
||||||
|
focal_length=focal,
|
||||||
|
principal_point=pp,
|
||||||
|
R=R, T=T,
|
||||||
|
device=R.device)
|
||||||
|
pred_mask = self.mesh_rasterizer(mesh.extend(n_views_per_iter), cameras=pose_cur)[..., 3]
|
||||||
|
else:
|
||||||
|
pred_mask = []
|
||||||
|
# ! Slow but less GPU usage
|
||||||
|
for i in range(n_views_per_iter):
|
||||||
|
#! PyTorch3D silhouette loss
|
||||||
|
pred_mask.append(self.mesh_rasterizer(mesh, cameras=pose[i])[..., 3])
|
||||||
|
pred_mask = torch.cat(pred_mask, dim=0)
|
||||||
|
|
||||||
|
output = {
|
||||||
|
'rgb': pred_rgb,
|
||||||
|
'rgb_gt': rgb,
|
||||||
|
'mask': pred_mask,
|
||||||
|
'mask_gt': mask_gt,
|
||||||
|
'vis_mask': visible_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
return output
|
8
src/network/__init__.py
Normal file
8
src/network/__init__.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
from src.network import encoder, decoder
|
||||||
|
|
||||||
|
encoder_dict = {
|
||||||
|
'local_pool_pointnet': encoder.LocalPoolPointnet,
|
||||||
|
}
|
||||||
|
decoder_dict = {
|
||||||
|
'simple_local': decoder.LocalDecoder,
|
||||||
|
}
|
106
src/network/decoder.py
Normal file
106
src/network/decoder.py
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
from ipdb import set_trace as st
|
||||||
|
from src.network.utils import normalize_3d_coordinate, ResnetBlockFC, \
|
||||||
|
normalize_coordinate
|
||||||
|
|
||||||
|
|
||||||
|
class LocalDecoder(nn.Module):
|
||||||
|
''' Decoder.
|
||||||
|
Instead of conditioning on global features, on plane/volume local features.
|
||||||
|
Args:
|
||||||
|
dim (int): input dimension
|
||||||
|
c_dim (int): dimension of latent conditioned code c
|
||||||
|
hidden_size (int): hidden size of Decoder network
|
||||||
|
n_blocks (int): number of blocks ResNetBlockFC layers
|
||||||
|
leaky (bool): whether to use leaky ReLUs
|
||||||
|
sample_mode (str): sampling feature strategy, bilinear|nearest
|
||||||
|
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, dim=3, c_dim=128, out_dim=3,
|
||||||
|
hidden_size=256, n_blocks=5, leaky=False, sample_mode='bilinear', padding=0.1, map2local=None):
|
||||||
|
super().__init__()
|
||||||
|
self.c_dim = c_dim
|
||||||
|
self.n_blocks = n_blocks
|
||||||
|
|
||||||
|
if c_dim != 0:
|
||||||
|
self.fc_c = nn.ModuleList([
|
||||||
|
nn.Linear(c_dim, hidden_size) for i in range(n_blocks)
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
self.fc_p = nn.Linear(dim, hidden_size)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
ResnetBlockFC(hidden_size) for i in range(n_blocks)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.fc_out = nn.Linear(hidden_size, out_dim)
|
||||||
|
|
||||||
|
if not leaky:
|
||||||
|
self.actvn = F.relu
|
||||||
|
else:
|
||||||
|
self.actvn = lambda x: F.leaky_relu(x, 0.2)
|
||||||
|
|
||||||
|
self.sample_mode = sample_mode
|
||||||
|
self.padding = padding
|
||||||
|
self.map2local = map2local
|
||||||
|
self.out_dim = out_dim
|
||||||
|
|
||||||
|
|
||||||
|
def sample_plane_feature(self, p, c, plane='xz'):
|
||||||
|
xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
|
||||||
|
xy = xy[:, :, None].float()
|
||||||
|
vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
|
||||||
|
c = F.grid_sample(c, vgrid, padding_mode='border',
|
||||||
|
align_corners=True,
|
||||||
|
mode=self.sample_mode).squeeze(-1)
|
||||||
|
return c
|
||||||
|
|
||||||
|
def sample_grid_feature(self, p, c):
|
||||||
|
p_nor = normalize_3d_coordinate(p.clone())
|
||||||
|
p_nor = p_nor[:, :, None, None].float()
|
||||||
|
vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1)
|
||||||
|
# acutally trilinear interpolation if mode = 'bilinear'
|
||||||
|
c = F.grid_sample(c, vgrid, padding_mode='border',
|
||||||
|
align_corners=True,
|
||||||
|
mode=self.sample_mode).squeeze(-1).squeeze(-1)
|
||||||
|
return c
|
||||||
|
|
||||||
|
def forward(self, p, c_plane, **kwargs):
|
||||||
|
batch_size = p.shape[0]
|
||||||
|
plane_type = list(c_plane.keys())
|
||||||
|
c = 0
|
||||||
|
if 'grid' in plane_type:
|
||||||
|
c += self.sample_grid_feature(p, c_plane['grid'])
|
||||||
|
if 'xz' in plane_type:
|
||||||
|
c += self.sample_plane_feature(p, c_plane['xz'], plane='xz')
|
||||||
|
if 'xy' in plane_type:
|
||||||
|
c += self.sample_plane_feature(p, c_plane['xy'], plane='xy')
|
||||||
|
if 'yz' in plane_type:
|
||||||
|
c += self.sample_plane_feature(p, c_plane['yz'], plane='yz')
|
||||||
|
c = c.transpose(1, 2)
|
||||||
|
|
||||||
|
p = p.float()
|
||||||
|
|
||||||
|
if self.map2local:
|
||||||
|
p = self.map2local(p)
|
||||||
|
|
||||||
|
net = self.fc_p(p)
|
||||||
|
|
||||||
|
for i in range(self.n_blocks):
|
||||||
|
if self.c_dim != 0:
|
||||||
|
net = net + self.fc_c[i](c)
|
||||||
|
|
||||||
|
net = self.blocks[i](net)
|
||||||
|
|
||||||
|
out = self.fc_out(self.actvn(net))
|
||||||
|
|
||||||
|
|
||||||
|
if self.out_dim > 3:
|
||||||
|
out = out.reshape(batch_size, -1, 3)
|
||||||
|
|
||||||
|
return out
|
181
src/network/encoder.py
Normal file
181
src/network/encoder.py
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from src.network.unet3d import UNet3D
|
||||||
|
from src.network.unet import UNet
|
||||||
|
from ipdb import set_trace as st
|
||||||
|
from torch_scatter import scatter_mean, scatter_max
|
||||||
|
from src.network.utils import get_embedder, normalize_3d_coordinate,\
|
||||||
|
coordinate2index, ResnetBlockFC, normalize_coordinate
|
||||||
|
|
||||||
|
class LocalPoolPointnet(nn.Module):
|
||||||
|
''' PointNet-based encoder network with ResNet blocks for each point.
|
||||||
|
Number of input points are fixed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
c_dim (int): dimension of latent code c
|
||||||
|
dim (int): input points dimension
|
||||||
|
hidden_dim (int): hidden dimension of the network
|
||||||
|
scatter_type (str): feature aggregation when doing local pooling
|
||||||
|
unet (bool): weather to use U-Net
|
||||||
|
unet_kwargs (str): U-Net parameters
|
||||||
|
unet3d (bool): weather to use 3D U-Net
|
||||||
|
unet3d_kwargs (str): 3D U-Net parameters
|
||||||
|
plane_resolution (int): defined resolution for plane feature
|
||||||
|
grid_resolution (int): defined resolution for grid feature
|
||||||
|
plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
|
||||||
|
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
||||||
|
n_blocks (int): number of blocks ResNetBlockFC layers
|
||||||
|
map2local (function): map global coordintes to local ones
|
||||||
|
pos_encoding (int): frequency for the positional encoding
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max',
|
||||||
|
unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None,
|
||||||
|
plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5,
|
||||||
|
map2local=None, pos_encoding=0):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.c_dim = c_dim
|
||||||
|
|
||||||
|
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
|
||||||
|
])
|
||||||
|
self.fc_c = nn.Linear(hidden_dim, c_dim)
|
||||||
|
|
||||||
|
self.actvn = nn.ReLU()
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
|
||||||
|
self.unet = None
|
||||||
|
if unet:
|
||||||
|
self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
|
||||||
|
|
||||||
|
self.unet3d = None
|
||||||
|
if unet3d:
|
||||||
|
self.unet3d = UNet3D(**unet3d_kwargs)
|
||||||
|
|
||||||
|
self.reso_plane = plane_resolution
|
||||||
|
self.reso_grid = grid_resolution
|
||||||
|
self.plane_type = plane_type
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
|
||||||
|
self.pe = None
|
||||||
|
if pos_encoding > 0:
|
||||||
|
embed_fn, input_ch = get_embedder(pos_encoding, d_in=dim)
|
||||||
|
self.pe = embed_fn
|
||||||
|
self.fc_pos = nn.Linear(input_ch, 2*hidden_dim)
|
||||||
|
|
||||||
|
self.map2local = map2local
|
||||||
|
|
||||||
|
|
||||||
|
if scatter_type == 'max':
|
||||||
|
self.scatter = scatter_max
|
||||||
|
elif scatter_type == 'mean':
|
||||||
|
self.scatter = scatter_mean
|
||||||
|
else:
|
||||||
|
raise ValueError('incorrect scatter type')
|
||||||
|
|
||||||
|
def generate_plane_features(self, p, c, plane='xz'):
|
||||||
|
# acquire indices of features in plane
|
||||||
|
xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
|
||||||
|
index = coordinate2index(xy, self.reso_plane)
|
||||||
|
|
||||||
|
# scatter plane features from points
|
||||||
|
fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
|
||||||
|
c = c.permute(0, 2, 1) # B x 512 x T
|
||||||
|
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
|
||||||
|
fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)
|
||||||
|
|
||||||
|
# process the plane features with UNet
|
||||||
|
if self.unet is not None:
|
||||||
|
fea_plane = self.unet(fea_plane)
|
||||||
|
|
||||||
|
return fea_plane
|
||||||
|
|
||||||
|
def generate_grid_features(self, p, c):
|
||||||
|
p_nor = normalize_3d_coordinate(p.clone())
|
||||||
|
index = coordinate2index(p_nor, self.reso_grid, coord_type='3d')
|
||||||
|
# scatter grid features from points
|
||||||
|
fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3)
|
||||||
|
c = c.permute(0, 2, 1)
|
||||||
|
fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3
|
||||||
|
fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) # sparce matrix (B x 512 x reso x reso)
|
||||||
|
|
||||||
|
if self.unet3d is not None:
|
||||||
|
fea_grid = self.unet3d(fea_grid)
|
||||||
|
|
||||||
|
return fea_grid
|
||||||
|
|
||||||
|
def pool_local(self, xy, index, c):
|
||||||
|
bs, fea_dim = c.size(0), c.size(2)
|
||||||
|
keys = xy.keys()
|
||||||
|
|
||||||
|
c_out = 0
|
||||||
|
for key in keys:
|
||||||
|
# scatter plane features from points
|
||||||
|
if key == 'grid':
|
||||||
|
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid**3)
|
||||||
|
else:
|
||||||
|
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2)
|
||||||
|
if self.scatter == scatter_max:
|
||||||
|
fea = fea[0]
|
||||||
|
# gather feature back to points
|
||||||
|
fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
|
||||||
|
c_out += fea
|
||||||
|
return c_out.permute(0, 2, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, p, normalize=True):
|
||||||
|
batch_size, T, D = p.size()
|
||||||
|
|
||||||
|
# acquire the index for each point
|
||||||
|
coord = {}
|
||||||
|
index = {}
|
||||||
|
if 'xz' in self.plane_type:
|
||||||
|
coord['xz'] = normalize_coordinate(p.clone(), plane='xz')
|
||||||
|
index['xz'] = coordinate2index(coord['xz'], self.reso_plane)
|
||||||
|
if 'xy' in self.plane_type:
|
||||||
|
coord['xy'] = normalize_coordinate(p.clone(), plane='xy')
|
||||||
|
index['xy'] = coordinate2index(coord['xy'], self.reso_plane)
|
||||||
|
if 'yz' in self.plane_type:
|
||||||
|
coord['yz'] = normalize_coordinate(p.clone(), plane='yz')
|
||||||
|
index['yz'] = coordinate2index(coord['yz'], self.reso_plane)
|
||||||
|
if 'grid' in self.plane_type:
|
||||||
|
if normalize:
|
||||||
|
coord['grid'] = normalize_3d_coordinate(p.clone())
|
||||||
|
else:
|
||||||
|
coord['grid'] = p.clone()[...,:3]
|
||||||
|
index['grid'] = coordinate2index(coord['grid'], self.reso_grid, coord_type='3d')
|
||||||
|
|
||||||
|
|
||||||
|
if self.pe:
|
||||||
|
p = self.pe(p)
|
||||||
|
|
||||||
|
if self.map2local:
|
||||||
|
pp = self.map2local(p)
|
||||||
|
net = self.fc_pos(pp)
|
||||||
|
else:
|
||||||
|
net = self.fc_pos(p)
|
||||||
|
# net = self.fc_pos(p)
|
||||||
|
|
||||||
|
net = self.blocks[0](net)
|
||||||
|
for block in self.blocks[1:]:
|
||||||
|
pooled = self.pool_local(coord, index, net)
|
||||||
|
net = torch.cat([net, pooled], dim=2)
|
||||||
|
net = block(net)
|
||||||
|
|
||||||
|
c = self.fc_c(net)
|
||||||
|
|
||||||
|
fea = {}
|
||||||
|
if 'grid' in self.plane_type:
|
||||||
|
fea['grid'] = self.generate_grid_features(p, c)
|
||||||
|
if 'xz' in self.plane_type:
|
||||||
|
fea['xz'] = self.generate_plane_features(p, c, plane='xz')
|
||||||
|
if 'xy' in self.plane_type:
|
||||||
|
fea['xy'] = self.generate_plane_features(p, c, plane='xy')
|
||||||
|
if 'yz' in self.plane_type:
|
||||||
|
fea['yz'] = self.generate_plane_features(p, c, plane='yz')
|
||||||
|
|
||||||
|
return fea
|
234
src/network/net_rgb.py
Normal file
234
src/network/net_rgb.py
Normal file
|
@ -0,0 +1,234 @@
|
||||||
|
# code from IDR (https://github.com/lioryariv/idr/blob/main/code/model/implicit_differentiable_renderer.py)
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from src.network.utils import get_embedder
|
||||||
|
from pdb import set_trace as st
|
||||||
|
|
||||||
|
class RenderingNetwork(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fea_size=0,
|
||||||
|
mode='naive',
|
||||||
|
d_out=3,
|
||||||
|
dims=[512, 512, 512, 512],
|
||||||
|
weight_norm=True,
|
||||||
|
pe_freq_view=0 # for positional encoding
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.mode = mode
|
||||||
|
if mode == 'naive':
|
||||||
|
d_in = 3
|
||||||
|
elif mode == 'no_feature':
|
||||||
|
d_in = 3 + 3 + 3
|
||||||
|
fea_size = 0
|
||||||
|
elif mode == 'full':
|
||||||
|
d_in = 3 + 3 + 3
|
||||||
|
else:
|
||||||
|
d_in = 3 + 3
|
||||||
|
dims = [d_in + fea_size] + dims + [d_out]
|
||||||
|
|
||||||
|
self.embedview_fn = None
|
||||||
|
if pe_freq_view > 0:
|
||||||
|
embedview_fn, input_ch = get_embedder(pe_freq_view, d_in=3)
|
||||||
|
self.embedview_fn = embedview_fn
|
||||||
|
dims[0] += (input_ch - 3)
|
||||||
|
|
||||||
|
self.num_layers = len(dims)
|
||||||
|
|
||||||
|
for l in range(0, self.num_layers - 1):
|
||||||
|
out_dim = dims[l + 1]
|
||||||
|
lin = nn.Linear(dims[l], out_dim)
|
||||||
|
|
||||||
|
if weight_norm:
|
||||||
|
lin = nn.utils.weight_norm(lin)
|
||||||
|
|
||||||
|
setattr(self, "lin" + str(l), lin)
|
||||||
|
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.tanh = nn.Tanh()
|
||||||
|
|
||||||
|
def forward(self, points, normals=None, view_dirs=None, feature_vectors=None):
|
||||||
|
if self.embedview_fn is not None:
|
||||||
|
view_dirs = self.embedview_fn(view_dirs)
|
||||||
|
# points = self.embedview_fn(points)
|
||||||
|
|
||||||
|
if (self.mode == 'full') & (feature_vectors is not None):
|
||||||
|
rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
|
||||||
|
elif (self.mode == 'no_feature') | ((self.mode == 'full') & (feature_vectors is None)):
|
||||||
|
rendering_input = torch.cat([points, view_dirs, normals], dim=-1)
|
||||||
|
elif self.mode == 'no_view_dir':
|
||||||
|
rendering_input = torch.cat([points, normals], dim=-1)
|
||||||
|
elif self.mode == 'no_normal':
|
||||||
|
rendering_input = torch.cat([points, view_dirs], dim=-1)
|
||||||
|
else:
|
||||||
|
rendering_input = points
|
||||||
|
|
||||||
|
x = rendering_input
|
||||||
|
|
||||||
|
for l in range(0, self.num_layers - 1):
|
||||||
|
lin = getattr(self, "lin" + str(l))
|
||||||
|
|
||||||
|
x = lin(x)
|
||||||
|
|
||||||
|
if l < self.num_layers - 2:
|
||||||
|
x = self.relu(x)
|
||||||
|
|
||||||
|
x = self.tanh(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class NeRFRenderingNetwork(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
feature_vector_size=0,
|
||||||
|
mode='naive',
|
||||||
|
d_in=3,
|
||||||
|
d_out=3,
|
||||||
|
dims=[512, 512, 512, 256],
|
||||||
|
weight_norm=True,
|
||||||
|
multires=0, # positional encoding of points
|
||||||
|
multires_view=0 # positional encoding of view
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.mode = mode
|
||||||
|
dims = [d_in + feature_vector_size] + dims
|
||||||
|
|
||||||
|
|
||||||
|
self.embed_fn = None
|
||||||
|
if multires > 0:
|
||||||
|
embed_fn, input_ch = get_embedder(multires, d_in=d_in)
|
||||||
|
self.embed_fn = embed_fn
|
||||||
|
dims[0] += (input_ch - 3)
|
||||||
|
|
||||||
|
self.num_layers = len(dims)
|
||||||
|
|
||||||
|
self.pts_net = nn.ModuleList([nn.Linear(dims[i], dims[i + 1]) for i in range(self.num_layers - 1)])
|
||||||
|
|
||||||
|
self.embedview_fn = None
|
||||||
|
if multires_view > 0:
|
||||||
|
embedview_fn, view_ch = get_embedder(multires_view, d_in=3)
|
||||||
|
self.embedview_fn = embedview_fn
|
||||||
|
# dims[0] += (input_ch - 3)
|
||||||
|
|
||||||
|
if mode == 'full':
|
||||||
|
self.view_net = nn.ModuleList([nn.Linear(dims[-1]+view_ch, 128)])
|
||||||
|
self.rgb_net = nn.Linear(128, 3)
|
||||||
|
else:
|
||||||
|
self.rgb_net = nn.Linear(dims[-1], 3)
|
||||||
|
|
||||||
|
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.tanh = nn.Tanh()
|
||||||
|
|
||||||
|
def forward(self, points, normals=None, view_dirs=None, feature_vectors=None):
|
||||||
|
if self.embed_fn is not None:
|
||||||
|
points = self.embed_fn(points)
|
||||||
|
if self.embedview_fn is not None:
|
||||||
|
view_dirs = self.embedview_fn(view_dirs)
|
||||||
|
|
||||||
|
x = points
|
||||||
|
for net in self.pts_net:
|
||||||
|
x = net(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
|
||||||
|
if self.mode=='full':
|
||||||
|
x = torch.cat([x, view_dirs], -1)
|
||||||
|
for net in self.view_net:
|
||||||
|
x = net(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
|
||||||
|
x = self.rgb_net(x)
|
||||||
|
x = self.tanh(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class ImplicitNetwork(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_in,
|
||||||
|
d_out,
|
||||||
|
dims,
|
||||||
|
geometric_init=True,
|
||||||
|
feature_vector_size=0,
|
||||||
|
bias=1.0,
|
||||||
|
skip_in=(),
|
||||||
|
weight_norm=True,
|
||||||
|
multires=0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dims = [d_in] + dims + [d_out + feature_vector_size]
|
||||||
|
|
||||||
|
self.embed_fn = None
|
||||||
|
if multires > 0:
|
||||||
|
embed_fn, input_ch = get_embedder(multires)
|
||||||
|
self.embed_fn = embed_fn
|
||||||
|
dims[0] = input_ch
|
||||||
|
|
||||||
|
self.num_layers = len(dims)
|
||||||
|
self.skip_in = skip_in
|
||||||
|
|
||||||
|
for l in range(0, self.num_layers - 1):
|
||||||
|
if l + 1 in self.skip_in:
|
||||||
|
out_dim = dims[l + 1] - dims[0]
|
||||||
|
else:
|
||||||
|
out_dim = dims[l + 1]
|
||||||
|
|
||||||
|
lin = nn.Linear(dims[l], out_dim)
|
||||||
|
|
||||||
|
if geometric_init:
|
||||||
|
if l == self.num_layers - 2:
|
||||||
|
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
|
||||||
|
torch.nn.init.constant_(lin.bias, -bias)
|
||||||
|
elif multires > 0 and l == 0:
|
||||||
|
torch.nn.init.constant_(lin.bias, 0.0)
|
||||||
|
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
|
||||||
|
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
||||||
|
elif multires > 0 and l in self.skip_in:
|
||||||
|
torch.nn.init.constant_(lin.bias, 0.0)
|
||||||
|
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
||||||
|
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
|
||||||
|
else:
|
||||||
|
torch.nn.init.constant_(lin.bias, 0.0)
|
||||||
|
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
||||||
|
|
||||||
|
if weight_norm:
|
||||||
|
lin = nn.utils.weight_norm(lin)
|
||||||
|
|
||||||
|
setattr(self, "lin" + str(l), lin)
|
||||||
|
|
||||||
|
self.softplus = nn.Softplus(beta=100)
|
||||||
|
|
||||||
|
def forward(self, input, compute_grad=False):
|
||||||
|
if self.embed_fn is not None:
|
||||||
|
input = self.embed_fn(input)
|
||||||
|
|
||||||
|
x = input
|
||||||
|
|
||||||
|
for l in range(0, self.num_layers - 1):
|
||||||
|
lin = getattr(self, "lin" + str(l))
|
||||||
|
|
||||||
|
if l in self.skip_in:
|
||||||
|
x = torch.cat([x, input], 1) / np.sqrt(2)
|
||||||
|
|
||||||
|
x = lin(x)
|
||||||
|
|
||||||
|
if l < self.num_layers - 2:
|
||||||
|
x = self.softplus(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def gradient(self, x):
|
||||||
|
x.requires_grad_(True)
|
||||||
|
y = self.forward(x)[:,:1]
|
||||||
|
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
||||||
|
gradients = torch.autograd.grad(
|
||||||
|
outputs=y,
|
||||||
|
inputs=x,
|
||||||
|
grad_outputs=d_output,
|
||||||
|
create_graph=True,
|
||||||
|
retain_graph=True,
|
||||||
|
only_inputs=True)[0]
|
||||||
|
return gradients.unsqueeze(1)
|
256
src/network/unet.py
Normal file
256
src/network/unet.py
Normal file
|
@ -0,0 +1,256 @@
|
||||||
|
'''
|
||||||
|
Codes are from:
|
||||||
|
https://github.com/jaxony/unet-pytorch/blob/master/model.py
|
||||||
|
'''
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import init
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def conv3x3(in_channels, out_channels, stride=1,
|
||||||
|
padding=1, bias=True, groups=1):
|
||||||
|
return nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
bias=bias,
|
||||||
|
groups=groups)
|
||||||
|
|
||||||
|
def upconv2x2(in_channels, out_channels, mode='transpose'):
|
||||||
|
if mode == 'transpose':
|
||||||
|
return nn.ConvTranspose2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=2,
|
||||||
|
stride=2)
|
||||||
|
else:
|
||||||
|
# out_channels is always going to be the same
|
||||||
|
# as in_channels
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Upsample(mode='bilinear', scale_factor=2),
|
||||||
|
conv1x1(in_channels, out_channels))
|
||||||
|
|
||||||
|
def conv1x1(in_channels, out_channels, groups=1):
|
||||||
|
return nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
groups=groups,
|
||||||
|
stride=1)
|
||||||
|
|
||||||
|
|
||||||
|
class DownConv(nn.Module):
|
||||||
|
"""
|
||||||
|
A helper Module that performs 2 convolutions and 1 MaxPool.
|
||||||
|
A ReLU activation follows each convolution.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, pooling=True):
|
||||||
|
super(DownConv, self).__init__()
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.pooling = pooling
|
||||||
|
|
||||||
|
self.conv1 = conv3x3(self.in_channels, self.out_channels)
|
||||||
|
self.conv2 = conv3x3(self.out_channels, self.out_channels)
|
||||||
|
|
||||||
|
if self.pooling:
|
||||||
|
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.relu(self.conv1(x))
|
||||||
|
x = F.relu(self.conv2(x))
|
||||||
|
before_pool = x
|
||||||
|
if self.pooling:
|
||||||
|
x = self.pool(x)
|
||||||
|
return x, before_pool
|
||||||
|
|
||||||
|
|
||||||
|
class UpConv(nn.Module):
|
||||||
|
"""
|
||||||
|
A helper Module that performs 2 convolutions and 1 UpConvolution.
|
||||||
|
A ReLU activation follows each convolution.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels,
|
||||||
|
merge_mode='concat', up_mode='transpose'):
|
||||||
|
super(UpConv, self).__init__()
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.merge_mode = merge_mode
|
||||||
|
self.up_mode = up_mode
|
||||||
|
|
||||||
|
self.upconv = upconv2x2(self.in_channels, self.out_channels,
|
||||||
|
mode=self.up_mode)
|
||||||
|
|
||||||
|
if self.merge_mode == 'concat':
|
||||||
|
self.conv1 = conv3x3(
|
||||||
|
2*self.out_channels, self.out_channels)
|
||||||
|
else:
|
||||||
|
# num of input channels to conv2 is same
|
||||||
|
self.conv1 = conv3x3(self.out_channels, self.out_channels)
|
||||||
|
self.conv2 = conv3x3(self.out_channels, self.out_channels)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, from_down, from_up):
|
||||||
|
""" Forward pass
|
||||||
|
Arguments:
|
||||||
|
from_down: tensor from the encoder pathway
|
||||||
|
from_up: upconv'd tensor from the decoder pathway
|
||||||
|
"""
|
||||||
|
from_up = self.upconv(from_up)
|
||||||
|
if self.merge_mode == 'concat':
|
||||||
|
x = torch.cat((from_up, from_down), 1)
|
||||||
|
else:
|
||||||
|
x = from_up + from_down
|
||||||
|
x = F.relu(self.conv1(x))
|
||||||
|
x = F.relu(self.conv2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UNet(nn.Module):
|
||||||
|
""" `UNet` class is based on https://arxiv.org/abs/1505.04597
|
||||||
|
|
||||||
|
The U-Net is a convolutional encoder-decoder neural network.
|
||||||
|
Contextual spatial information (from the decoding,
|
||||||
|
expansive pathway) about an input tensor is merged with
|
||||||
|
information representing the localization of details
|
||||||
|
(from the encoding, compressive pathway).
|
||||||
|
|
||||||
|
Modifications to the original paper:
|
||||||
|
(1) padding is used in 3x3 convolutions to prevent loss
|
||||||
|
of border pixels
|
||||||
|
(2) merging outputs does not require cropping due to (1)
|
||||||
|
(3) residual connections can be used by specifying
|
||||||
|
UNet(merge_mode='add')
|
||||||
|
(4) if non-parametric upsampling is used in the decoder
|
||||||
|
pathway (specified by upmode='upsample'), then an
|
||||||
|
additional 1x1 2d convolution occurs after upsampling
|
||||||
|
to reduce channel dimensionality by a factor of 2.
|
||||||
|
This channel halving happens with the convolution in
|
||||||
|
the tranpose convolution (specified by upmode='transpose')
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_classes, in_channels=3, depth=5,
|
||||||
|
start_filts=64, up_mode='transpose',
|
||||||
|
merge_mode='concat', **kwargs):
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
in_channels: int, number of channels in the input tensor.
|
||||||
|
Default is 3 for RGB images.
|
||||||
|
depth: int, number of MaxPools in the U-Net.
|
||||||
|
start_filts: int, number of convolutional filters for the
|
||||||
|
first conv.
|
||||||
|
up_mode: string, type of upconvolution. Choices: 'transpose'
|
||||||
|
for transpose convolution or 'upsample' for nearest neighbour
|
||||||
|
upsampling.
|
||||||
|
"""
|
||||||
|
super(UNet, self).__init__()
|
||||||
|
|
||||||
|
if up_mode in ('transpose', 'upsample'):
|
||||||
|
self.up_mode = up_mode
|
||||||
|
else:
|
||||||
|
raise ValueError("\"{}\" is not a valid mode for "
|
||||||
|
"upsampling. Only \"transpose\" and "
|
||||||
|
"\"upsample\" are allowed.".format(up_mode))
|
||||||
|
|
||||||
|
if merge_mode in ('concat', 'add'):
|
||||||
|
self.merge_mode = merge_mode
|
||||||
|
else:
|
||||||
|
raise ValueError("\"{}\" is not a valid mode for"
|
||||||
|
"merging up and down paths. "
|
||||||
|
"Only \"concat\" and "
|
||||||
|
"\"add\" are allowed.".format(up_mode))
|
||||||
|
|
||||||
|
# NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
|
||||||
|
if self.up_mode == 'upsample' and self.merge_mode == 'add':
|
||||||
|
raise ValueError("up_mode \"upsample\" is incompatible "
|
||||||
|
"with merge_mode \"add\" at the moment "
|
||||||
|
"because it doesn't make sense to use "
|
||||||
|
"nearest neighbour to reduce "
|
||||||
|
"depth channels (by half).")
|
||||||
|
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.start_filts = start_filts
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
|
self.down_convs = []
|
||||||
|
self.up_convs = []
|
||||||
|
|
||||||
|
# create the encoder pathway and add to a list
|
||||||
|
for i in range(depth):
|
||||||
|
ins = self.in_channels if i == 0 else outs
|
||||||
|
outs = self.start_filts*(2**i)
|
||||||
|
pooling = True if i < depth-1 else False
|
||||||
|
|
||||||
|
down_conv = DownConv(ins, outs, pooling=pooling)
|
||||||
|
self.down_convs.append(down_conv)
|
||||||
|
|
||||||
|
# create the decoder pathway and add to a list
|
||||||
|
# - careful! decoding only requires depth-1 blocks
|
||||||
|
for i in range(depth-1):
|
||||||
|
ins = outs
|
||||||
|
outs = ins // 2
|
||||||
|
up_conv = UpConv(ins, outs, up_mode=up_mode,
|
||||||
|
merge_mode=merge_mode)
|
||||||
|
self.up_convs.append(up_conv)
|
||||||
|
|
||||||
|
# add the list of modules to current module
|
||||||
|
self.down_convs = nn.ModuleList(self.down_convs)
|
||||||
|
self.up_convs = nn.ModuleList(self.up_convs)
|
||||||
|
|
||||||
|
self.conv_final = conv1x1(outs, self.num_classes)
|
||||||
|
|
||||||
|
self.reset_params()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def weight_init(m):
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
init.xavier_normal_(m.weight)
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_params(self):
|
||||||
|
for i, m in enumerate(self.modules()):
|
||||||
|
self.weight_init(m)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
encoder_outs = []
|
||||||
|
# encoder pathway, save outputs for merging
|
||||||
|
for i, module in enumerate(self.down_convs):
|
||||||
|
x, before_pool = module(x)
|
||||||
|
encoder_outs.append(before_pool)
|
||||||
|
for i, module in enumerate(self.up_convs):
|
||||||
|
before_pool = encoder_outs[-(i+2)]
|
||||||
|
x = module(before_pool, x)
|
||||||
|
|
||||||
|
# No softmax is used. This means you need to use
|
||||||
|
# nn.CrossEntropyLoss is your training script,
|
||||||
|
# as this module includes a softmax already.
|
||||||
|
x = self.conv_final(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""
|
||||||
|
testing
|
||||||
|
"""
|
||||||
|
model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32)
|
||||||
|
print(model)
|
||||||
|
print(sum(p.numel() for p in model.parameters()))
|
||||||
|
|
||||||
|
reso = 176
|
||||||
|
x = np.zeros((1, 1, reso, reso))
|
||||||
|
x[:,:,int(reso/2-1), int(reso/2-1)] = np.nan
|
||||||
|
x = torch.FloatTensor(x)
|
||||||
|
|
||||||
|
out = model(x)
|
||||||
|
print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso)))
|
||||||
|
|
||||||
|
# loss = torch.sum(out)
|
||||||
|
# loss.backward()
|
559
src/network/unet3d.py
Normal file
559
src/network/unet3d.py
Normal file
|
@ -0,0 +1,559 @@
|
||||||
|
'''
|
||||||
|
Code from the 3D UNet implementation:
|
||||||
|
https://github.com/wolny/pytorch-3dunet/
|
||||||
|
'''
|
||||||
|
import importlib
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from functools import partial
|
||||||
|
from src.network.utils import get_embedder
|
||||||
|
|
||||||
|
def number_of_features_per_level(init_channel_number, num_levels):
|
||||||
|
return [init_channel_number * 2 ** k for k in range(num_levels)]
|
||||||
|
|
||||||
|
|
||||||
|
def conv3d(in_channels, out_channels, kernel_size, bias, padding=1):
|
||||||
|
return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
|
||||||
|
|
||||||
|
|
||||||
|
def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1):
|
||||||
|
"""
|
||||||
|
Create a list of modules with together constitute a single conv layer with non-linearity
|
||||||
|
and optional batchnorm/groupnorm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of input channels
|
||||||
|
out_channels (int): number of output channels
|
||||||
|
order (string): order of things, e.g.
|
||||||
|
'cr' -> conv + ReLU
|
||||||
|
'gcr' -> groupnorm + conv + ReLU
|
||||||
|
'cl' -> conv + LeakyReLU
|
||||||
|
'ce' -> conv + ELU
|
||||||
|
'bcr' -> batchnorm + conv + ReLU
|
||||||
|
num_groups (int): number of groups for the GroupNorm
|
||||||
|
padding (int): add zero-padding to the input
|
||||||
|
|
||||||
|
Return:
|
||||||
|
list of tuple (name, module)
|
||||||
|
"""
|
||||||
|
assert 'c' in order, "Conv layer MUST be present"
|
||||||
|
assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer'
|
||||||
|
|
||||||
|
modules = []
|
||||||
|
for i, char in enumerate(order):
|
||||||
|
if char == 'r':
|
||||||
|
modules.append(('ReLU', nn.ReLU(inplace=True)))
|
||||||
|
elif char == 'l':
|
||||||
|
modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True)))
|
||||||
|
elif char == 'e':
|
||||||
|
modules.append(('ELU', nn.ELU(inplace=True)))
|
||||||
|
elif char == 'c':
|
||||||
|
# add learnable bias only in the absence of batchnorm/groupnorm
|
||||||
|
bias = not ('g' in order or 'b' in order)
|
||||||
|
modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding)))
|
||||||
|
elif char == 'g':
|
||||||
|
is_before_conv = i < order.index('c')
|
||||||
|
if is_before_conv:
|
||||||
|
num_channels = in_channels
|
||||||
|
else:
|
||||||
|
num_channels = out_channels
|
||||||
|
|
||||||
|
# use only one group if the given number of groups is greater than the number of channels
|
||||||
|
if num_channels < num_groups:
|
||||||
|
num_groups = 1
|
||||||
|
|
||||||
|
assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}'
|
||||||
|
modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)))
|
||||||
|
elif char == 'b':
|
||||||
|
is_before_conv = i < order.index('c')
|
||||||
|
if is_before_conv:
|
||||||
|
modules.append(('batchnorm', nn.BatchNorm3d(in_channels)))
|
||||||
|
else:
|
||||||
|
modules.append(('batchnorm', nn.BatchNorm3d(out_channels)))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']")
|
||||||
|
|
||||||
|
return modules
|
||||||
|
|
||||||
|
|
||||||
|
class SingleConv(nn.Sequential):
|
||||||
|
"""
|
||||||
|
Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
|
||||||
|
of operations can be specified via the `order` parameter
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of input channels
|
||||||
|
out_channels (int): number of output channels
|
||||||
|
kernel_size (int): size of the convolving kernel
|
||||||
|
order (string): determines the order of layers, e.g.
|
||||||
|
'cr' -> conv + ReLU
|
||||||
|
'crg' -> conv + ReLU + groupnorm
|
||||||
|
'cl' -> conv + LeakyReLU
|
||||||
|
'ce' -> conv + ELU
|
||||||
|
num_groups (int): number of groups for the GroupNorm
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8, padding=1):
|
||||||
|
super(SingleConv, self).__init__()
|
||||||
|
|
||||||
|
for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding):
|
||||||
|
self.add_module(name, module)
|
||||||
|
|
||||||
|
|
||||||
|
class DoubleConv(nn.Sequential):
|
||||||
|
"""
|
||||||
|
A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
|
||||||
|
We use (Conv3d+ReLU+GroupNorm3d) by default.
|
||||||
|
This can be changed however by providing the 'order' argument, e.g. in order
|
||||||
|
to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
|
||||||
|
Use padded convolutions to make sure that the output (H_out, W_out) is the same
|
||||||
|
as (H_in, W_in), so that you don't have to crop in the decoder path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of input channels
|
||||||
|
out_channels (int): number of output channels
|
||||||
|
encoder (bool): if True we're in the encoder path, otherwise we're in the decoder
|
||||||
|
kernel_size (int): size of the convolving kernel
|
||||||
|
order (string): determines the order of layers, e.g.
|
||||||
|
'cr' -> conv + ReLU
|
||||||
|
'crg' -> conv + ReLU + groupnorm
|
||||||
|
'cl' -> conv + LeakyReLU
|
||||||
|
'ce' -> conv + ELU
|
||||||
|
num_groups (int): number of groups for the GroupNorm
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='crg', num_groups=8):
|
||||||
|
super(DoubleConv, self).__init__()
|
||||||
|
if encoder:
|
||||||
|
# we're in the encoder path
|
||||||
|
conv1_in_channels = in_channels
|
||||||
|
conv1_out_channels = out_channels // 2
|
||||||
|
if conv1_out_channels < in_channels:
|
||||||
|
conv1_out_channels = in_channels
|
||||||
|
conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels
|
||||||
|
else:
|
||||||
|
# we're in the decoder path, decrease the number of channels in the 1st convolution
|
||||||
|
conv1_in_channels, conv1_out_channels = in_channels, out_channels
|
||||||
|
conv2_in_channels, conv2_out_channels = out_channels, out_channels
|
||||||
|
|
||||||
|
# conv1
|
||||||
|
self.add_module('SingleConv1',
|
||||||
|
SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups))
|
||||||
|
# conv2
|
||||||
|
self.add_module('SingleConv2',
|
||||||
|
SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups))
|
||||||
|
|
||||||
|
|
||||||
|
class ExtResNetBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Basic UNet block consisting of a SingleConv followed by the residual block.
|
||||||
|
The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number
|
||||||
|
of output channels is compatible with the residual block that follows.
|
||||||
|
This block can be used instead of standard DoubleConv in the Encoder module.
|
||||||
|
Motivated by: https://arxiv.org/pdf/1706.00120.pdf
|
||||||
|
|
||||||
|
Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs):
|
||||||
|
super(ExtResNetBlock, self).__init__()
|
||||||
|
|
||||||
|
# first convolution
|
||||||
|
self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
|
||||||
|
# residual block
|
||||||
|
self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
|
||||||
|
# remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
|
||||||
|
n_order = order
|
||||||
|
for c in 'rel':
|
||||||
|
n_order = n_order.replace(c, '')
|
||||||
|
self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order,
|
||||||
|
num_groups=num_groups)
|
||||||
|
|
||||||
|
# create non-linearity separately
|
||||||
|
if 'l' in order:
|
||||||
|
self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
elif 'e' in order:
|
||||||
|
self.non_linearity = nn.ELU(inplace=True)
|
||||||
|
else:
|
||||||
|
self.non_linearity = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# apply first convolution and save the output as a residual
|
||||||
|
out = self.conv1(x)
|
||||||
|
residual = out
|
||||||
|
|
||||||
|
# residual block
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.conv3(out)
|
||||||
|
|
||||||
|
out += residual
|
||||||
|
out = self.non_linearity(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
"""
|
||||||
|
A single module from the encoder path consisting of the optional max
|
||||||
|
pooling layer (one may specify the MaxPool kernel_size to be different
|
||||||
|
than the standard (2,2,2), e.g. if the volumetric data is anisotropic
|
||||||
|
(make sure to use complementary scale_factor in the decoder path) followed by
|
||||||
|
a DoubleConv module.
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of input channels
|
||||||
|
out_channels (int): number of output channels
|
||||||
|
conv_kernel_size (int): size of the convolving kernel
|
||||||
|
apply_pooling (bool): if True use MaxPool3d before DoubleConv
|
||||||
|
pool_kernel_size (tuple): the size of the window to take a max over
|
||||||
|
pool_type (str): pooling layer: 'max' or 'avg'
|
||||||
|
basic_module(nn.Module): either ResNetBlock or DoubleConv
|
||||||
|
conv_layer_order (string): determines the order of layers
|
||||||
|
in `DoubleConv` module. See `DoubleConv` for more info.
|
||||||
|
num_groups (int): number of groups for the GroupNorm
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True,
|
||||||
|
pool_kernel_size=(2, 2, 2), pool_type='max', basic_module=DoubleConv, conv_layer_order='crg',
|
||||||
|
num_groups=8):
|
||||||
|
super(Encoder, self).__init__()
|
||||||
|
assert pool_type in ['max', 'avg']
|
||||||
|
if apply_pooling:
|
||||||
|
if pool_type == 'max':
|
||||||
|
self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
|
||||||
|
else:
|
||||||
|
self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size)
|
||||||
|
else:
|
||||||
|
self.pooling = None
|
||||||
|
|
||||||
|
self.basic_module = basic_module(in_channels, out_channels,
|
||||||
|
encoder=True,
|
||||||
|
kernel_size=conv_kernel_size,
|
||||||
|
order=conv_layer_order,
|
||||||
|
num_groups=num_groups)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.pooling is not None:
|
||||||
|
x = self.pooling(x)
|
||||||
|
x = self.basic_module(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
"""
|
||||||
|
A single module for decoder path consisting of the upsampling layer
|
||||||
|
(either learned ConvTranspose3d or nearest neighbor interpolation) followed by a basic module (DoubleConv or ExtResNetBlock).
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of input channels
|
||||||
|
out_channels (int): number of output channels
|
||||||
|
kernel_size (int): size of the convolving kernel
|
||||||
|
scale_factor (tuple): used as the multiplier for the image H/W/D in
|
||||||
|
case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation
|
||||||
|
from the corresponding encoder
|
||||||
|
basic_module(nn.Module): either ResNetBlock or DoubleConv
|
||||||
|
conv_layer_order (string): determines the order of layers
|
||||||
|
in `DoubleConv` module. See `DoubleConv` for more info.
|
||||||
|
num_groups (int): number of groups for the GroupNorm
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv,
|
||||||
|
conv_layer_order='crg', num_groups=8, mode='nearest'):
|
||||||
|
super(Decoder, self).__init__()
|
||||||
|
if basic_module == DoubleConv:
|
||||||
|
# if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining
|
||||||
|
self.upsampling = Upsampling(transposed_conv=False, in_channels=in_channels, out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size, scale_factor=scale_factor, mode=mode)
|
||||||
|
# concat joining
|
||||||
|
self.joining = partial(self._joining, concat=True)
|
||||||
|
else:
|
||||||
|
# if basic_module=ExtResNetBlock use transposed convolution upsampling and summation joining
|
||||||
|
self.upsampling = Upsampling(transposed_conv=True, in_channels=in_channels, out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size, scale_factor=scale_factor, mode=mode)
|
||||||
|
# sum joining
|
||||||
|
self.joining = partial(self._joining, concat=False)
|
||||||
|
# adapt the number of in_channels for the ExtResNetBlock
|
||||||
|
in_channels = out_channels
|
||||||
|
|
||||||
|
self.basic_module = basic_module(in_channels, out_channels,
|
||||||
|
encoder=False,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
order=conv_layer_order,
|
||||||
|
num_groups=num_groups)
|
||||||
|
|
||||||
|
def forward(self, encoder_features, x):
|
||||||
|
x = self.upsampling(encoder_features=encoder_features, x=x)
|
||||||
|
x = self.joining(encoder_features, x)
|
||||||
|
x = self.basic_module(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _joining(encoder_features, x, concat):
|
||||||
|
if concat:
|
||||||
|
return torch.cat((encoder_features, x), dim=1)
|
||||||
|
else:
|
||||||
|
return encoder_features + x
|
||||||
|
|
||||||
|
|
||||||
|
class Upsampling(nn.Module):
|
||||||
|
"""
|
||||||
|
Upsamples a given multi-channel 3D data using either interpolation or learned transposed convolution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transposed_conv (bool): if True uses ConvTranspose3d for upsampling, otherwise uses interpolation
|
||||||
|
concat_joining (bool): if True uses concatenation joining between encoder and decoder features, otherwise
|
||||||
|
uses summation joining (see Residual U-Net)
|
||||||
|
in_channels (int): number of input channels for transposed conv
|
||||||
|
out_channels (int): number of output channels for transpose conv
|
||||||
|
kernel_size (int or tuple): size of the convolving kernel
|
||||||
|
scale_factor (int or tuple): stride of the convolution
|
||||||
|
mode (str): algorithm used for upsampling:
|
||||||
|
'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest'
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, transposed_conv, in_channels=None, out_channels=None, kernel_size=3,
|
||||||
|
scale_factor=(2, 2, 2), mode='nearest'):
|
||||||
|
super(Upsampling, self).__init__()
|
||||||
|
|
||||||
|
if transposed_conv:
|
||||||
|
# make sure that the output size reverses the MaxPool3d from the corresponding encoder
|
||||||
|
# (D_out = (D_in − 1) × stride[0] − 2 × padding[0] + kernel_size[0] + output_padding[0])
|
||||||
|
self.upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor,
|
||||||
|
padding=1)
|
||||||
|
else:
|
||||||
|
self.upsample = partial(self._interpolate, mode=mode)
|
||||||
|
|
||||||
|
def forward(self, encoder_features, x):
|
||||||
|
output_size = encoder_features.size()[2:]
|
||||||
|
return self.upsample(x, output_size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _interpolate(x, size, mode):
|
||||||
|
return F.interpolate(x, size=size, mode=mode)
|
||||||
|
|
||||||
|
|
||||||
|
class FinalConv(nn.Sequential):
|
||||||
|
"""
|
||||||
|
A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution
|
||||||
|
which reduces the number of channels to 'out_channels'.
|
||||||
|
with the number of output channels 'out_channels // 2' and 'out_channels' respectively.
|
||||||
|
We use (Conv3d+ReLU+GroupNorm3d) by default.
|
||||||
|
This can be change however by providing the 'order' argument, e.g. in order
|
||||||
|
to change to Conv3d+BatchNorm3d+ReLU use order='cbr'.
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of input channels
|
||||||
|
out_channels (int): number of output channels
|
||||||
|
kernel_size (int): size of the convolving kernel
|
||||||
|
order (string): determines the order of layers, e.g.
|
||||||
|
'cr' -> conv + ReLU
|
||||||
|
'crg' -> conv + ReLU + groupnorm
|
||||||
|
num_groups (int): number of groups for the GroupNorm
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8):
|
||||||
|
super(FinalConv, self).__init__()
|
||||||
|
|
||||||
|
# conv1
|
||||||
|
self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups))
|
||||||
|
|
||||||
|
# in the last layer a 1×1 convolution reduces the number of output channels to out_channels
|
||||||
|
final_conv = nn.Conv3d(in_channels, out_channels, 1)
|
||||||
|
self.add_module('final_conv', final_conv)
|
||||||
|
|
||||||
|
class Abstract3DUNet(nn.Module):
|
||||||
|
"""
|
||||||
|
Base class for standard and residual UNet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of input channels
|
||||||
|
out_channels (int): number of output segmentation masks;
|
||||||
|
Note that that the of out_channels might correspond to either
|
||||||
|
different semantic classes or to different binary segmentation mask.
|
||||||
|
It's up to the user of the class to interpret the out_channels and
|
||||||
|
use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class)
|
||||||
|
or BCEWithLogitsLoss (two-class) respectively)
|
||||||
|
f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number
|
||||||
|
of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4
|
||||||
|
final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the
|
||||||
|
final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used
|
||||||
|
to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model.
|
||||||
|
basic_module: basic model for the encoder/decoder (DoubleConv, ExtResNetBlock, ....)
|
||||||
|
layer_order (string): determines the order of layers
|
||||||
|
in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d.
|
||||||
|
See `SingleConv` for more info
|
||||||
|
f_maps (int, tuple): if int: number of feature maps in the first conv layer of the encoder (default: 64);
|
||||||
|
if tuple: number of feature maps at each level
|
||||||
|
num_groups (int): number of groups for the GroupNorm
|
||||||
|
num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int)
|
||||||
|
is_segmentation (bool): if True (semantic segmentation problem) Sigmoid/Softmax normalization is applied
|
||||||
|
after the final convolution; if False (regression problem) the normalization layer is skipped at the end
|
||||||
|
testing (bool): if True (testing mode) the `final_activation` (if present, i.e. `is_segmentation=true`)
|
||||||
|
will be applied as the last operation during the forward pass; if False the model is in training mode
|
||||||
|
and the `final_activation` (even if present) won't be applied; default: False
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr',
|
||||||
|
num_groups=8, num_levels=4, is_segmentation=False, testing=False, pe_freq=0, **kwargs):
|
||||||
|
super(Abstract3DUNet, self).__init__()
|
||||||
|
|
||||||
|
self.testing = testing
|
||||||
|
|
||||||
|
if isinstance(f_maps, int):
|
||||||
|
f_maps = number_of_features_per_level(f_maps, num_levels=num_levels)
|
||||||
|
|
||||||
|
# create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)`
|
||||||
|
|
||||||
|
self.embed_fn = None
|
||||||
|
if pe_freq > 0:
|
||||||
|
embed_fn, input_ch = get_embedder(pe_freq, d_in=in_channels)
|
||||||
|
self.embed_fn = embed_fn
|
||||||
|
in_channels = input_ch
|
||||||
|
|
||||||
|
encoders = []
|
||||||
|
for i, out_feature_num in enumerate(f_maps):
|
||||||
|
if i == 0:
|
||||||
|
encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=basic_module,
|
||||||
|
conv_layer_order=layer_order, num_groups=num_groups)
|
||||||
|
else:
|
||||||
|
# TODO: adapt for anisotropy in the data, i.e. use proper pooling kernel to make the data isotropic after 1-2 pooling operations
|
||||||
|
# currently pools with a constant kernel: (2, 2, 2)
|
||||||
|
encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=basic_module,
|
||||||
|
conv_layer_order=layer_order, num_groups=num_groups)
|
||||||
|
encoders.append(encoder)
|
||||||
|
|
||||||
|
self.encoders = nn.ModuleList(encoders)
|
||||||
|
|
||||||
|
# create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1`
|
||||||
|
decoders = []
|
||||||
|
reversed_f_maps = list(reversed(f_maps))
|
||||||
|
for i in range(len(reversed_f_maps) - 1):
|
||||||
|
if basic_module == DoubleConv:
|
||||||
|
in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1]
|
||||||
|
else:
|
||||||
|
in_feature_num = reversed_f_maps[i]
|
||||||
|
|
||||||
|
out_feature_num = reversed_f_maps[i + 1]
|
||||||
|
# TODO: if non-standard pooling was used, make sure to use correct striding for transpose conv
|
||||||
|
# currently strides with a constant stride: (2, 2, 2)
|
||||||
|
decoder = Decoder(in_feature_num, out_feature_num, basic_module=basic_module,
|
||||||
|
conv_layer_order=layer_order, num_groups=num_groups)
|
||||||
|
decoders.append(decoder)
|
||||||
|
|
||||||
|
self.decoders = nn.ModuleList(decoders)
|
||||||
|
|
||||||
|
# in the last layer a 1×1 convolution reduces the number of output
|
||||||
|
# channels to the number of labels
|
||||||
|
self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)
|
||||||
|
|
||||||
|
if is_segmentation:
|
||||||
|
# semantic segmentation problem
|
||||||
|
if final_sigmoid:
|
||||||
|
self.final_activation = nn.Sigmoid()
|
||||||
|
else:
|
||||||
|
self.final_activation = nn.Softmax(dim=1)
|
||||||
|
else:
|
||||||
|
# regression problem
|
||||||
|
self.final_activation = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
if self.embed_fn is not None:
|
||||||
|
x = self.embed_fn(x.permute(0, 2, 3, 4, 1))
|
||||||
|
x = x.permute(0, 4, 1, 2, 3)
|
||||||
|
|
||||||
|
# encoder part
|
||||||
|
encoders_features = []
|
||||||
|
for encoder in self.encoders:
|
||||||
|
x = encoder(x)
|
||||||
|
# reverse the encoder outputs to be aligned with the decoder
|
||||||
|
encoders_features.insert(0, x)
|
||||||
|
|
||||||
|
# remove the last encoder's output from the list
|
||||||
|
# !!remember: it's the 1st in the list
|
||||||
|
encoders_features = encoders_features[1:]
|
||||||
|
|
||||||
|
# decoder part
|
||||||
|
for decoder, encoder_features in zip(self.decoders, encoders_features):
|
||||||
|
# pass the output from the corresponding encoder and the output
|
||||||
|
# of the previous decoder
|
||||||
|
x = decoder(encoder_features, x)
|
||||||
|
|
||||||
|
x = self.final_conv(x)
|
||||||
|
|
||||||
|
# apply final_activation (i.e. Sigmoid or Softmax) only during prediction. During training the network outputs
|
||||||
|
# logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric
|
||||||
|
if self.testing and self.final_activation is not None:
|
||||||
|
x = self.final_activation(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UNet3D(Abstract3DUNet):
|
||||||
|
"""
|
||||||
|
3DUnet model from
|
||||||
|
`"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation"
|
||||||
|
<https://arxiv.org/pdf/1606.06650.pdf>`.
|
||||||
|
|
||||||
|
Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
|
||||||
|
num_groups=8, num_levels=4, is_segmentation=True, **kwargs):
|
||||||
|
super(UNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid,
|
||||||
|
basic_module=DoubleConv, f_maps=f_maps, layer_order=layer_order,
|
||||||
|
num_groups=num_groups, num_levels=num_levels, is_segmentation=is_segmentation,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualUNet3D(Abstract3DUNet):
|
||||||
|
"""
|
||||||
|
Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
|
||||||
|
Uses ExtResNetBlock as a basic building block, summation joining instead
|
||||||
|
of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts).
|
||||||
|
Since the model effectively becomes a residual net, in theory it allows for deeper UNet.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
|
||||||
|
num_groups=8, num_levels=5, is_segmentation=True, **kwargs):
|
||||||
|
super(ResidualUNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels,
|
||||||
|
final_sigmoid=final_sigmoid,
|
||||||
|
basic_module=ExtResNetBlock, f_maps=f_maps, layer_order=layer_order,
|
||||||
|
num_groups=num_groups, num_levels=num_levels,
|
||||||
|
is_segmentation=is_segmentation,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(config):
|
||||||
|
def _model_class(class_name):
|
||||||
|
m = importlib.import_module('pytorch3dunet.unet3d.model')
|
||||||
|
clazz = getattr(m, class_name)
|
||||||
|
return clazz
|
||||||
|
|
||||||
|
assert 'model' in config, 'Could not find model configuration'
|
||||||
|
model_config = config['model']
|
||||||
|
model_class = _model_class(model_config['name'])
|
||||||
|
return model_class(**model_config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""
|
||||||
|
testing
|
||||||
|
"""
|
||||||
|
in_channels = 1
|
||||||
|
out_channels = 1
|
||||||
|
f_maps = 32
|
||||||
|
num_levels = 2
|
||||||
|
model = UNet3D(in_channels, out_channels, f_maps=f_maps, num_levels=num_levels, layer_order='cr')
|
||||||
|
print(model)
|
||||||
|
print('number of parameters: ', sum(p.numel() for p in model.parameters()))
|
||||||
|
|
||||||
|
reso = 18
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
x = np.zeros((1, 1, reso, reso, reso))
|
||||||
|
x[:,:, int(reso/2-1), int(reso/2-1), int(reso/2-1)] = np.nan
|
||||||
|
x = torch.FloatTensor(x)
|
||||||
|
|
||||||
|
out = model(x)
|
||||||
|
print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso*reso)))
|
||||||
|
|
167
src/network/utils.py
Normal file
167
src/network/utils.py
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
""" Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class Embedder:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.create_embedding_fn()
|
||||||
|
|
||||||
|
def create_embedding_fn(self):
|
||||||
|
embed_fns = []
|
||||||
|
d = self.kwargs['input_dims']
|
||||||
|
out_dim = 0
|
||||||
|
if self.kwargs['include_input']:
|
||||||
|
embed_fns.append(lambda x: x)
|
||||||
|
out_dim += d
|
||||||
|
|
||||||
|
max_freq = self.kwargs['max_freq_log2']
|
||||||
|
N_freqs = self.kwargs['num_freqs']
|
||||||
|
|
||||||
|
if self.kwargs['log_sampling']:
|
||||||
|
freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
|
||||||
|
else:
|
||||||
|
freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
|
||||||
|
|
||||||
|
for freq in freq_bands:
|
||||||
|
for p_fn in self.kwargs['periodic_fns']:
|
||||||
|
embed_fns.append(lambda x, p_fn=p_fn,
|
||||||
|
freq=freq: p_fn(x * freq))
|
||||||
|
out_dim += d
|
||||||
|
|
||||||
|
self.embed_fns = embed_fns
|
||||||
|
self.out_dim = out_dim
|
||||||
|
|
||||||
|
def embed(self, inputs):
|
||||||
|
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
|
||||||
|
|
||||||
|
def get_embedder(multires, d_in=3):
|
||||||
|
embed_kwargs = {
|
||||||
|
'include_input': True,
|
||||||
|
'input_dims': d_in,
|
||||||
|
'max_freq_log2': multires-1,
|
||||||
|
'num_freqs': multires,
|
||||||
|
'log_sampling': True,
|
||||||
|
'periodic_fns': [torch.sin, torch.cos],
|
||||||
|
}
|
||||||
|
|
||||||
|
embedder_obj = Embedder(**embed_kwargs)
|
||||||
|
def embed(x, eo=embedder_obj): return eo.embed(x)
|
||||||
|
return embed, embedder_obj.out_dim
|
||||||
|
|
||||||
|
def normalize_coordinate(p, plane='xz'):
|
||||||
|
''' Normalize coordinate to [0, 1] for unit cube experiments
|
||||||
|
|
||||||
|
Args:
|
||||||
|
p (tensor): point
|
||||||
|
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
||||||
|
plane (str): plane feature type, ['xz', 'xy', 'yz']
|
||||||
|
'''
|
||||||
|
if plane == 'xz':
|
||||||
|
xy = p[:, :, [0, 2]]
|
||||||
|
elif plane =='xy':
|
||||||
|
xy = p[:, :, [0, 1]]
|
||||||
|
else:
|
||||||
|
xy = p[:, :, [1, 2]]
|
||||||
|
|
||||||
|
xy_new = xy
|
||||||
|
# f there are outliers out of the range
|
||||||
|
if xy_new.max() >= 1:
|
||||||
|
xy_new[xy_new >= 1] = 1 - 10e-6
|
||||||
|
if xy_new.min() < 0:
|
||||||
|
xy_new[xy_new < 0] = 0.0
|
||||||
|
return xy_new
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_3d_coordinate(p):
|
||||||
|
''' Normalize coordinate to [0, 1] for unit cube experiments.
|
||||||
|
'''
|
||||||
|
if p.max() >= 1:
|
||||||
|
p[p >= 1] = 1 - 10e-6
|
||||||
|
if p.min() < 0:
|
||||||
|
p[p < 0] = 0.0
|
||||||
|
return p
|
||||||
|
|
||||||
|
def coordinate2index(x, reso, coord_type='2d'):
|
||||||
|
''' Normalize coordinate to [0, 1] for unit cube experiments.
|
||||||
|
Corresponds to our 3D model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (tensor): coordinate
|
||||||
|
reso (int): defined resolution
|
||||||
|
coord_type (str): coordinate type
|
||||||
|
'''
|
||||||
|
x = (x * reso).long()
|
||||||
|
if coord_type == '2d': # plane
|
||||||
|
index = x[:, :, 0] + reso * x[:, :, 1]
|
||||||
|
elif coord_type == '3d': # grid
|
||||||
|
index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2])
|
||||||
|
index = index[:, None, :]
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
class map2local(object):
|
||||||
|
''' Add new keys to the given input
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s (float): the defined voxel size
|
||||||
|
pos_encoding (str): method for the positional encoding, linear|sin_cos
|
||||||
|
'''
|
||||||
|
def __init__(self, s, pos_encoding='linear'):
|
||||||
|
super().__init__()
|
||||||
|
self.s = s
|
||||||
|
# self.pe = positional_encoding(basis_function=pos_encoding, local=True)
|
||||||
|
|
||||||
|
def __call__(self, p):
|
||||||
|
# p = torch.remainder(p, self.s) / self.s # always possitive
|
||||||
|
p = (p % self.s) / self.s
|
||||||
|
p[p < 0] = 0.0
|
||||||
|
# p = torch.fmod(p, self.s) / self.s # same sign as input p!
|
||||||
|
# p = self.pe(p)
|
||||||
|
return p
|
||||||
|
|
||||||
|
# Resnet Blocks
|
||||||
|
class ResnetBlockFC(nn.Module):
|
||||||
|
''' Fully connected ResNet Block class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size_in (int): input dimension
|
||||||
|
size_out (int): output dimension
|
||||||
|
size_h (int): hidden dimension
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, size_in, size_out=None, size_h=None, siren=False):
|
||||||
|
super().__init__()
|
||||||
|
# Attributes
|
||||||
|
if size_out is None:
|
||||||
|
size_out = size_in
|
||||||
|
|
||||||
|
if size_h is None:
|
||||||
|
size_h = min(size_in, size_out)
|
||||||
|
|
||||||
|
self.size_in = size_in
|
||||||
|
self.size_h = size_h
|
||||||
|
self.size_out = size_out
|
||||||
|
# Submodules
|
||||||
|
self.fc_0 = nn.Linear(size_in, size_h)
|
||||||
|
self.fc_1 = nn.Linear(size_h, size_out)
|
||||||
|
self.actvn = nn.ReLU()
|
||||||
|
|
||||||
|
if size_in == size_out:
|
||||||
|
self.shortcut = None
|
||||||
|
else:
|
||||||
|
self.shortcut = nn.Linear(size_in, size_out, bias=False)
|
||||||
|
# Initialization
|
||||||
|
nn.init.zeros_(self.fc_1.weight)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
net = self.fc_0(self.actvn(x))
|
||||||
|
dx = self.fc_1(self.actvn(net))
|
||||||
|
|
||||||
|
if self.shortcut is not None:
|
||||||
|
x_s = self.shortcut(x)
|
||||||
|
else:
|
||||||
|
x_s = x
|
||||||
|
|
||||||
|
return x_s + dx
|
349
src/optimization.py
Normal file
349
src/optimization.py
Normal file
|
@ -0,0 +1,349 @@
|
||||||
|
import time, os
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
import trimesh
|
||||||
|
|
||||||
|
from src.dpsr import DPSR
|
||||||
|
from src.model import PSR2Mesh
|
||||||
|
from src.utils import grid_interp, verts_on_largest_mesh,\
|
||||||
|
export_pointcloud, mc_from_psr, GaussianSmoothing
|
||||||
|
from src.visualize import visualize_points_mesh, visualize_psr_grid, \
|
||||||
|
visualize_mesh_phong, render_rgb
|
||||||
|
from torchvision.utils import save_image
|
||||||
|
from torchvision.io import write_video
|
||||||
|
from pytorch3d.loss import chamfer_distance
|
||||||
|
import open3d as o3d
|
||||||
|
|
||||||
|
class Trainer(object):
|
||||||
|
'''
|
||||||
|
Args:
|
||||||
|
cfg : config file
|
||||||
|
optimizer : pytorch optimizer object
|
||||||
|
device : pytorch device
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, cfg, optimizer, device=None):
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.device = device
|
||||||
|
self.cfg = cfg
|
||||||
|
self.psr2mesh = PSR2Mesh.apply
|
||||||
|
self.data_type = cfg['data']['data_type']
|
||||||
|
|
||||||
|
# initialize DPSR
|
||||||
|
self.dpsr = DPSR(res=(cfg['model']['grid_res'],
|
||||||
|
cfg['model']['grid_res'],
|
||||||
|
cfg['model']['grid_res']),
|
||||||
|
sig=cfg['model']['psr_sigma'])
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR
|
||||||
|
self.dpsr = self.dpsr.to(device)
|
||||||
|
|
||||||
|
def train_step(self, data, inputs, model, it):
|
||||||
|
''' Performs a training step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (dict) : data dictionary
|
||||||
|
inputs (torch.tensor) : input point clouds
|
||||||
|
model (nn.Module or None): a neural network or None
|
||||||
|
it (int) : the number of iterations
|
||||||
|
'''
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss, loss_each = self.compute_loss(inputs, data, model, it)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
return loss.item(), loss_each
|
||||||
|
|
||||||
|
def compute_loss(self, inputs, data, model, it=0):
|
||||||
|
''' Compute the loss.
|
||||||
|
Args:
|
||||||
|
data (dict) : data dictionary
|
||||||
|
inputs (torch.tensor) : input point clouds
|
||||||
|
model (nn.Module or None): a neural network or None
|
||||||
|
it (int) : the number of iterations
|
||||||
|
'''
|
||||||
|
|
||||||
|
device = self.device
|
||||||
|
res = self.cfg['model']['grid_res']
|
||||||
|
|
||||||
|
# source oriented point clouds to PSR grid
|
||||||
|
psr_grid, points, normals = self.pcl2psr(inputs)
|
||||||
|
|
||||||
|
# build mesh
|
||||||
|
v, f, n = self.psr2mesh(psr_grid)
|
||||||
|
|
||||||
|
# the output is in the range of [0, 1), we make it to the real range [0, 1].
|
||||||
|
# This is a hack for our DPSR solver
|
||||||
|
v = v * res / (res-1)
|
||||||
|
|
||||||
|
points = points * 2. - 1.
|
||||||
|
v = v * 2. - 1. # within the range of (-1, 1)
|
||||||
|
|
||||||
|
loss = 0
|
||||||
|
loss_each = {}
|
||||||
|
# compute loss
|
||||||
|
if self.data_type == 'point':
|
||||||
|
if self.cfg['train']['w_chamfer'] > 0:
|
||||||
|
loss_ = self.cfg['train']['w_chamfer'] * \
|
||||||
|
self.compute_3d_loss(v, data)
|
||||||
|
loss_each['chamfer'] = loss_
|
||||||
|
loss += loss_
|
||||||
|
elif self.data_type == 'img':
|
||||||
|
loss, loss_each = self.compute_2d_loss(inputs, data, model)
|
||||||
|
|
||||||
|
return loss, loss_each
|
||||||
|
|
||||||
|
|
||||||
|
def pcl2psr(self, inputs):
|
||||||
|
''' Convert an oriented point cloud to PSR indicator grid
|
||||||
|
Args:
|
||||||
|
inputs (torch.tensor): input oriented point clouds
|
||||||
|
'''
|
||||||
|
|
||||||
|
points, normals = inputs[...,:3], inputs[...,3:]
|
||||||
|
if self.cfg['model']['apply_sigmoid']:
|
||||||
|
points = torch.sigmoid(points)
|
||||||
|
if self.cfg['model']['normal_normalize']:
|
||||||
|
normals = normals / normals.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# DPSR to get grid
|
||||||
|
psr_grid = self.dpsr(points, normals).unsqueeze(1)
|
||||||
|
psr_grid = torch.tanh(psr_grid)
|
||||||
|
|
||||||
|
return psr_grid, points, normals
|
||||||
|
|
||||||
|
def compute_3d_loss(self, v, data):
|
||||||
|
''' Compute the loss for point clouds.
|
||||||
|
Args:
|
||||||
|
v (torch.tensor) : mesh vertices
|
||||||
|
data (dict) : data dictionary
|
||||||
|
'''
|
||||||
|
|
||||||
|
pts_gt = data.get('target_points')
|
||||||
|
idx = np.random.randint(pts_gt.shape[1], size=self.cfg['train']['n_sup_point'])
|
||||||
|
if self.cfg['train']['subsample_vertex']:
|
||||||
|
#chamfer distance only on random sampled vertices
|
||||||
|
idx = np.random.randint(v.shape[1], size=self.cfg['train']['n_sup_point'])
|
||||||
|
loss, _ = chamfer_distance(v[:, idx], pts_gt)
|
||||||
|
else:
|
||||||
|
loss, _ = chamfer_distance(v, pts_gt)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def compute_2d_loss(self, inputs, data, model):
|
||||||
|
''' Compute the 2D losses.
|
||||||
|
Args:
|
||||||
|
inputs (torch.tensor) : input source point clouds
|
||||||
|
data (dict) : data dictionary
|
||||||
|
model (nn.Module or None): neural network or None
|
||||||
|
'''
|
||||||
|
|
||||||
|
losses = {"color":
|
||||||
|
{"weight": self.cfg['train']['l_weight']['rgb'],
|
||||||
|
"values": []
|
||||||
|
},
|
||||||
|
"silhouette":
|
||||||
|
{"weight": self.cfg['train']['l_weight']['mask'],
|
||||||
|
"values": []},
|
||||||
|
}
|
||||||
|
loss_all = {k: torch.tensor(0.0, device=self.device) for k in losses}
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
out = model(inputs, data)
|
||||||
|
|
||||||
|
if out['rgb'] is not None:
|
||||||
|
rgb_gt = out['rgb_gt'].reshape(self.cfg['data']['n_views_per_iter'],
|
||||||
|
-1, 3)[out['vis_mask']]
|
||||||
|
loss_all["color"] += torch.nn.L1Loss(reduction='sum')(rgb_gt,
|
||||||
|
out['rgb']) / out['rgb'].shape[0]
|
||||||
|
|
||||||
|
if out['mask'] is not None:
|
||||||
|
loss_all["silhouette"] += ((out['mask'] - out['mask_gt']) ** 2).mean()
|
||||||
|
|
||||||
|
# weighted sum of the losses
|
||||||
|
loss = torch.tensor(0.0, device=self.device)
|
||||||
|
for k, l in loss_all.items():
|
||||||
|
loss += l * losses[k]["weight"]
|
||||||
|
losses[k]["values"].append(l)
|
||||||
|
|
||||||
|
return loss, loss_all
|
||||||
|
|
||||||
|
def point_resampling(self, inputs):
|
||||||
|
''' Resample points
|
||||||
|
Args:
|
||||||
|
inputs (torch.tensor): oriented point clouds
|
||||||
|
'''
|
||||||
|
|
||||||
|
psr_grid, points, normals = self.pcl2psr(inputs)
|
||||||
|
|
||||||
|
# shortcuts
|
||||||
|
n_grow = self.cfg['train']['n_grow_points']
|
||||||
|
|
||||||
|
# [hack] for points resampled from the mesh from marching cubes,
|
||||||
|
# we need to divide by s instead of (s-1), and the scale is correct.
|
||||||
|
verts, faces, _ = mc_from_psr(psr_grid, real_scale=False, zero_level=0)
|
||||||
|
|
||||||
|
# find the largest component
|
||||||
|
pts_mesh, faces_mesh = verts_on_largest_mesh(verts, faces)
|
||||||
|
|
||||||
|
# sample vertices only from the largest component, not from fragments
|
||||||
|
mesh = trimesh.Trimesh(vertices=pts_mesh, faces=faces_mesh)
|
||||||
|
pi, face_idx = mesh.sample(n_grow+points.shape[1], return_index=True)
|
||||||
|
normals_i = mesh.face_normals[face_idx].astype('float32')
|
||||||
|
pts_mesh = torch.tensor(pi.astype('float32')).to(self.device)[None]
|
||||||
|
n_mesh = torch.tensor(normals_i).to(self.device)[None]
|
||||||
|
|
||||||
|
points, normals = pts_mesh, n_mesh
|
||||||
|
print('{} total points are resampled'.format(points.shape[1]))
|
||||||
|
|
||||||
|
# update inputs
|
||||||
|
points = torch.log(points / (1 - points)) # inverse sigmoid
|
||||||
|
inputs = torch.cat([points, normals], dim=-1)
|
||||||
|
inputs.requires_grad = True
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def visualize(self, data, inputs, renderer, epoch, o3d_vis=None):
|
||||||
|
''' Visualization.
|
||||||
|
Args:
|
||||||
|
data (dict) : data dictionary
|
||||||
|
inputs (torch.tensor) : source point clouds
|
||||||
|
renderer (nn.Module or None): a neural network or None
|
||||||
|
epoch (int) : the number of iterations
|
||||||
|
o3d_vis (o3d.Visualizer) : open3d visualizer
|
||||||
|
'''
|
||||||
|
|
||||||
|
data_type = self.cfg['data']['data_type']
|
||||||
|
it = '{:04d}'.format(int(epoch/self.cfg['train']['visualize_every']))
|
||||||
|
|
||||||
|
|
||||||
|
if (self.cfg['train']['exp_mesh']) \
|
||||||
|
| (self.cfg['train']['exp_pcl']) \
|
||||||
|
| (self.cfg['train']['o3d_show']):
|
||||||
|
psr_grid, points, normals = self.pcl2psr(inputs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
v, f, n = mc_from_psr(psr_grid, pytorchify=True,
|
||||||
|
zero_level=self.cfg['data']['zero_level'], real_scale=True)
|
||||||
|
v, f, n = v[None], f[None], n[None]
|
||||||
|
|
||||||
|
v = v * 2. - 1. # change to the range of [-1, 1]
|
||||||
|
|
||||||
|
color_v = None
|
||||||
|
if data_type == 'img':
|
||||||
|
if self.cfg['train']['vis_vert_color'] & \
|
||||||
|
(self.cfg['train']['l_weight']['rgb'] != 0.):
|
||||||
|
color_v = renderer['color'](v, n).squeeze().detach().cpu().numpy()
|
||||||
|
color_v[color_v<0], color_v[color_v>1] = 0., 1.
|
||||||
|
|
||||||
|
vv = v.detach().squeeze().cpu().numpy()
|
||||||
|
ff = f.detach().squeeze().cpu().numpy()
|
||||||
|
points = points * 2 - 1
|
||||||
|
visualize_points_mesh(o3d_vis, points, normals,
|
||||||
|
vv, ff, self.cfg, it, epoch, color_v=color_v)
|
||||||
|
|
||||||
|
else:
|
||||||
|
v, f, n = inputs
|
||||||
|
|
||||||
|
|
||||||
|
if (data_type == 'img') & (self.cfg['train']['vis_rendering']):
|
||||||
|
pred_imgs = []
|
||||||
|
pred_masks = []
|
||||||
|
n_views = len(data['poses'])
|
||||||
|
# idx_list = trange(n_views)
|
||||||
|
idx_list = [13, 24, 27, 48]
|
||||||
|
|
||||||
|
#!
|
||||||
|
model = renderer.eval()
|
||||||
|
for idx in idx_list:
|
||||||
|
pose = data['poses'][idx]
|
||||||
|
rgb = data['rgbs'][idx]
|
||||||
|
mask_gt = data['masks'][idx]
|
||||||
|
img_size = rgb.shape[0] if rgb.shape[0]== rgb.shape[1] else (rgb.shape[0], rgb.shape[1])
|
||||||
|
ray = None
|
||||||
|
if 'rays' in data.keys():
|
||||||
|
ray = data['rays'][idx]
|
||||||
|
if self.cfg['train']['l_weight']['rgb'] != 0.:
|
||||||
|
fea_grid = None
|
||||||
|
if model.unet3d is not None:
|
||||||
|
with torch.no_grad():
|
||||||
|
fea_grid = model.unet3d(psr_grid).permute(0, 2, 3, 4, 1)
|
||||||
|
if model.encoder is not None:
|
||||||
|
pp = torch.cat([(points+1)/2, normals], dim=-1)
|
||||||
|
fea_grid = model.encoder(pp,
|
||||||
|
normalize=False).permute(0, 2, 3, 4, 1)
|
||||||
|
|
||||||
|
pred, visible_mask = render_rgb(v, f, n, pose,
|
||||||
|
model.rendering_network.eval(),
|
||||||
|
img_size, ray=ray, fea_grid=fea_grid)
|
||||||
|
img_pred = torch.zeros([rgb.shape[0]*rgb.shape[1], 3])
|
||||||
|
img_pred[visible_mask] = pred.detach().cpu()
|
||||||
|
|
||||||
|
img_pred = img_pred.reshape(rgb.shape[0], rgb.shape[1], 3)
|
||||||
|
img_pred[img_pred<0], img_pred[img_pred>1] = 0., 1.
|
||||||
|
filename=os.path.join(self.cfg['train']['dir_rendering'],
|
||||||
|
'rendering_{}_{:d}.png'.format(it, idx))
|
||||||
|
save_image(img_pred.permute(2, 0, 1), filename)
|
||||||
|
pred_imgs.append(img_pred)
|
||||||
|
|
||||||
|
#! Mesh rendering using Phong shading model
|
||||||
|
filename=os.path.join(self.cfg['train']['dir_rendering'],
|
||||||
|
'mesh_{}_{:d}.png'.format(it, idx))
|
||||||
|
visualize_mesh_phong(v, f, n, pose, img_size, name=filename)
|
||||||
|
|
||||||
|
if len(pred_imgs) >= 1:
|
||||||
|
pred_imgs = torch.stack(pred_imgs, dim=0)
|
||||||
|
save_image(pred_imgs.permute(0, 3, 1, 2),
|
||||||
|
os.path.join(self.cfg['train']['dir_rendering'],
|
||||||
|
'{}.png'.format(it)), nrow=4)
|
||||||
|
if self.cfg['train']['save_video']:
|
||||||
|
write_video(os.path.join(self.cfg['train']['dir_rendering'],
|
||||||
|
'{}.mp4'.format(it)),
|
||||||
|
(pred_imgs*255.).type(torch.uint8), fps=24)
|
||||||
|
|
||||||
|
def save_mesh_pointclouds(self, inputs, epoch, center=None, scale=None):
|
||||||
|
''' Save meshes and point clouds.
|
||||||
|
Args:
|
||||||
|
inputs (torch.tensor) : source point clouds
|
||||||
|
epoch (int) : the number of iterations
|
||||||
|
center (numpy.array) : center of the shape
|
||||||
|
scale (numpy.array) : scale of the shape
|
||||||
|
'''
|
||||||
|
|
||||||
|
exp_pcl = self.cfg['train']['exp_pcl']
|
||||||
|
exp_mesh = self.cfg['train']['exp_mesh']
|
||||||
|
|
||||||
|
psr_grid, points, normals = self.pcl2psr(inputs)
|
||||||
|
|
||||||
|
if exp_pcl:
|
||||||
|
dir_pcl = self.cfg['train']['dir_pcl']
|
||||||
|
p = points.squeeze(0).detach().cpu().numpy()
|
||||||
|
p = p * 2 - 1
|
||||||
|
n = normals.squeeze(0).detach().cpu().numpy()
|
||||||
|
if scale is not None:
|
||||||
|
p *= scale
|
||||||
|
if center is not None:
|
||||||
|
p += center
|
||||||
|
export_pointcloud(os.path.join(dir_pcl, '{:04d}.ply'.format(epoch)), p, n)
|
||||||
|
if exp_mesh:
|
||||||
|
dir_mesh = self.cfg['train']['dir_mesh']
|
||||||
|
with torch.no_grad():
|
||||||
|
v, f, _ = mc_from_psr(psr_grid,
|
||||||
|
zero_level=self.cfg['data']['zero_level'], real_scale=True)
|
||||||
|
v = v * 2 - 1
|
||||||
|
if scale is not None:
|
||||||
|
v *= scale
|
||||||
|
if center is not None:
|
||||||
|
v += center
|
||||||
|
mesh = o3d.geometry.TriangleMesh()
|
||||||
|
mesh.vertices = o3d.utility.Vector3dVector(v)
|
||||||
|
mesh.triangles = o3d.utility.Vector3iVector(f)
|
||||||
|
outdir_mesh = os.path.join(dir_mesh, '{:04d}.ply'.format(epoch))
|
||||||
|
o3d.io.write_triangle_mesh(outdir_mesh, mesh)
|
||||||
|
|
||||||
|
if self.cfg['train']['vis_psr']:
|
||||||
|
dir_psr_vis = self.cfg['train']['out_dir']+'/psr_vis_all'
|
||||||
|
visualize_psr_grid(psr_grid, out_dir=dir_psr_vis)
|
207
src/training.py
Normal file
207
src/training.py
Normal file
|
@ -0,0 +1,207 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from collections import defaultdict
|
||||||
|
import trimesh
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from src.dpsr import DPSR
|
||||||
|
from src.utils import grid_interp, export_pointcloud, export_mesh, \
|
||||||
|
mc_from_psr, scale2onet, GaussianSmoothing
|
||||||
|
from pytorch3d.ops.knn import knn_gather, knn_points
|
||||||
|
from pytorch3d.loss import chamfer_distance
|
||||||
|
from pdb import set_trace as st
|
||||||
|
|
||||||
|
class Trainer(object):
|
||||||
|
'''
|
||||||
|
Args:
|
||||||
|
model (nn.Module): our defined model
|
||||||
|
optimizer (optimizer): pytorch optimizer object
|
||||||
|
device (device): pytorch device
|
||||||
|
input_type (str): input type
|
||||||
|
vis_dir (str): visualization directory
|
||||||
|
'''
|
||||||
|
def __init__(self, cfg, optimizer, device=None):
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.device = device
|
||||||
|
self.cfg = cfg
|
||||||
|
if self.cfg['train']['w_raw'] != 0:
|
||||||
|
from src.model import PSR2Mesh
|
||||||
|
self.psr2mesh = PSR2Mesh.apply
|
||||||
|
|
||||||
|
# initialize DPSR
|
||||||
|
self.dpsr = DPSR(res=(cfg['model']['grid_res'],
|
||||||
|
cfg['model']['grid_res'],
|
||||||
|
cfg['model']['grid_res']),
|
||||||
|
sig=cfg['model']['psr_sigma'])
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR
|
||||||
|
self.dpsr = self.dpsr.to(device)
|
||||||
|
|
||||||
|
if cfg['train']['gauss_weight']>0.:
|
||||||
|
self.gauss_smooth = GaussianSmoothing(1, 7, 2).to(device)
|
||||||
|
|
||||||
|
def train_step(self, inputs, data, model):
|
||||||
|
''' Performs a training step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (dict): data dictionary
|
||||||
|
'''
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
p = data.get('inputs').to(self.device)
|
||||||
|
|
||||||
|
out = model(p)
|
||||||
|
|
||||||
|
points, normals = out
|
||||||
|
|
||||||
|
loss = 0
|
||||||
|
loss_each = {}
|
||||||
|
if self.cfg['train']['w_psr'] != 0:
|
||||||
|
psr_gt = data.get('gt_psr').to(self.device)
|
||||||
|
if self.cfg['model']['psr_tanh']:
|
||||||
|
psr_gt = torch.tanh(psr_gt)
|
||||||
|
|
||||||
|
psr_grid = self.dpsr(points, normals)
|
||||||
|
if self.cfg['model']['psr_tanh']:
|
||||||
|
psr_grid = torch.tanh(psr_grid)
|
||||||
|
|
||||||
|
# apply a rescaling weight based on GT SDF values
|
||||||
|
if self.cfg['train']['gauss_weight']>0:
|
||||||
|
gauss_sigma = self.cfg['train']['gauss_weight']
|
||||||
|
# set up the weighting for loss, higher weights
|
||||||
|
# for points near to the surface
|
||||||
|
psr_gt_pad = torch.nn.ReplicationPad3d(1)(psr_gt.unsqueeze(1)).squeeze(1)
|
||||||
|
delta_x = delta_y = delta_z = 1
|
||||||
|
grad_x = (psr_gt_pad[:, 2:, :, :] - psr_gt_pad[:, :-2, :, :]) / 2 / delta_x
|
||||||
|
grad_y = (psr_gt_pad[:, :, 2:, :] - psr_gt_pad[:, :, :-2, :]) / 2 / delta_y
|
||||||
|
grad_z = (psr_gt_pad[:, :, :, 2:] - psr_gt_pad[:, :, :, :-2]) / 2 / delta_z
|
||||||
|
grad_x = grad_x[:, :, 1:-1, 1:-1]
|
||||||
|
grad_y = grad_y[:, 1:-1, :, 1:-1]
|
||||||
|
grad_z = grad_z[:, 1:-1, 1:-1, :]
|
||||||
|
psr_grad = torch.stack([grad_x, grad_y, grad_z], dim=-1)
|
||||||
|
psr_grad_norm = psr_grad.norm(dim=-1)[:, None]
|
||||||
|
w = torch.nn.ReplicationPad3d(3)(psr_grad_norm)
|
||||||
|
w = 2*self.gauss_smooth(w).squeeze(1)
|
||||||
|
loss_each['psr'] = self.cfg['train']['w_psr'] * F.mse_loss(w*psr_grid, w*psr_gt)
|
||||||
|
else:
|
||||||
|
loss_each['psr'] = self.cfg['train']['w_psr'] * F.mse_loss(psr_grid, psr_gt)
|
||||||
|
|
||||||
|
loss += loss_each['psr']
|
||||||
|
|
||||||
|
# regularization on the input point positions via chamfer distance
|
||||||
|
if self.cfg['train']['w_reg_point'] != 0.:
|
||||||
|
points_gt = data.get('gt_points').to(self.device)
|
||||||
|
loss_reg, loss_norm = chamfer_distance(points, points_gt)
|
||||||
|
|
||||||
|
loss_each['reg'] = self.cfg['train']['w_reg_point'] * loss_reg
|
||||||
|
loss += loss_each['reg']
|
||||||
|
|
||||||
|
if self.cfg['train']['w_normals'] != 0.:
|
||||||
|
points_gt = data.get('gt_points').to(self.device)
|
||||||
|
normals_gt = data.get('gt_points.normals').to(self.device)
|
||||||
|
x_nn = knn_points(points, points_gt, K=1)
|
||||||
|
x_normals_near = knn_gather(normals_gt, x_nn.idx)[..., 0, :]
|
||||||
|
|
||||||
|
cham_norm_x = F.l1_loss(normals, x_normals_near)
|
||||||
|
loss_norm = cham_norm_x
|
||||||
|
|
||||||
|
loss_each['normals'] = self.cfg['train']['w_normals'] * loss_norm
|
||||||
|
loss += loss_each['normals']
|
||||||
|
|
||||||
|
if self.cfg['train']['w_raw'] != 0:
|
||||||
|
res = self.cfg['model']['grid_res']
|
||||||
|
# DPSR to get grid
|
||||||
|
psr_grid = self.dpsr(points, normals)
|
||||||
|
if self.cfg['model']['psr_tanh']:
|
||||||
|
psr_grid = torch.tanh(psr_grid)
|
||||||
|
|
||||||
|
v, f, n = self.psr2mesh(psr_grid)
|
||||||
|
|
||||||
|
pts_gt = data.get('gt_points').to(self.device)
|
||||||
|
|
||||||
|
loss, _ = chamfer_distance(v, pts_gt)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
return loss.item(), loss_each
|
||||||
|
|
||||||
|
def save(self, model, data, epoch, id):
|
||||||
|
|
||||||
|
p = data.get('inputs').to(self.device)
|
||||||
|
|
||||||
|
exp_pcl = self.cfg['train']['exp_pcl']
|
||||||
|
exp_mesh = self.cfg['train']['exp_mesh']
|
||||||
|
exp_gt = self.cfg['generation']['exp_gt']
|
||||||
|
exp_input = self.cfg['generation']['exp_input']
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
points, normals = model(p)
|
||||||
|
|
||||||
|
if exp_gt:
|
||||||
|
points_gt = data.get('gt_points').to(self.device)
|
||||||
|
normals_gt = data.get('gt_points.normals').to(self.device)
|
||||||
|
|
||||||
|
if exp_pcl:
|
||||||
|
dir_pcl = self.cfg['train']['dir_pcl']
|
||||||
|
export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}.ply'.format(epoch, id)), scale2onet(points), normals)
|
||||||
|
if exp_gt:
|
||||||
|
export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}_oracle.ply'.format(epoch, id)), scale2onet(points_gt), normals_gt)
|
||||||
|
if exp_input:
|
||||||
|
export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}_input.ply'.format(epoch, id)), scale2onet(p))
|
||||||
|
|
||||||
|
if exp_mesh:
|
||||||
|
dir_mesh = self.cfg['train']['dir_mesh']
|
||||||
|
psr_grid = self.dpsr(points, normals)
|
||||||
|
# psr_grid = torch.tanh(psr_grid)
|
||||||
|
with torch.no_grad():
|
||||||
|
v, f, _ = mc_from_psr(psr_grid,
|
||||||
|
zero_level=self.cfg['data']['zero_level'])
|
||||||
|
outdir_mesh = os.path.join(dir_mesh, '{:04d}_{:01d}.ply'.format(epoch, id))
|
||||||
|
export_mesh(outdir_mesh, scale2onet(v), f)
|
||||||
|
if exp_gt:
|
||||||
|
psr_gt = self.dpsr(points_gt, normals_gt)
|
||||||
|
with torch.no_grad():
|
||||||
|
v, f, _ = mc_from_psr(psr_gt,
|
||||||
|
zero_level=self.cfg['data']['zero_level'])
|
||||||
|
export_mesh(os.path.join(dir_mesh, '{:04d}_{:01d}_oracle.ply'.format(epoch, id)), scale2onet(v), f)
|
||||||
|
|
||||||
|
def evaluate(self, val_loader, model):
|
||||||
|
''' Performs an evaluation.
|
||||||
|
Args:
|
||||||
|
val_loader (dataloader): pytorch dataloader
|
||||||
|
'''
|
||||||
|
eval_list = defaultdict(list)
|
||||||
|
|
||||||
|
for data in tqdm(val_loader):
|
||||||
|
eval_step_dict = self.eval_step(data, model)
|
||||||
|
|
||||||
|
for k, v in eval_step_dict.items():
|
||||||
|
eval_list[k].append(v)
|
||||||
|
|
||||||
|
eval_dict = {k: np.mean(v) for k, v in eval_list.items()}
|
||||||
|
return eval_dict
|
||||||
|
|
||||||
|
def eval_step(self, data, model):
|
||||||
|
''' Performs an evaluation step.
|
||||||
|
Args:
|
||||||
|
data (dict): data dictionary
|
||||||
|
'''
|
||||||
|
model.eval()
|
||||||
|
eval_dict = {}
|
||||||
|
|
||||||
|
p = data.get('inputs').to(self.device)
|
||||||
|
psr_gt = data.get('gt_psr').to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# forward pass
|
||||||
|
points, normals = model(p)
|
||||||
|
# DPSR to get predicted psr grid
|
||||||
|
psr_grid = self.dpsr(points, normals)
|
||||||
|
|
||||||
|
eval_dict['psr_l1'] = F.l1_loss(psr_grid, psr_gt).item()
|
||||||
|
eval_dict['psr_l2'] = F.mse_loss(psr_grid, psr_gt).item()
|
||||||
|
|
||||||
|
return eval_dict
|
645
src/utils.py
Normal file
645
src/utils.py
Normal file
|
@ -0,0 +1,645 @@
|
||||||
|
import torch
|
||||||
|
import io, os, logging, urllib
|
||||||
|
import yaml
|
||||||
|
import trimesh
|
||||||
|
import imageio
|
||||||
|
import numbers
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
from collections import OrderedDict
|
||||||
|
from plyfile import PlyData
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.utils import model_zoo
|
||||||
|
from skimage import measure, img_as_float32
|
||||||
|
from pytorch3d.structures import Meshes
|
||||||
|
from pytorch3d.renderer import PerspectiveCameras, rasterize_meshes
|
||||||
|
from igl import adjacency_matrix, connected_components
|
||||||
|
import open3d as o3d
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# Below are functions for DPSR
|
||||||
|
|
||||||
|
def fftfreqs(res, dtype=torch.float32, exact=True):
|
||||||
|
"""
|
||||||
|
Helper function to return frequency tensors
|
||||||
|
:param res: n_dims int tuple of number of frequency modes
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
n_dims = len(res)
|
||||||
|
freqs = []
|
||||||
|
for dim in range(n_dims - 1):
|
||||||
|
r_ = res[dim]
|
||||||
|
freq = np.fft.fftfreq(r_, d=1/r_)
|
||||||
|
freqs.append(torch.tensor(freq, dtype=dtype))
|
||||||
|
r_ = res[-1]
|
||||||
|
if exact:
|
||||||
|
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype))
|
||||||
|
else:
|
||||||
|
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype))
|
||||||
|
omega = torch.meshgrid(freqs)
|
||||||
|
omega = list(omega)
|
||||||
|
omega = torch.stack(omega, dim=-1)
|
||||||
|
|
||||||
|
return omega
|
||||||
|
|
||||||
|
def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag)
|
||||||
|
"""
|
||||||
|
multiply tensor x by i ** deg
|
||||||
|
"""
|
||||||
|
deg %= 4
|
||||||
|
if deg == 0:
|
||||||
|
res = x
|
||||||
|
elif deg == 1:
|
||||||
|
res = x[..., [1, 0]]
|
||||||
|
res[..., 0] = -res[..., 0]
|
||||||
|
elif deg == 2:
|
||||||
|
res = -x
|
||||||
|
elif deg == 3:
|
||||||
|
res = x[..., [1, 0]]
|
||||||
|
res[..., 1] = -res[..., 1]
|
||||||
|
return res
|
||||||
|
|
||||||
|
def spec_gaussian_filter(res, sig):
|
||||||
|
omega = fftfreqs(res, dtype=torch.float64) # [dim0, dim1, dim2, d]
|
||||||
|
dis = torch.sqrt(torch.sum(omega ** 2, dim=-1))
|
||||||
|
filter_ = torch.exp(-0.5*((sig*2*dis/res[0])**2)).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
filter_.requires_grad = False
|
||||||
|
|
||||||
|
return filter_
|
||||||
|
|
||||||
|
def grid_interp(grid, pts, batched=True):
|
||||||
|
"""
|
||||||
|
:param grid: tensor of shape (batch, *size, in_features)
|
||||||
|
:param pts: tensor of shape (batch, num_points, dim) within range (0, 1)
|
||||||
|
:return values at query points
|
||||||
|
"""
|
||||||
|
if not batched:
|
||||||
|
grid = grid.unsqueeze(0)
|
||||||
|
pts = pts.unsqueeze(0)
|
||||||
|
dim = pts.shape[-1]
|
||||||
|
bs = grid.shape[0]
|
||||||
|
size = torch.tensor(grid.shape[1:-1]).to(grid.device).type(pts.dtype)
|
||||||
|
cubesize = 1.0 / size
|
||||||
|
|
||||||
|
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
|
||||||
|
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
|
||||||
|
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
|
||||||
|
tmp = torch.tensor([0,1],dtype=torch.long)
|
||||||
|
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
|
||||||
|
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
|
||||||
|
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
|
||||||
|
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
|
||||||
|
ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
|
||||||
|
# latent code on neighbor nodes
|
||||||
|
if dim == 2:
|
||||||
|
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1]] # (batch, num_points, 2**dim, in_features)
|
||||||
|
else:
|
||||||
|
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2]] # (batch, num_points, 2**dim, in_features)
|
||||||
|
|
||||||
|
# weights of neighboring nodes
|
||||||
|
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
|
||||||
|
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
|
||||||
|
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
|
||||||
|
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
||||||
|
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
||||||
|
pos_ = pos_.type(pts.dtype)
|
||||||
|
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
|
||||||
|
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
|
||||||
|
query_values = torch.sum(lat * weights.unsqueeze(-1), dim=-2) # (batch, num_points, in_features)
|
||||||
|
if not batched:
|
||||||
|
query_values = query_values.squeeze(0)
|
||||||
|
|
||||||
|
return query_values
|
||||||
|
|
||||||
|
def scatter_to_grid(inds, vals, size):
|
||||||
|
"""
|
||||||
|
Scatter update values into empty tensor of size size.
|
||||||
|
:param inds: (#values, dims)
|
||||||
|
:param vals: (#values)
|
||||||
|
:param size: tuple for size. len(size)=dims
|
||||||
|
"""
|
||||||
|
dims = inds.shape[1]
|
||||||
|
assert(inds.shape[0] == vals.shape[0])
|
||||||
|
assert(len(size) == dims)
|
||||||
|
dev = vals.device
|
||||||
|
# result = torch.zeros(*size).view(-1).to(dev).type(vals.dtype) # flatten
|
||||||
|
# # flatten inds
|
||||||
|
result = torch.zeros(*size, device=dev).view(-1).type(vals.dtype) # flatten
|
||||||
|
# flatten inds
|
||||||
|
fac = [np.prod(size[i+1:]) for i in range(len(size)-1)] + [1]
|
||||||
|
fac = torch.tensor(fac, device=dev).type(inds.dtype)
|
||||||
|
inds_fold = torch.sum(inds*fac, dim=-1) # [#values,]
|
||||||
|
result.scatter_add_(0, inds_fold, vals)
|
||||||
|
result = result.view(*size)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def point_rasterize(pts, vals, size):
|
||||||
|
"""
|
||||||
|
:param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1)
|
||||||
|
:param vals: point values, tensor of shape (batch, num_points, features)
|
||||||
|
:param size: len(size)=dim tuple for grid size
|
||||||
|
:return rasterized values (batch, features, res0, res1, res2)
|
||||||
|
"""
|
||||||
|
dim = pts.shape[-1]
|
||||||
|
assert(pts.shape[:2] == vals.shape[:2])
|
||||||
|
assert(pts.shape[2] == dim)
|
||||||
|
size_list = list(size)
|
||||||
|
size = torch.tensor(size).to(pts.device).float()
|
||||||
|
cubesize = 1.0 / size
|
||||||
|
bs = pts.shape[0]
|
||||||
|
nf = vals.shape[-1]
|
||||||
|
npts = pts.shape[1]
|
||||||
|
dev = pts.device
|
||||||
|
|
||||||
|
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
|
||||||
|
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
|
||||||
|
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
|
||||||
|
tmp = torch.tensor([0,1],dtype=torch.long)
|
||||||
|
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
|
||||||
|
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
|
||||||
|
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
|
||||||
|
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
|
||||||
|
# ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
|
||||||
|
ind_b = torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
|
||||||
|
|
||||||
|
# weights of neighboring nodes
|
||||||
|
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
|
||||||
|
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
|
||||||
|
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
|
||||||
|
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
||||||
|
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
||||||
|
pos_ = pos_.type(pts.dtype)
|
||||||
|
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
|
||||||
|
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
|
||||||
|
|
||||||
|
ind_b = ind_b.unsqueeze(-1).unsqueeze(-1) # (batch, num_points, 2**dim, 1, 1)
|
||||||
|
ind_n = ind_n.unsqueeze(-2) # (batch, num_points, 2**dim, 1, dim)
|
||||||
|
ind_f = torch.arange(nf, device=dev).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1)
|
||||||
|
# ind_f = torch.arange(nf).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1)
|
||||||
|
|
||||||
|
ind_b = ind_b.expand(bs, npts, 2**dim, nf, 1)
|
||||||
|
ind_n = ind_n.expand(bs, npts, 2**dim, nf, dim).to(dev)
|
||||||
|
ind_f = ind_f.expand(bs, npts, 2**dim, nf, 1)
|
||||||
|
inds = torch.cat([ind_b, ind_f, ind_n], dim=-1) # (batch, num_points, 2**dim, nf, 1+1+dim)
|
||||||
|
|
||||||
|
# weighted values
|
||||||
|
vals = weights.unsqueeze(-1) * vals.unsqueeze(-2) # (batch, num_points, 2**dim, nf)
|
||||||
|
|
||||||
|
inds = inds.view(-1, dim+2).permute(1, 0).long() # (1+dim+1, bs*npts*2**dim*nf)
|
||||||
|
vals = vals.reshape(-1) # (bs*npts*2**dim*nf)
|
||||||
|
tensor_size = [bs, nf] + size_list
|
||||||
|
raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf] + size_list)
|
||||||
|
|
||||||
|
return raster # [batch, nf, res, res, res]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# Below are the utilization functions in general
|
||||||
|
|
||||||
|
class AverageMeter(object):
|
||||||
|
"""Computes and stores the average and current value"""
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.val = 0
|
||||||
|
self.n = 0
|
||||||
|
self.avg = 0
|
||||||
|
self.sum = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def update(self, val, n=1):
|
||||||
|
self.val = val
|
||||||
|
self.n = n
|
||||||
|
self.sum += val * n
|
||||||
|
self.count += n
|
||||||
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valcavg(self):
|
||||||
|
return self.val.sum().item() / (self.n != 0).sum().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avgcavg(self):
|
||||||
|
return self.avg.sum().item() / (self.count != 0).sum().item()
|
||||||
|
|
||||||
|
def load_model_manual(state_dict, model):
|
||||||
|
new_state_dict = OrderedDict()
|
||||||
|
is_model_parallel = isinstance(model, torch.nn.DataParallel)
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if k.startswith('module.') != is_model_parallel:
|
||||||
|
if k.startswith('module.'):
|
||||||
|
# remove module
|
||||||
|
k = k[7:]
|
||||||
|
else:
|
||||||
|
# add module
|
||||||
|
k = 'module.' + k
|
||||||
|
|
||||||
|
new_state_dict[k]=v
|
||||||
|
|
||||||
|
model.load_state_dict(new_state_dict)
|
||||||
|
|
||||||
|
def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0):
|
||||||
|
'''
|
||||||
|
Run marching cubes from PSR grid
|
||||||
|
'''
|
||||||
|
batch_size = psr_grid.shape[0]
|
||||||
|
s = psr_grid.shape[-1] # size of psr_grid
|
||||||
|
psr_grid_numpy = psr_grid.squeeze().detach().cpu().numpy()
|
||||||
|
|
||||||
|
if batch_size>1:
|
||||||
|
verts, faces, normals = [], [], []
|
||||||
|
for i in range(batch_size):
|
||||||
|
verts_cur, faces_cur, normals_cur, values = measure.marching_cubes(psr_grid_numpy[i], level=0)
|
||||||
|
verts.append(verts_cur)
|
||||||
|
faces.append(faces_cur)
|
||||||
|
normals.append(normals_cur)
|
||||||
|
verts = np.stack(verts, axis = 0)
|
||||||
|
faces = np.stack(faces, axis = 0)
|
||||||
|
normals = np.stack(normals, axis = 0)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy, level=zero_level)
|
||||||
|
except:
|
||||||
|
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy)
|
||||||
|
if real_scale:
|
||||||
|
verts = verts / (s-1) # scale to range [0, 1]
|
||||||
|
else:
|
||||||
|
verts = verts / s # scale to range [0, 1)
|
||||||
|
|
||||||
|
if pytorchify:
|
||||||
|
device = psr_grid.device
|
||||||
|
verts = torch.Tensor(np.ascontiguousarray(verts)).to(device)
|
||||||
|
faces = torch.Tensor(np.ascontiguousarray(faces)).to(device)
|
||||||
|
normals = torch.Tensor(np.ascontiguousarray(-normals)).to(device)
|
||||||
|
|
||||||
|
return verts, faces, normals
|
||||||
|
|
||||||
|
def calc_inters_points(verts, faces, pose, img_size, mask_gt=None):
|
||||||
|
verts = verts.squeeze()
|
||||||
|
faces = faces.squeeze()
|
||||||
|
pix_to_face, w, mask = mesh_rasterization(verts, faces, pose, img_size)
|
||||||
|
if mask_gt is not None:
|
||||||
|
#! only evaluate within the intersection
|
||||||
|
mask = mask & mask_gt
|
||||||
|
# find 3D points intesected on the mesh
|
||||||
|
if True:
|
||||||
|
w_masked = w[mask]
|
||||||
|
f_p = faces[pix_to_face[mask]].long() # cooresponding faces for each pixel
|
||||||
|
# corresponding vertices for p_closest
|
||||||
|
v_a, v_b, v_c = verts[f_p[..., 0]], verts[f_p[..., 1]], verts[f_p[..., 2]]
|
||||||
|
|
||||||
|
# calculate the intersection point of each pixel and the mesh
|
||||||
|
p_inters = w_masked[..., 0, None] * v_a + \
|
||||||
|
w_masked[..., 1, None] * v_b + \
|
||||||
|
w_masked[..., 2, None] * v_c
|
||||||
|
else:
|
||||||
|
# backproject ndc to world coordinates using z-buffer
|
||||||
|
W, H = img_size[1], img_size[0]
|
||||||
|
xy = uv.to(mask.device)[mask]
|
||||||
|
x_ndc = 1 - (2*xy[:, 0]) / (W - 1)
|
||||||
|
y_ndc = 1 - (2*xy[:, 1]) / (H - 1)
|
||||||
|
z = zbuf.squeeze().reshape(H * W)[mask]
|
||||||
|
xy_depth = torch.stack((x_ndc, y_ndc, z), dim=1)
|
||||||
|
|
||||||
|
p_inters = pose.unproject_points(xy_depth, world_coordinates=True)
|
||||||
|
|
||||||
|
# if there are outlier points, we should remove it
|
||||||
|
if (p_inters.max()>1) | (p_inters.min()<-1):
|
||||||
|
mask_bound = (p_inters>=-1) & (p_inters<=1)
|
||||||
|
mask_bound = (mask_bound.sum(dim=-1)==3)
|
||||||
|
mask[mask==True] = mask_bound
|
||||||
|
p_inters = p_inters[mask_bound]
|
||||||
|
print('!!!!!find outlier!')
|
||||||
|
|
||||||
|
return p_inters, mask, f_p, w_masked
|
||||||
|
|
||||||
|
def mesh_rasterization(verts, faces, pose, img_size):
|
||||||
|
'''
|
||||||
|
Use PyTorch3D to rasterize the mesh given a camera
|
||||||
|
'''
|
||||||
|
transformed_v = pose.transform_points(verts.detach()) # world -> ndc coordinate system
|
||||||
|
if isinstance(pose, PerspectiveCameras):
|
||||||
|
transformed_v[..., 2] = 1/transformed_v[..., 2]
|
||||||
|
# find p_closest on mesh of each pixel via rasterization
|
||||||
|
transformed_mesh = Meshes(verts=[transformed_v], faces=[faces])
|
||||||
|
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
|
||||||
|
transformed_mesh,
|
||||||
|
image_size=img_size,
|
||||||
|
blur_radius=0,
|
||||||
|
faces_per_pixel=1,
|
||||||
|
perspective_correct=False
|
||||||
|
)
|
||||||
|
pix_to_face = pix_to_face.reshape(1, -1) # B x reso x reso -> B x (reso x reso)
|
||||||
|
mask = pix_to_face.clone() != -1
|
||||||
|
mask = mask.squeeze()
|
||||||
|
pix_to_face = pix_to_face.squeeze()
|
||||||
|
w = bary_coords.reshape(-1, 3)
|
||||||
|
|
||||||
|
return pix_to_face, w, mask
|
||||||
|
|
||||||
|
def verts_on_largest_mesh(verts, faces):
|
||||||
|
'''
|
||||||
|
verts: Numpy array or Torch.Tensor (N, 3)
|
||||||
|
faces: Numpy array (N, 3)
|
||||||
|
'''
|
||||||
|
if torch.is_tensor(faces):
|
||||||
|
verts = verts.squeeze().detach().cpu().numpy()
|
||||||
|
faces = faces.squeeze().int().detach().cpu().numpy()
|
||||||
|
|
||||||
|
A = adjacency_matrix(faces)
|
||||||
|
num, conn_idx, conn_size = connected_components(A)
|
||||||
|
if num == 0:
|
||||||
|
v_large, f_large = verts, faces
|
||||||
|
else:
|
||||||
|
max_idx = conn_size.argmax() # find the index of the largest component
|
||||||
|
v_large = verts[conn_idx==max_idx] # keep points on the largest component
|
||||||
|
|
||||||
|
if True:
|
||||||
|
mesh_largest = trimesh.Trimesh(verts, faces)
|
||||||
|
connected_comp = mesh_largest.split(only_watertight=False)
|
||||||
|
mesh_largest = connected_comp[max_idx]
|
||||||
|
v_large, f_large = mesh_largest.vertices, mesh_largest.faces
|
||||||
|
v_large = v_large.astype(np.float32)
|
||||||
|
return v_large, f_large
|
||||||
|
|
||||||
|
def load_pointcloud(in_file):
|
||||||
|
plydata = PlyData.read(in_file)
|
||||||
|
vertices = np.stack([
|
||||||
|
plydata['vertex']['x'],
|
||||||
|
plydata['vertex']['y'],
|
||||||
|
plydata['vertex']['z']
|
||||||
|
], axis=1)
|
||||||
|
return vertices
|
||||||
|
|
||||||
|
# General config
|
||||||
|
def load_config(path, default_path=None):
|
||||||
|
''' Loads config file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): path to config file
|
||||||
|
default_path (bool): whether to use default path
|
||||||
|
'''
|
||||||
|
# Load configuration from file itself
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
cfg_special = yaml.load(f, Loader=yaml.Loader)
|
||||||
|
|
||||||
|
# Check if we should inherit from a config
|
||||||
|
inherit_from = cfg_special.get('inherit_from')
|
||||||
|
|
||||||
|
# If yes, load this config first as default
|
||||||
|
# If no, use the default_path
|
||||||
|
if inherit_from is not None:
|
||||||
|
cfg = load_config(inherit_from, default_path)
|
||||||
|
elif default_path is not None:
|
||||||
|
with open(default_path, 'r') as f:
|
||||||
|
cfg = yaml.load(f, Loader=yaml.Loader)
|
||||||
|
else:
|
||||||
|
cfg = dict()
|
||||||
|
|
||||||
|
# Include main configuration
|
||||||
|
update_recursive(cfg, cfg_special)
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
def update_config(config, unknown):
|
||||||
|
# update config given args
|
||||||
|
for idx,arg in enumerate(unknown):
|
||||||
|
if arg.startswith("--"):
|
||||||
|
keys = arg.replace("--","").split(':')
|
||||||
|
assert(len(keys)==2)
|
||||||
|
k1, k2 = keys
|
||||||
|
argtype = type(config[k1][k2])
|
||||||
|
if argtype == bool:
|
||||||
|
v = unknown[idx+1].lower() == 'true'
|
||||||
|
else:
|
||||||
|
if config[k1][k2] is not None:
|
||||||
|
v = type(config[k1][k2])(unknown[idx+1])
|
||||||
|
else:
|
||||||
|
v = unknown[idx+1]
|
||||||
|
print(f'Changing {k1}:{k2} ---- {config[k1][k2]} to {v}')
|
||||||
|
config[k1][k2] = v
|
||||||
|
return config
|
||||||
|
|
||||||
|
def initialize_logger(cfg):
|
||||||
|
out_dir = cfg['train']['out_dir']
|
||||||
|
if not out_dir:
|
||||||
|
os.makedirs(out_dir)
|
||||||
|
|
||||||
|
cfg['train']['dir_model'] = os.path.join(out_dir, 'model')
|
||||||
|
os.makedirs(cfg['train']['dir_model'], exist_ok=True)
|
||||||
|
|
||||||
|
if cfg['train']['exp_mesh']:
|
||||||
|
cfg['train']['dir_mesh'] = os.path.join(out_dir, 'vis/mesh')
|
||||||
|
os.makedirs(cfg['train']['dir_mesh'], exist_ok=True)
|
||||||
|
if cfg['train']['exp_pcl']:
|
||||||
|
cfg['train']['dir_pcl'] = os.path.join(out_dir, 'vis/pointcloud')
|
||||||
|
os.makedirs(cfg['train']['dir_pcl'], exist_ok=True)
|
||||||
|
if cfg['train']['vis_rendering']:
|
||||||
|
cfg['train']['dir_rendering'] = os.path.join(out_dir, 'vis/rendering')
|
||||||
|
os.makedirs(cfg['train']['dir_rendering'], exist_ok=True)
|
||||||
|
if cfg['train']['o3d_show']:
|
||||||
|
cfg['train']['dir_o3d'] = os.path.join(out_dir, 'vis/o3d')
|
||||||
|
os.makedirs(cfg['train']['dir_o3d'], exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("train")
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
logger.handlers = []
|
||||||
|
# ch = logging.StreamHandler()
|
||||||
|
# logger.addHandler(ch)
|
||||||
|
fh = logging.FileHandler(os.path.join(cfg['train']['out_dir'], "log.txt"))
|
||||||
|
logger.addHandler(fh)
|
||||||
|
logger.info('Outout dir: %s', out_dir)
|
||||||
|
return logger
|
||||||
|
|
||||||
|
def update_recursive(dict1, dict2):
|
||||||
|
''' Update two config dictionaries recursively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dict1 (dict): first dictionary to be updated
|
||||||
|
dict2 (dict): second dictionary which entries should be used
|
||||||
|
|
||||||
|
'''
|
||||||
|
for k, v in dict2.items():
|
||||||
|
if k not in dict1:
|
||||||
|
dict1[k] = dict()
|
||||||
|
if isinstance(v, dict):
|
||||||
|
update_recursive(dict1[k], v)
|
||||||
|
else:
|
||||||
|
dict1[k] = v
|
||||||
|
|
||||||
|
def export_pointcloud(name, points, normals=None):
|
||||||
|
if len(points.shape) > 2:
|
||||||
|
points = points[0]
|
||||||
|
if normals is not None:
|
||||||
|
normals = normals[0]
|
||||||
|
if isinstance(points, torch.Tensor):
|
||||||
|
points = points.detach().cpu().numpy()
|
||||||
|
if normals is not None:
|
||||||
|
normals = normals.detach().cpu().numpy()
|
||||||
|
pcd = o3d.geometry.PointCloud()
|
||||||
|
pcd.points = o3d.utility.Vector3dVector(points)
|
||||||
|
if normals is not None:
|
||||||
|
pcd.normals = o3d.utility.Vector3dVector(normals)
|
||||||
|
o3d.io.write_point_cloud(name, pcd)
|
||||||
|
|
||||||
|
def export_mesh(name, v, f):
|
||||||
|
if len(v.shape) > 2:
|
||||||
|
v, f = v[0], f[0]
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
v = v.detach().cpu().numpy()
|
||||||
|
f = f.detach().cpu().numpy()
|
||||||
|
mesh = o3d.geometry.TriangleMesh()
|
||||||
|
mesh.vertices = o3d.utility.Vector3dVector(v)
|
||||||
|
mesh.triangles = o3d.utility.Vector3iVector(f)
|
||||||
|
o3d.io.write_triangle_mesh(name, mesh)
|
||||||
|
|
||||||
|
def scale2onet(p, scale=1.2):
|
||||||
|
'''
|
||||||
|
Scale the point cloud from SAP to ONet range
|
||||||
|
'''
|
||||||
|
return (p - 0.5) * scale
|
||||||
|
|
||||||
|
def update_optimizer(inputs, cfg, epoch, model=None, schedule=None):
|
||||||
|
if model is not None:
|
||||||
|
if schedule is not None:
|
||||||
|
optimizer = torch.optim.Adam([
|
||||||
|
{"params": model.parameters(),
|
||||||
|
"lr": schedule[0].get_learning_rate(epoch)},
|
||||||
|
{"params": inputs,
|
||||||
|
"lr": schedule[1].get_learning_rate(epoch)}])
|
||||||
|
elif 'lr' in cfg['train']:
|
||||||
|
optimizer = torch.optim.Adam([
|
||||||
|
{"params": model.parameters(),
|
||||||
|
"lr": float(cfg['train']['lr'])},
|
||||||
|
{"params": inputs,
|
||||||
|
"lr": float(cfg['train']['lr_pcl'])}])
|
||||||
|
else:
|
||||||
|
raise Exception('no known learning rate')
|
||||||
|
else:
|
||||||
|
if schedule is not None:
|
||||||
|
optimizer = torch.optim.Adam([inputs], lr=schedule[0].get_learning_rate(epoch))
|
||||||
|
else:
|
||||||
|
optimizer = torch.optim.Adam([inputs], lr=float(cfg['train']['lr_pcl']))
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def is_url(url):
|
||||||
|
scheme = urllib.parse.urlparse(url).scheme
|
||||||
|
return scheme in ('http', 'https')
|
||||||
|
|
||||||
|
def load_url(url):
|
||||||
|
'''Load a module dictionary from url.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): url to saved model
|
||||||
|
'''
|
||||||
|
print(url)
|
||||||
|
print('=> Loading checkpoint from url...')
|
||||||
|
state_dict = model_zoo.load_url(url, progress=True)
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianSmoothing(nn.Module):
|
||||||
|
"""
|
||||||
|
Apply gaussian smoothing on a
|
||||||
|
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
|
||||||
|
in the input using a depthwise convolution.
|
||||||
|
Arguments:
|
||||||
|
channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well.
|
||||||
|
kernel_size (int, sequence): Size of the gaussian kernel.
|
||||||
|
sigma (float, sequence): Standard deviation of the gaussian kernel.
|
||||||
|
dim (int, optional): The number of dimensions of the data.
|
||||||
|
Default value is 2 (spatial).
|
||||||
|
"""
|
||||||
|
def __init__(self, channels, kernel_size, sigma, dim=3):
|
||||||
|
super(GaussianSmoothing, self).__init__()
|
||||||
|
if isinstance(kernel_size, numbers.Number):
|
||||||
|
kernel_size = [kernel_size] * dim
|
||||||
|
if isinstance(sigma, numbers.Number):
|
||||||
|
sigma = [sigma] * dim
|
||||||
|
|
||||||
|
# The gaussian kernel is the product of the
|
||||||
|
# gaussian function of each dimension.
|
||||||
|
kernel = 1
|
||||||
|
meshgrids = torch.meshgrid(
|
||||||
|
[
|
||||||
|
torch.arange(size, dtype=torch.float32)
|
||||||
|
for size in kernel_size
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
||||||
|
mean = (size - 1) / 2
|
||||||
|
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
|
||||||
|
torch.exp(-((mgrid - mean) / std) ** 2 / 2)
|
||||||
|
|
||||||
|
# Make sure sum of values in gaussian kernel equals 1.
|
||||||
|
kernel = kernel / torch.sum(kernel)
|
||||||
|
|
||||||
|
# Reshape to depthwise convolutional weight
|
||||||
|
kernel = kernel.view(1, 1, *kernel.size())
|
||||||
|
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
||||||
|
|
||||||
|
self.register_buffer('weight', kernel)
|
||||||
|
self.groups = channels
|
||||||
|
|
||||||
|
if dim == 1:
|
||||||
|
self.conv = F.conv1d
|
||||||
|
elif dim == 2:
|
||||||
|
self.conv = F.conv2d
|
||||||
|
elif dim == 3:
|
||||||
|
self.conv = F.conv3d
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
"""
|
||||||
|
Apply gaussian filter to input.
|
||||||
|
Arguments:
|
||||||
|
input (torch.Tensor): Input to apply gaussian filter on.
|
||||||
|
Returns:
|
||||||
|
filtered (torch.Tensor): Filtered output.
|
||||||
|
"""
|
||||||
|
return self.conv(input, weight=self.weight, groups=self.groups)
|
||||||
|
|
||||||
|
# Originally from https://github.com/amosgropp/IGR/blob/0db06b1273/code/utils/general.py
|
||||||
|
def get_learning_rate_schedules(schedule_specs):
|
||||||
|
|
||||||
|
schedules = []
|
||||||
|
|
||||||
|
for key in schedule_specs.keys():
|
||||||
|
schedules.append(StepLearningRateSchedule(
|
||||||
|
schedule_specs[key]['initial'],
|
||||||
|
schedule_specs[key]["interval"],
|
||||||
|
schedule_specs[key]["factor"],
|
||||||
|
schedule_specs[key]["final"]))
|
||||||
|
return schedules
|
||||||
|
|
||||||
|
class LearningRateSchedule:
|
||||||
|
def get_learning_rate(self, epoch):
|
||||||
|
pass
|
||||||
|
class StepLearningRateSchedule(LearningRateSchedule):
|
||||||
|
def __init__(self, initial, interval, factor, final=1e-6):
|
||||||
|
self.initial = float(initial)
|
||||||
|
self.interval = interval
|
||||||
|
self.factor = factor
|
||||||
|
self.final = float(final)
|
||||||
|
|
||||||
|
def get_learning_rate(self, epoch):
|
||||||
|
lr = np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6)
|
||||||
|
if lr > self.final:
|
||||||
|
return lr
|
||||||
|
else:
|
||||||
|
return self.final
|
||||||
|
|
||||||
|
def adjust_learning_rate(lr_schedules, optimizer, epoch):
|
||||||
|
for i, param_group in enumerate(optimizer.param_groups):
|
||||||
|
param_group["lr"] = lr_schedules[i].get_learning_rate(epoch)
|
175
src/visualize.py
Normal file
175
src/visualize.py
Normal file
|
@ -0,0 +1,175 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import open3d as o3d
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from skimage import measure
|
||||||
|
from src.utils import calc_inters_points, grid_interp
|
||||||
|
from scipy import ndimage
|
||||||
|
from tqdm import trange
|
||||||
|
from torchvision.utils import save_image
|
||||||
|
from pdb import set_trace as st
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_points_mesh(vis, points, normals, verts, faces, cfg, it, epoch, color_v=None):
|
||||||
|
''' Visualization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (dict): data dictionary
|
||||||
|
depth (int): PSR depth
|
||||||
|
out_path (str): output path for the mesh
|
||||||
|
'''
|
||||||
|
mesh = o3d.geometry.TriangleMesh()
|
||||||
|
mesh.vertices = o3d.utility.Vector3dVector(verts)
|
||||||
|
mesh.triangles = o3d.utility.Vector3iVector(faces)
|
||||||
|
mesh.paint_uniform_color(np.array([0.7,0.7,0.7]))
|
||||||
|
if color_v is not None:
|
||||||
|
mesh.vertex_colors = o3d.utility.Vector3dVector(color_v)
|
||||||
|
|
||||||
|
if vis is not None:
|
||||||
|
dir_o3d = cfg['train']['dir_o3d']
|
||||||
|
wire = o3d.geometry.LineSet.create_from_triangle_mesh(mesh)
|
||||||
|
|
||||||
|
p = points.squeeze(0).detach().cpu().numpy()
|
||||||
|
n = normals.squeeze(0).detach().cpu().numpy()
|
||||||
|
pcd = o3d.geometry.PointCloud()
|
||||||
|
pcd.points = o3d.utility.Vector3dVector(p)
|
||||||
|
pcd.normals = o3d.utility.Vector3dVector(n)
|
||||||
|
pcd.paint_uniform_color(np.array([0.7,0.7,1.0]))
|
||||||
|
# pcd = pcd.uniform_down_sample(5)
|
||||||
|
|
||||||
|
vis.clear_geometries()
|
||||||
|
vis.add_geometry(mesh)
|
||||||
|
vis.update_geometry(mesh)
|
||||||
|
|
||||||
|
#! Thingi wheel - an example for how to change cameras in Open3D viewers
|
||||||
|
vis.get_view_control().set_front([ 0.0461, -0.7467, 0.6635 ])
|
||||||
|
vis.get_view_control().set_lookat([ 0.0092, 0.0078, 0.0638 ])
|
||||||
|
vis.get_view_control().set_up([ 0.0520, 0.6651, 0.7449 ])
|
||||||
|
vis.get_view_control().set_zoom(0.7)
|
||||||
|
vis.poll_events()
|
||||||
|
|
||||||
|
out_path = os.path.join(dir_o3d, '{}.jpg'.format(it))
|
||||||
|
vis.capture_screen_image(out_path)
|
||||||
|
|
||||||
|
vis.clear_geometries()
|
||||||
|
vis.add_geometry(pcd, reset_bounding_box=False)
|
||||||
|
vis.update_geometry(pcd)
|
||||||
|
vis.get_render_option().point_show_normal=True # visualize point normals
|
||||||
|
|
||||||
|
vis.get_view_control().set_front([ 0.0461, -0.7467, 0.6635 ])
|
||||||
|
vis.get_view_control().set_lookat([ 0.0092, 0.0078, 0.0638 ])
|
||||||
|
vis.get_view_control().set_up([ 0.0520, 0.6651, 0.7449 ])
|
||||||
|
vis.get_view_control().set_zoom(0.7)
|
||||||
|
vis.poll_events()
|
||||||
|
|
||||||
|
out_path = os.path.join(dir_o3d, '{}_pcd.jpg'.format(it))
|
||||||
|
vis.capture_screen_image(out_path)
|
||||||
|
|
||||||
|
def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name='video.mp4'):
|
||||||
|
if pose is not None:
|
||||||
|
device = psr_grid.device
|
||||||
|
# get world coordinate of grid points [-1, 1]
|
||||||
|
res = psr_grid.shape[-1]
|
||||||
|
x = torch.linspace(-1, 1, steps=res)
|
||||||
|
co_x, co_y, co_z = torch.meshgrid(x, x, x)
|
||||||
|
co_grid = torch.stack(
|
||||||
|
[co_x.reshape(-1), co_y.reshape(-1), co_z.reshape(-1)],
|
||||||
|
dim=1).to(device).unsqueeze(0)
|
||||||
|
|
||||||
|
# visualize the projected occ_soft value
|
||||||
|
res = 128
|
||||||
|
psr_grid = psr_grid.reshape(-1)
|
||||||
|
out_mask = psr_grid>0
|
||||||
|
in_mask = psr_grid<0
|
||||||
|
pix = pose.transform_points_screen(co_grid, ((res, res),))[..., :2].round().long().squeeze()
|
||||||
|
vis_mask = (pix[..., 0]>=0) & (pix[..., 0]<=res-1) & \
|
||||||
|
(pix[..., 1]>=0) & (pix[..., 1]<=res-1)
|
||||||
|
pix_out = pix[vis_mask & out_mask]
|
||||||
|
pix_in = pix[vis_mask & in_mask]
|
||||||
|
|
||||||
|
img = torch.ones([res,res]).to(device)
|
||||||
|
psr_grid = torch.sigmoid(- psr_grid * 5)
|
||||||
|
img[pix_out[:, 1], pix_out[:, 0]] = psr_grid[vis_mask & out_mask]
|
||||||
|
img[pix_in[:, 1], pix_in[:, 0]] = psr_grid[vis_mask & in_mask]
|
||||||
|
# save_image(img, 'tmp.png', normalize=True)
|
||||||
|
return img
|
||||||
|
elif out_dir is not None:
|
||||||
|
dir_psr_vis = out_dir
|
||||||
|
os.makedirs(dir_psr_vis, exist_ok=True)
|
||||||
|
psr_grid = psr_grid.squeeze().detach().cpu().numpy()
|
||||||
|
axis = ['x', 'y', 'z']
|
||||||
|
s = psr_grid.shape[0]
|
||||||
|
for idx in trange(s):
|
||||||
|
my_dpi = 100
|
||||||
|
plt.figure(figsize=(1000/my_dpi, 300/my_dpi), dpi=my_dpi)
|
||||||
|
plt.subplot(1, 3, 1)
|
||||||
|
plt.imshow(ndimage.rotate(psr_grid[idx], 180, mode='nearest'), cmap='nipy_spectral')
|
||||||
|
plt.clim(-1, 1)
|
||||||
|
plt.colorbar()
|
||||||
|
plt.title('x')
|
||||||
|
plt.grid("off")
|
||||||
|
plt.axis("off")
|
||||||
|
|
||||||
|
plt.subplot(1, 3, 2)
|
||||||
|
plt.imshow(ndimage.rotate(psr_grid[:, idx], 180, mode='nearest'), cmap='nipy_spectral')
|
||||||
|
plt.clim(-1, 1)
|
||||||
|
plt.colorbar()
|
||||||
|
plt.title('y')
|
||||||
|
plt.grid("off")
|
||||||
|
plt.axis("off")
|
||||||
|
|
||||||
|
plt.subplot(1, 3, 3)
|
||||||
|
plt.imshow(ndimage.rotate(psr_grid[:,:,idx], 90, mode='nearest'), cmap='nipy_spectral')
|
||||||
|
plt.clim(-1, 1)
|
||||||
|
plt.colorbar()
|
||||||
|
plt.title('z')
|
||||||
|
plt.grid("off")
|
||||||
|
plt.axis("off")
|
||||||
|
|
||||||
|
|
||||||
|
plt.savefig(os.path.join(dir_psr_vis, '{}'.format(idx)), pad_inches = 0, dpi=100)
|
||||||
|
plt.close()
|
||||||
|
os.system("rm {}/{}".format(dir_psr_vis, out_video_name))
|
||||||
|
os.system("ffmpeg -framerate 25 -start_number 0 -i {}/%d.png -pix_fmt yuv420p -crf 17 {}/{}".format(dir_psr_vis, dir_psr_vis, out_video_name))
|
||||||
|
|
||||||
|
def visualize_mesh_phong(v, f, n, pose, img_size, name, device='cpu'):
|
||||||
|
#! Mesh rendering using Phong shading model
|
||||||
|
_, mask, f_p, w = calc_inters_points(v, f, pose, img_size)
|
||||||
|
n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]]
|
||||||
|
n_inters = w[..., 0, None] * n_a.squeeze() + \
|
||||||
|
w[..., 1, None] * n_b.squeeze() + \
|
||||||
|
w[..., 2, None] * n_c.squeeze()
|
||||||
|
n_inters = n_inters.detach().to(device)
|
||||||
|
light_source = -pose.R@pose.T.squeeze()
|
||||||
|
light = (light_source / light_source.norm(2)).permute(1, 0).to(device).float()
|
||||||
|
diffuse_per = torch.Tensor([0.7,0.7,0.7]).float()
|
||||||
|
ambiant = torch.Tensor([0.3,0.3,0.3]).float()
|
||||||
|
|
||||||
|
diffuse = torch.mm(n_inters, light).clamp_min(0).repeat(1, 3) * diffuse_per.unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
phong = torch.ones([img_size[0]*img_size[1], 3]).to(device)
|
||||||
|
phong[mask] = (ambiant.unsqueeze(0).to(device) + diffuse).clamp_max(1.0)
|
||||||
|
pp = phong.reshape(img_size[0], img_size[1], -1)
|
||||||
|
save_image(pp.permute(2, 0, 1), name)
|
||||||
|
|
||||||
|
def render_rgb(v, f, n, pose, renderer, img_size, mask_gt=None, ray=None, fea_grid=None):
|
||||||
|
p_inters, mask, f_p, w = calc_inters_points(v.detach(), f, pose, img_size, mask_gt=mask_gt)
|
||||||
|
# normals for p_inters
|
||||||
|
n_inters = None
|
||||||
|
if n is not None:
|
||||||
|
n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]]
|
||||||
|
n_inters = w[..., 0, None] * n_a.squeeze() + \
|
||||||
|
w[..., 1, None] * n_b.squeeze() + \
|
||||||
|
w[..., 2, None] * n_c.squeeze()
|
||||||
|
if ray is not None:
|
||||||
|
ray = ray.squeeze()[mask]
|
||||||
|
|
||||||
|
fea = None
|
||||||
|
if fea_grid is not None:
|
||||||
|
fea = grid_interp(fea_grid, (p_inters.detach()[None] + 1) / 2).squeeze()
|
||||||
|
|
||||||
|
# use MLP to regress color
|
||||||
|
color_pred = renderer(p_inters, normals=n_inters, view_dirs=ray, feature_vectors=fea).squeeze()
|
||||||
|
|
||||||
|
return color_pred, mask
|
185
train.py
Normal file
185
train.py
Normal file
|
@ -0,0 +1,185 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
abspath = os.path.abspath(__file__)
|
||||||
|
dname = os.path.dirname(abspath)
|
||||||
|
os.chdir(dname)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
import open3d as o3d
|
||||||
|
|
||||||
|
import numpy as np; np.set_printoptions(precision=4)
|
||||||
|
import shutil, argparse, time
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from src import config
|
||||||
|
from src.data import collate_remove_none, collate_stack_together, worker_init_fn
|
||||||
|
from src.training import Trainer
|
||||||
|
from src.model import Encode2Points
|
||||||
|
from src.utils import load_config, initialize_logger, \
|
||||||
|
AverageMeter, load_model_manual
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='MNIST toy experiment')
|
||||||
|
parser.add_argument('config', type=str, help='Path to config file.')
|
||||||
|
parser.add_argument('--no_cuda', action='store_true', default=False,
|
||||||
|
help='disables CUDA training')
|
||||||
|
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
cfg = load_config(args.config, 'configs/default.yaml')
|
||||||
|
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||||
|
device = torch.device("cuda" if use_cuda else "cpu")
|
||||||
|
input_type = cfg['data']['input_type']
|
||||||
|
batch_size = cfg['train']['batch_size']
|
||||||
|
model_selection_metric = cfg['train']['model_selection_metric']
|
||||||
|
|
||||||
|
# PYTORCH VERSION > 1.0.0
|
||||||
|
assert(float(torch.__version__.split('.')[-3]) > 0)
|
||||||
|
|
||||||
|
# boiler-plate
|
||||||
|
if cfg['train']['timestamp']:
|
||||||
|
cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S")
|
||||||
|
logger = initialize_logger(cfg)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
shutil.copyfile(args.config, os.path.join(cfg['train']['out_dir'], 'config.yaml'))
|
||||||
|
|
||||||
|
logger.info("using GPU: " + torch.cuda.get_device_name(0))
|
||||||
|
|
||||||
|
# TensorboardX writer
|
||||||
|
tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log")
|
||||||
|
if not os.path.exists(tblogdir):
|
||||||
|
os.makedirs(tblogdir, exist_ok=True)
|
||||||
|
writer = SummaryWriter(log_dir=tblogdir)
|
||||||
|
|
||||||
|
|
||||||
|
inputs = None
|
||||||
|
train_dataset = config.get_dataset('train', cfg)
|
||||||
|
val_dataset = config.get_dataset('val', cfg)
|
||||||
|
vis_dataset = config.get_dataset('vis', cfg)
|
||||||
|
|
||||||
|
|
||||||
|
collate_fn = collate_remove_none
|
||||||
|
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
train_dataset, batch_size=batch_size, num_workers=cfg['train']['n_workers'], shuffle=True,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
worker_init_fn=worker_init_fn)
|
||||||
|
|
||||||
|
val_loader = torch.utils.data.DataLoader(
|
||||||
|
val_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False,
|
||||||
|
collate_fn=collate_remove_none,
|
||||||
|
worker_init_fn=worker_init_fn)
|
||||||
|
|
||||||
|
vis_loader = torch.utils.data.DataLoader(
|
||||||
|
vis_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
worker_init_fn=worker_init_fn)
|
||||||
|
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
model = torch.nn.DataParallel(Encode2Points(cfg)).to(device)
|
||||||
|
else:
|
||||||
|
model = Encode2Points(cfg).to(device)
|
||||||
|
|
||||||
|
n_parameter = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
logger.info('Number of parameters: %d'% n_parameter)
|
||||||
|
# load model
|
||||||
|
try:
|
||||||
|
# load model
|
||||||
|
state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt'))
|
||||||
|
load_model_manual(state_dict['state_dict'], model)
|
||||||
|
|
||||||
|
out = "Load model from iteration %d" % state_dict.get('it', 0)
|
||||||
|
logger.info(out)
|
||||||
|
# load point cloud
|
||||||
|
except:
|
||||||
|
state_dict = dict()
|
||||||
|
|
||||||
|
metric_val_best = state_dict.get(
|
||||||
|
'loss_val_best', np.inf)
|
||||||
|
|
||||||
|
logger.info('Current best validation metric (%s): %.8f'
|
||||||
|
% (model_selection_metric, metric_val_best))
|
||||||
|
|
||||||
|
LR = float(cfg['train']['lr'])
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=LR)
|
||||||
|
|
||||||
|
start_epoch = state_dict.get('epoch', -1)
|
||||||
|
it = state_dict.get('it', -1)
|
||||||
|
|
||||||
|
trainer = Trainer(cfg, optimizer, device=device)
|
||||||
|
runtime = {}
|
||||||
|
runtime['all'] = AverageMeter()
|
||||||
|
|
||||||
|
# training loop
|
||||||
|
for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1):
|
||||||
|
|
||||||
|
for batch in train_loader:
|
||||||
|
it += 1
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
loss, loss_each = trainer.train_step(inputs, batch, model)
|
||||||
|
|
||||||
|
# measure elapsed time
|
||||||
|
end = time.time()
|
||||||
|
runtime['all'].update(end - start)
|
||||||
|
|
||||||
|
if it % cfg['train']['print_every'] == 0:
|
||||||
|
log_text = ('[Epoch %02d] it=%d, loss=%.4f') %(epoch, it, loss)
|
||||||
|
writer.add_scalar('train/loss', loss, it)
|
||||||
|
if loss_each is not None:
|
||||||
|
for k, l in loss_each.items():
|
||||||
|
if l.item() != 0.:
|
||||||
|
log_text += (' loss_%s=%.4f') % (k, l.item())
|
||||||
|
writer.add_scalar('train/%s' % k, l, it)
|
||||||
|
|
||||||
|
log_text += (' time=%.3f / %.2f') % (runtime['all'].val, runtime['all'].sum)
|
||||||
|
logger.info(log_text)
|
||||||
|
|
||||||
|
if (it>0)& (it % cfg['train']['visualize_every'] == 0):
|
||||||
|
for i, batch_vis in enumerate(vis_loader):
|
||||||
|
trainer.save(model, batch_vis, it, i)
|
||||||
|
if i >= 4:
|
||||||
|
break
|
||||||
|
logger.info('Saved mesh and pointcloud')
|
||||||
|
|
||||||
|
# run validation
|
||||||
|
if it > 0 and (it % cfg['train']['validate_every']) == 0:
|
||||||
|
eval_dict = trainer.evaluate(val_loader, model)
|
||||||
|
metric_val = eval_dict[model_selection_metric]
|
||||||
|
logger.info('Validation metric (%s): %.4f'
|
||||||
|
% (model_selection_metric, metric_val))
|
||||||
|
|
||||||
|
for k, v in eval_dict.items():
|
||||||
|
writer.add_scalar('val/%s' % k, v, it)
|
||||||
|
|
||||||
|
if -(metric_val - metric_val_best) >= 0:
|
||||||
|
metric_val_best = metric_val
|
||||||
|
logger.info('New best model (loss %.4f)' % metric_val_best)
|
||||||
|
state = {'epoch': epoch,
|
||||||
|
'it': it,
|
||||||
|
'loss_val_best': metric_val_best}
|
||||||
|
state['state_dict'] = model.state_dict()
|
||||||
|
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model_best.pt'))
|
||||||
|
|
||||||
|
# save checkpoint
|
||||||
|
if (epoch > 0) & (it % cfg['train']['checkpoint_every'] == 0):
|
||||||
|
state = {'epoch': epoch,
|
||||||
|
'it': it,
|
||||||
|
'loss_val_best': metric_val_best}
|
||||||
|
pcl = None
|
||||||
|
state['state_dict'] = model.state_dict()
|
||||||
|
|
||||||
|
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt'))
|
||||||
|
|
||||||
|
if (it % cfg['train']['backup_every'] == 0):
|
||||||
|
torch.save(state, os.path.join(cfg['train']['dir_model'], '%04d' % it + '.pt'))
|
||||||
|
logger.info("Backup model at iteration %d" % it)
|
||||||
|
logger.info("Save new model at iteration %d" % it)
|
||||||
|
|
||||||
|
done=time.time()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in a new issue