Compare commits
10 commits
79c35e453a
...
43164f892f
Author | SHA1 | Date | |
---|---|---|---|
43164f892f | |||
377132fdd5 | |||
241631628b | |||
aa60b82868 | |||
41b213e4a5 | |||
ee1be08bce | |||
ee2501d2bd | |||
388991be13 | |||
671b0c7c81 | |||
3967559875 |
15
.editorconfig
Normal file
15
.editorconfig
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
# EditorConfig is awesome: https://EditorConfig.org
|
||||||
|
|
||||||
|
# top-most EditorConfig file
|
||||||
|
root = true
|
||||||
|
|
||||||
|
[*]
|
||||||
|
indent_style = space
|
||||||
|
indent_size = 4
|
||||||
|
end_of_line = lf
|
||||||
|
charset = utf-8
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
insert_final_newline = true
|
||||||
|
|
||||||
|
[*.{json,toml,yaml,yml,md}]
|
||||||
|
indent_size = 2
|
28
.gitattributes
vendored
Normal file
28
.gitattributes
vendored
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
# https://github.com/alexkaratarakis/gitattributes/blob/master/Python.gitattributes
|
||||||
|
# Basic .gitattributes for a python repo.
|
||||||
|
|
||||||
|
# Source files
|
||||||
|
*.pxd text diff=python
|
||||||
|
*.py text diff=python
|
||||||
|
*.py3 text diff=python
|
||||||
|
*.pyw text diff=python
|
||||||
|
*.pyx text diff=python
|
||||||
|
*.pyz text diff=python
|
||||||
|
*.pyi text diff=python
|
||||||
|
|
||||||
|
# Binary files
|
||||||
|
*.db binary
|
||||||
|
*.p binary
|
||||||
|
*.pkl binary
|
||||||
|
*.pickle binary
|
||||||
|
*.pyc binary export-ignore
|
||||||
|
*.pyo binary export-ignore
|
||||||
|
*.pyd binary
|
||||||
|
|
||||||
|
# Jupyter notebook
|
||||||
|
*.ipynb text
|
||||||
|
|
||||||
|
# Note: .db, .p, and .pkl files are associated
|
||||||
|
# with the python modules ``pickle``, ``dbm.*``,
|
||||||
|
# ``shelve``, ``marshal``, ``anydbm``, & ``bsddb``
|
||||||
|
# (among others).
|
71
.gitignore
vendored
71
.gitignore
vendored
|
@ -1,32 +1,22 @@
|
||||||
/out
|
# Personnal ignores
|
||||||
/data
|
lightning_logs/
|
||||||
.vscode
|
|
||||||
.cache
|
|
||||||
*.pyc
|
|
||||||
*.pyd
|
|
||||||
*.pt
|
|
||||||
*.so
|
|
||||||
*.o
|
|
||||||
*.prof
|
|
||||||
*.swp
|
|
||||||
*.lib
|
|
||||||
*.obj
|
|
||||||
*.exp
|
|
||||||
.nfs*
|
|
||||||
*.jpg
|
|
||||||
*.png
|
|
||||||
*.ply
|
|
||||||
*.off
|
|
||||||
*.npz
|
|
||||||
*.txt
|
|
||||||
# *.sh
|
|
||||||
|
|
||||||
|
*.tar.gz
|
||||||
|
*.vtk
|
||||||
|
|
||||||
|
demo/
|
||||||
|
|
||||||
|
# https://github.com/github/gitignore/blob/main/Python.gitignore
|
||||||
|
# Basic .gitignore for a python repo.
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
# Distribution / packaging
|
# Distribution / packaging
|
||||||
.Python
|
.Python
|
||||||
build/
|
build/
|
||||||
|
@ -41,7 +31,6 @@ parts/
|
||||||
sdist/
|
sdist/
|
||||||
var/
|
var/
|
||||||
wheels/
|
wheels/
|
||||||
pip-wheel-metadata/
|
|
||||||
share/python-wheels/
|
share/python-wheels/
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
.installed.cfg
|
.installed.cfg
|
||||||
|
@ -71,6 +60,7 @@ coverage.xml
|
||||||
*.py,cover
|
*.py,cover
|
||||||
.hypothesis/
|
.hypothesis/
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
# Translations
|
# Translations
|
||||||
*.mo
|
*.mo
|
||||||
|
@ -93,6 +83,7 @@ instance/
|
||||||
docs/_build/
|
docs/_build/
|
||||||
|
|
||||||
# PyBuilder
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
target/
|
target/
|
||||||
|
|
||||||
# Jupyter Notebook
|
# Jupyter Notebook
|
||||||
|
@ -103,7 +94,9 @@ profile_default/
|
||||||
ipython_config.py
|
ipython_config.py
|
||||||
|
|
||||||
# pyenv
|
# pyenv
|
||||||
.python-version
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
# pipenv
|
# pipenv
|
||||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
@ -112,7 +105,22 @@ ipython_config.py
|
||||||
# install all needed dependencies.
|
# install all needed dependencies.
|
||||||
#Pipfile.lock
|
#Pipfile.lock
|
||||||
|
|
||||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
__pypackages__/
|
__pypackages__/
|
||||||
|
|
||||||
# Celery stuff
|
# Celery stuff
|
||||||
|
@ -148,3 +156,16 @@ dmypy.json
|
||||||
|
|
||||||
# Pyre type checker
|
# Pyre type checker
|
||||||
.pyre/
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
|
5
.vscode/settings.json
vendored
5
.vscode/settings.json
vendored
|
@ -1,5 +1,5 @@
|
||||||
{
|
{
|
||||||
"python.defaultInterpreterPath": "/local_scratch/lfainsin/.conda/envs/pyg/bin/python", // required for python ide tools
|
"python.defaultInterpreterPath": "/local_scratch/lfainsin/.conda/envs/sap/bin/python", // required for python ide tools
|
||||||
"python.terminal.activateEnvironment": false, // or else terminal gets bugged
|
"python.terminal.activateEnvironment": false, // or else terminal gets bugged
|
||||||
// good pratice settings
|
// good pratice settings
|
||||||
"editor.formatOnSave": true,
|
"editor.formatOnSave": true,
|
||||||
|
@ -33,7 +33,7 @@
|
||||||
"path": "bash",
|
"path": "bash",
|
||||||
"icon": "rocket",
|
"icon": "rocket",
|
||||||
"env": {
|
"env": {
|
||||||
"CONDAENV": "pyg",
|
"CONDAENV": "sap",
|
||||||
},
|
},
|
||||||
"args": [
|
"args": [
|
||||||
"-c",
|
"-c",
|
||||||
|
@ -42,7 +42,6 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"terminal.integrated.env.linux": {
|
"terminal.integrated.env.linux": {
|
||||||
"PYTHONPATH": "${workspaceFolder}/src/",
|
|
||||||
"SLURM_JOB_ID": null,
|
"SLURM_JOB_ID": null,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,12 +33,13 @@ conda activate sap
|
||||||
|
|
||||||
Next, you should install [PyTorch3D](https://pytorch3d.org/) (**>=0.5**) yourself from the [official instruction](https://github.com/facebookresearch/pytorch3d/blob/master/INSTALL.md#3-install-wheels-for-linux).
|
Next, you should install [PyTorch3D](https://pytorch3d.org/) (**>=0.5**) yourself from the [official instruction](https://github.com/facebookresearch/pytorch3d/blob/master/INSTALL.md#3-install-wheels-for-linux).
|
||||||
|
|
||||||
And install [PyTorch Scatter](https://github.com/rusty1s/pytorch_scatter):
|
```bash
|
||||||
```sh
|
git clone https://github.com/facebookresearch/pytorch3d.git
|
||||||
conda install pytorch-scatter -c pyg
|
cd pytorch3d
|
||||||
|
module load compilers
|
||||||
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Demo - Quick Start
|
## Demo - Quick Start
|
||||||
|
|
||||||
First, run the script to get the demo data:
|
First, run the script to get the demo data:
|
||||||
|
|
|
@ -1,28 +1,44 @@
|
||||||
name: sap
|
name: sap
|
||||||
|
|
||||||
channels:
|
channels:
|
||||||
|
- nodefaults
|
||||||
- conda-forge
|
- conda-forge
|
||||||
- pytorch
|
- pytorch
|
||||||
- defaults
|
- nvidia
|
||||||
|
- pyg
|
||||||
|
|
||||||
dependencies:
|
dependencies:
|
||||||
- python
|
#---# basic python
|
||||||
- pytorch
|
- python=3.8
|
||||||
- torchvision
|
|
||||||
- cudatoolkit
|
|
||||||
- numpy
|
|
||||||
- matplotlib
|
|
||||||
- pyyaml
|
|
||||||
- scipy
|
|
||||||
- tqdm
|
- tqdm
|
||||||
|
- pyyaml
|
||||||
|
#---# visu
|
||||||
|
- matplotlib
|
||||||
|
#---# scientific
|
||||||
|
- numpy
|
||||||
|
- scipy
|
||||||
- trimesh
|
- trimesh
|
||||||
- igl
|
- igl
|
||||||
|
#---# pytorch
|
||||||
|
- pytorch
|
||||||
|
- pytorch-cuda=11.8
|
||||||
|
- cudatoolkit
|
||||||
|
- torchvision
|
||||||
|
- torch-scatter
|
||||||
|
#---# tooling (linting, typing...)
|
||||||
|
- ruff
|
||||||
|
- mypy
|
||||||
|
- black
|
||||||
|
- isort
|
||||||
|
#---# logging
|
||||||
- tensorboard
|
- tensorboard
|
||||||
|
#---# pip shit
|
||||||
- pip
|
- pip
|
||||||
- pip:
|
- pip:
|
||||||
- plyfile==0.7
|
- plyfile
|
||||||
# - open3d>=0.11.1
|
- scikit-image
|
||||||
- scikit-image>=0.18.0
|
- python-mnist
|
||||||
- python-mnist==0.7
|
- opencv-python
|
||||||
- opencv-python>=4.4
|
- av
|
||||||
- av==8.0.3
|
- pykdtree
|
||||||
- pykdtree==1.3.4
|
- ipdb
|
||||||
- ipdb==0.13.7
|
|
||||||
|
|
148
eval_meshes.py
148
eval_meshes.py
|
@ -1,155 +1,145 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
import trimesh
|
import trimesh
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
import numpy as np; np.set_printoptions(precision=4)
|
|
||||||
import shutil, argparse, time, os
|
|
||||||
import pandas as pd
|
|
||||||
from src.data import collate_remove_none, collate_stack_together, worker_init_fn
|
|
||||||
from src.training import Trainer
|
|
||||||
from src.model import Encode2Points
|
|
||||||
from src.data import PointCloudField, IndexField, Shapes3dDataset
|
|
||||||
from src.utils import load_config, load_pointcloud
|
|
||||||
from src.eval import MeshEvaluator
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from pdb import set_trace as st
|
|
||||||
|
from src.data import IndexField, PointCloudField, Shapes3dDataset
|
||||||
|
from src.eval import MeshEvaluator
|
||||||
|
from src.utils import load_config, load_pointcloud
|
||||||
|
|
||||||
|
np.set_printoptions(precision=4)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='MNIST toy experiment')
|
parser = argparse.ArgumentParser(description="MNIST toy experiment")
|
||||||
parser.add_argument('config', type=str, help='Path to config file.')
|
parser.add_argument("config", type=str, help="Path to config file.")
|
||||||
parser.add_argument('--no_cuda', action='store_true', default=False,
|
parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
|
||||||
help='disables CUDA training')
|
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
|
||||||
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
|
parser.add_argument("--iter", type=int, metavar="S", help="the training iteration to be evaluated.")
|
||||||
parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
cfg = load_config(args.config, 'configs/default.yaml')
|
cfg = load_config(args.config, "configs/default.yaml")
|
||||||
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||||
device = torch.device("cuda" if use_cuda else "cpu")
|
torch.device("cuda" if use_cuda else "cpu")
|
||||||
data_type = cfg['data']['data_type']
|
cfg["data"]["data_type"]
|
||||||
# Shorthands
|
# Shorthands
|
||||||
out_dir = cfg['train']['out_dir']
|
out_dir = cfg["train"]["out_dir"]
|
||||||
generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir'])
|
generation_dir = os.path.join(out_dir, cfg["generation"]["generation_dir"])
|
||||||
|
|
||||||
if cfg['generation'].get('iter', 0)!=0:
|
if cfg["generation"].get("iter", 0) != 0:
|
||||||
generation_dir += '_%04d'%cfg['generation']['iter']
|
generation_dir += "_%04d" % cfg["generation"]["iter"]
|
||||||
elif args.iter is not None:
|
elif args.iter is not None:
|
||||||
generation_dir += '_%04d'%args.iter
|
generation_dir += "_%04d" % args.iter
|
||||||
|
|
||||||
print('Evaluate meshes under %s'%generation_dir)
|
print("Evaluate meshes under %s" % generation_dir)
|
||||||
|
|
||||||
out_file = os.path.join(generation_dir, 'eval_meshes_full.pkl')
|
out_file = os.path.join(generation_dir, "eval_meshes_full.pkl")
|
||||||
out_file_class = os.path.join(generation_dir, 'eval_meshes.csv')
|
out_file_class = os.path.join(generation_dir, "eval_meshes.csv")
|
||||||
|
|
||||||
# PYTORCH VERSION > 1.0.0
|
# PYTORCH VERSION > 1.0.0
|
||||||
assert(float(torch.__version__.split('.')[-3]) > 0)
|
assert float(torch.__version__.split(".")[-3]) > 0
|
||||||
|
|
||||||
pointcloud_field = PointCloudField(cfg['data']['pointcloud_file'])
|
pointcloud_field = PointCloudField(cfg["data"]["pointcloud_file"])
|
||||||
fields = {
|
fields = {
|
||||||
'pointcloud': pointcloud_field,
|
"pointcloud": pointcloud_field,
|
||||||
'idx': IndexField(),
|
"idx": IndexField(),
|
||||||
}
|
}
|
||||||
|
|
||||||
print('Test split: ', cfg['data']['test_split'])
|
print("Test split: ", cfg["data"]["test_split"])
|
||||||
|
|
||||||
dataset_folder = cfg['data']['path']
|
dataset_folder = cfg["data"]["path"]
|
||||||
dataset = Shapes3dDataset(
|
dataset = Shapes3dDataset(
|
||||||
dataset_folder, fields,
|
dataset_folder, fields, cfg["data"]["test_split"], categories=cfg["data"]["class"], cfg=cfg,
|
||||||
cfg['data']['test_split'],
|
)
|
||||||
categories=cfg['data']['class'], cfg=cfg)
|
|
||||||
|
|
||||||
# Loader
|
# Loader
|
||||||
test_loader = torch.utils.data.DataLoader(
|
test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False)
|
||||||
dataset, batch_size=1, num_workers=0, shuffle=False)
|
|
||||||
|
|
||||||
# Evaluator
|
# Evaluator
|
||||||
evaluator = MeshEvaluator(n_points=100000)
|
evaluator = MeshEvaluator(n_points=100000)
|
||||||
|
|
||||||
eval_dicts = []
|
eval_dicts = []
|
||||||
print('Evaluating meshes...')
|
print("Evaluating meshes...")
|
||||||
for it, data in enumerate(tqdm(test_loader)):
|
for _it, data in enumerate(tqdm(test_loader)):
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
print('Invalid data.')
|
print("Invalid data.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
mesh_dir = os.path.join(generation_dir, 'meshes')
|
mesh_dir = os.path.join(generation_dir, "meshes")
|
||||||
pointcloud_dir = os.path.join(generation_dir, 'pointcloud')
|
pointcloud_dir = os.path.join(generation_dir, "pointcloud")
|
||||||
|
|
||||||
|
|
||||||
# Get index etc.
|
# Get index etc.
|
||||||
idx = data['idx'].item()
|
idx = data["idx"].item()
|
||||||
try:
|
try:
|
||||||
model_dict = dataset.get_model_dict(idx)
|
model_dict = dataset.get_model_dict(idx)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
model_dict = {'model': str(idx), 'category': 'n/a'}
|
model_dict = {"model": str(idx), "category": "n/a"}
|
||||||
|
|
||||||
modelname = model_dict['model']
|
modelname = model_dict["model"]
|
||||||
category_id = model_dict['category']
|
category_id = model_dict["category"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
category_name = dataset.metadata[category_id].get('name', 'n/a')
|
category_name = dataset.metadata[category_id].get("name", "n/a")
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
category_name = 'n/a'
|
category_name = "n/a"
|
||||||
|
|
||||||
if category_id != 'n/a':
|
if category_id != "n/a":
|
||||||
mesh_dir = os.path.join(mesh_dir, category_id)
|
mesh_dir = os.path.join(mesh_dir, category_id)
|
||||||
pointcloud_dir = os.path.join(pointcloud_dir, category_id)
|
pointcloud_dir = os.path.join(pointcloud_dir, category_id)
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
pointcloud_tgt = data['pointcloud'].squeeze(0).numpy()
|
pointcloud_tgt = data["pointcloud"].squeeze(0).numpy()
|
||||||
normals_tgt = data['pointcloud.normals'].squeeze(0).numpy()
|
normals_tgt = data["pointcloud.normals"].squeeze(0).numpy()
|
||||||
|
|
||||||
|
|
||||||
eval_dict = {
|
eval_dict = {
|
||||||
'idx': idx,
|
"idx": idx,
|
||||||
'class id': category_id,
|
"class id": category_id,
|
||||||
'class name': category_name,
|
"class name": category_name,
|
||||||
'modelname':modelname,
|
"modelname": modelname,
|
||||||
}
|
}
|
||||||
eval_dicts.append(eval_dict)
|
eval_dicts.append(eval_dict)
|
||||||
|
|
||||||
# Evaluate mesh
|
# Evaluate mesh
|
||||||
if cfg['test']['eval_mesh']:
|
if cfg["test"]["eval_mesh"]:
|
||||||
mesh_file = os.path.join(mesh_dir, '%s.off' % modelname)
|
mesh_file = os.path.join(mesh_dir, "%s.off" % modelname)
|
||||||
|
|
||||||
if os.path.exists(mesh_file):
|
if os.path.exists(mesh_file):
|
||||||
mesh = trimesh.load(mesh_file, process=False)
|
mesh = trimesh.load(mesh_file, process=False)
|
||||||
eval_dict_mesh = evaluator.eval_mesh(
|
eval_dict_mesh = evaluator.eval_mesh(mesh, pointcloud_tgt, normals_tgt)
|
||||||
mesh, pointcloud_tgt, normals_tgt)
|
|
||||||
for k, v in eval_dict_mesh.items():
|
for k, v in eval_dict_mesh.items():
|
||||||
eval_dict[k + ' (mesh)'] = v
|
eval_dict[k + " (mesh)"] = v
|
||||||
else:
|
else:
|
||||||
print('Warning: mesh does not exist: %s' % mesh_file)
|
print("Warning: mesh does not exist: %s" % mesh_file)
|
||||||
|
|
||||||
# Evaluate point cloud
|
# Evaluate point cloud
|
||||||
if cfg['test']['eval_pointcloud']:
|
if cfg["test"]["eval_pointcloud"]:
|
||||||
pointcloud_file = os.path.join(
|
pointcloud_file = os.path.join(pointcloud_dir, "%s.ply" % modelname)
|
||||||
pointcloud_dir, '%s.ply' % modelname)
|
|
||||||
|
|
||||||
if os.path.exists(pointcloud_file):
|
if os.path.exists(pointcloud_file):
|
||||||
pointcloud = load_pointcloud(pointcloud_file).astype(np.float32)
|
pointcloud = load_pointcloud(pointcloud_file).astype(np.float32)
|
||||||
eval_dict_pcl = evaluator.eval_pointcloud(
|
eval_dict_pcl = evaluator.eval_pointcloud(pointcloud, pointcloud_tgt)
|
||||||
pointcloud, pointcloud_tgt)
|
|
||||||
for k, v in eval_dict_pcl.items():
|
for k, v in eval_dict_pcl.items():
|
||||||
eval_dict[k + ' (pcl)'] = v
|
eval_dict[k + " (pcl)"] = v
|
||||||
else:
|
else:
|
||||||
print('Warning: pointcloud does not exist: %s'
|
print("Warning: pointcloud does not exist: %s" % pointcloud_file)
|
||||||
% pointcloud_file)
|
|
||||||
|
|
||||||
|
|
||||||
# Create pandas dataframe and save
|
# Create pandas dataframe and save
|
||||||
eval_df = pd.DataFrame(eval_dicts)
|
eval_df = pd.DataFrame(eval_dicts)
|
||||||
eval_df.set_index(['idx'], inplace=True)
|
eval_df.set_index(["idx"], inplace=True)
|
||||||
eval_df.to_pickle(out_file)
|
eval_df.to_pickle(out_file)
|
||||||
|
|
||||||
# Create CSV file with main statistics
|
# Create CSV file with main statistics
|
||||||
eval_df_class = eval_df.groupby(by=['class name']).mean()
|
eval_df_class = eval_df.groupby(by=["class name"]).mean()
|
||||||
eval_df_class.loc['mean'] = eval_df_class.mean()
|
eval_df_class.loc["mean"] = eval_df_class.mean()
|
||||||
eval_df_class.to_csv(out_file_class)
|
eval_df_class.to_csv(out_file_class)
|
||||||
|
|
||||||
# Print results
|
# Print results
|
||||||
print(eval_df_class)
|
print(eval_df_class)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
191
generate.py
191
generate.py
|
@ -1,127 +1,136 @@
|
||||||
import torch
|
import argparse
|
||||||
from torch.utils.data import Dataset, DataLoader
|
import os
|
||||||
import numpy as np; np.set_printoptions(precision=4)
|
import shutil
|
||||||
import shutil, argparse, time, os
|
|
||||||
import pandas as pd
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from src import config
|
|
||||||
from src.utils import mc_from_psr, export_mesh, export_pointcloud
|
import numpy as np
|
||||||
from src.dpsr import DPSR
|
import pandas as pd
|
||||||
from src.training import Trainer
|
import torch
|
||||||
from src.model import Encode2Points
|
|
||||||
from src.utils import load_config, load_model_manual, scale2onet, is_url, load_url
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from pdb import set_trace as st
|
|
||||||
|
from src import config
|
||||||
|
from src.dpsr import DPSR
|
||||||
|
from src.model import Encode2Points
|
||||||
|
from src.utils import (
|
||||||
|
export_mesh,
|
||||||
|
export_pointcloud,
|
||||||
|
is_url,
|
||||||
|
load_config,
|
||||||
|
load_model_manual,
|
||||||
|
load_url,
|
||||||
|
mc_from_psr,
|
||||||
|
scale2onet,
|
||||||
|
)
|
||||||
|
|
||||||
|
np.set_printoptions(precision=4)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='MNIST toy experiment')
|
parser = argparse.ArgumentParser(description="MNIST toy experiment")
|
||||||
parser.add_argument('config', type=str, help='Path to config file.')
|
parser.add_argument("config", type=str, help="Path to config file.")
|
||||||
parser.add_argument('--no_cuda', action='store_true', default=False,
|
parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
|
||||||
help='disables CUDA training')
|
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
|
||||||
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
|
parser.add_argument("--iter", type=int, metavar="S", help="the training iteration to be evaluated.")
|
||||||
parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
cfg = load_config(args.config, 'configs/default.yaml')
|
cfg = load_config(args.config, "configs/default.yaml")
|
||||||
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||||
device = torch.device("cuda" if use_cuda else "cpu")
|
device = torch.device("cuda" if use_cuda else "cpu")
|
||||||
data_type = cfg['data']['data_type']
|
cfg["data"]["data_type"]
|
||||||
input_type = cfg['data']['input_type']
|
cfg["data"]["input_type"]
|
||||||
vis_n_outputs = cfg['generation']['vis_n_outputs']
|
vis_n_outputs = cfg["generation"]["vis_n_outputs"]
|
||||||
if vis_n_outputs is None:
|
if vis_n_outputs is None:
|
||||||
vis_n_outputs = -1
|
vis_n_outputs = -1
|
||||||
# Shorthands
|
# Shorthands
|
||||||
out_dir = cfg['train']['out_dir']
|
out_dir = cfg["train"]["out_dir"]
|
||||||
if not out_dir:
|
if not out_dir:
|
||||||
os.makedirs(out_dir)
|
os.makedirs(out_dir)
|
||||||
generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir'])
|
generation_dir = os.path.join(out_dir, cfg["generation"]["generation_dir"])
|
||||||
out_time_file = os.path.join(generation_dir, 'time_generation_full.pkl')
|
out_time_file = os.path.join(generation_dir, "time_generation_full.pkl")
|
||||||
out_time_file_class = os.path.join(generation_dir, 'time_generation.pkl')
|
out_time_file_class = os.path.join(generation_dir, "time_generation.pkl")
|
||||||
|
|
||||||
# PYTORCH VERSION > 1.0.0
|
# PYTORCH VERSION > 1.0.0
|
||||||
assert(float(torch.__version__.split('.')[-3]) > 0)
|
assert float(torch.__version__.split(".")[-3]) > 0
|
||||||
|
|
||||||
dataset = config.get_dataset('test', cfg, return_idx=True)
|
dataset = config.get_dataset("test", cfg, return_idx=True)
|
||||||
test_loader = torch.utils.data.DataLoader(
|
test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False)
|
||||||
dataset, batch_size=1, num_workers=0, shuffle=False)
|
|
||||||
|
|
||||||
model = Encode2Points(cfg).to(device)
|
model = Encode2Points(cfg).to(device)
|
||||||
|
|
||||||
# load model
|
# load model
|
||||||
try:
|
try:
|
||||||
if is_url(cfg['test']['model_file']):
|
if is_url(cfg["test"]["model_file"]):
|
||||||
state_dict = load_url(cfg['test']['model_file'])
|
state_dict = load_url(cfg["test"]["model_file"])
|
||||||
elif cfg['generation'].get('iter', 0)!=0:
|
elif cfg["generation"].get("iter", 0) != 0:
|
||||||
state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% cfg['generation']['iter']))
|
state_dict = torch.load(os.path.join(out_dir, "model", "%04d.pt" % cfg["generation"]["iter"]))
|
||||||
generation_dir += '_%04d'%cfg['generation']['iter']
|
generation_dir += "_%04d" % cfg["generation"]["iter"]
|
||||||
elif args.iter is not None:
|
elif args.iter is not None:
|
||||||
state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% args.iter))
|
state_dict = torch.load(os.path.join(out_dir, "model", "%04d.pt" % args.iter))
|
||||||
else:
|
else:
|
||||||
state_dict = torch.load(os.path.join(out_dir, 'model_best.pt'))
|
state_dict = torch.load(os.path.join(out_dir, "model_best.pt"))
|
||||||
|
|
||||||
load_model_manual(state_dict['state_dict'], model)
|
load_model_manual(state_dict["state_dict"], model)
|
||||||
|
|
||||||
except:
|
except:
|
||||||
print('Model loading error. Exiting.')
|
print("Model loading error. Exiting.")
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
|
|
||||||
# Generator
|
# Generator
|
||||||
generator = config.get_generator(model, cfg, device=device)
|
generator = config.get_generator(model, cfg, device=device)
|
||||||
|
|
||||||
# Determine what to generate
|
# Determine what to generate
|
||||||
generate_mesh = cfg['generation']['generate_mesh']
|
generate_mesh = cfg["generation"]["generate_mesh"]
|
||||||
generate_pointcloud = cfg['generation']['generate_pointcloud']
|
generate_pointcloud = cfg["generation"]["generate_pointcloud"]
|
||||||
|
|
||||||
# Statistics
|
# Statistics
|
||||||
time_dicts = []
|
time_dicts = []
|
||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
model.eval()
|
model.eval()
|
||||||
dpsr = DPSR(res=(cfg['generation']['psr_resolution'],
|
dpsr = DPSR(
|
||||||
cfg['generation']['psr_resolution'],
|
res=(
|
||||||
cfg['generation']['psr_resolution']),
|
cfg["generation"]["psr_resolution"],
|
||||||
sig= cfg['generation']['psr_sigma']).to(device)
|
cfg["generation"]["psr_resolution"],
|
||||||
|
cfg["generation"]["psr_resolution"],
|
||||||
|
),
|
||||||
|
sig=cfg["generation"]["psr_sigma"],
|
||||||
|
).to(device)
|
||||||
|
|
||||||
# Count how many models already created
|
# Count how many models already created
|
||||||
model_counter = defaultdict(int)
|
model_counter = defaultdict(int)
|
||||||
|
|
||||||
print('Generating...')
|
print("Generating...")
|
||||||
for it, data in enumerate(tqdm(test_loader)):
|
for _it, data in enumerate(tqdm(test_loader)):
|
||||||
|
|
||||||
# Output folders
|
# Output folders
|
||||||
mesh_dir = os.path.join(generation_dir, 'meshes')
|
mesh_dir = os.path.join(generation_dir, "meshes")
|
||||||
in_dir = os.path.join(generation_dir, 'input')
|
in_dir = os.path.join(generation_dir, "input")
|
||||||
pointcloud_dir = os.path.join(generation_dir, 'pointcloud')
|
pointcloud_dir = os.path.join(generation_dir, "pointcloud")
|
||||||
generation_vis_dir = os.path.join(generation_dir, 'vis', )
|
generation_vis_dir = os.path.join(generation_dir, "vis")
|
||||||
|
|
||||||
# Get index etc.
|
# Get index etc.
|
||||||
idx = data['idx'].item()
|
idx = data["idx"].item()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_dict = dataset.get_model_dict(idx)
|
model_dict = dataset.get_model_dict(idx)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
model_dict = {'model': str(idx), 'category': 'n/a'}
|
model_dict = {"model": str(idx), "category": "n/a"}
|
||||||
|
|
||||||
modelname = model_dict['model']
|
modelname = model_dict["model"]
|
||||||
category_id = model_dict['category']
|
category_id = model_dict["category"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
category_name = dataset.metadata[category_id].get('name', 'n/a')
|
category_name = dataset.metadata[category_id].get("name", "n/a")
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
category_name = 'n/a'
|
category_name = "n/a"
|
||||||
|
|
||||||
if category_id != 'n/a':
|
if category_id != "n/a":
|
||||||
mesh_dir = os.path.join(mesh_dir, str(category_id))
|
mesh_dir = os.path.join(mesh_dir, str(category_id))
|
||||||
pointcloud_dir = os.path.join(pointcloud_dir, str(category_id))
|
pointcloud_dir = os.path.join(pointcloud_dir, str(category_id))
|
||||||
in_dir = os.path.join(in_dir, str(category_id))
|
in_dir = os.path.join(in_dir, str(category_id))
|
||||||
|
|
||||||
folder_name = str(category_id)
|
folder_name = str(category_id)
|
||||||
if category_name != 'n/a':
|
if category_name != "n/a":
|
||||||
folder_name = str(folder_name) + '_' + category_name.split(',')[0]
|
folder_name = str(folder_name) + "_" + category_name.split(",")[0]
|
||||||
|
|
||||||
generation_vis_dir = os.path.join(generation_vis_dir, folder_name)
|
generation_vis_dir = os.path.join(generation_vis_dir, folder_name)
|
||||||
|
|
||||||
|
@ -140,10 +149,10 @@ def main():
|
||||||
|
|
||||||
# Timing dict
|
# Timing dict
|
||||||
time_dict = {
|
time_dict = {
|
||||||
'idx': idx,
|
"idx": idx,
|
||||||
'class id': category_id,
|
"class id": category_id,
|
||||||
'class name': category_name,
|
"class name": category_name,
|
||||||
'modelname':modelname,
|
"modelname": modelname,
|
||||||
}
|
}
|
||||||
time_dicts.append(time_dict)
|
time_dicts.append(time_dict)
|
||||||
|
|
||||||
|
@ -158,60 +167,56 @@ def main():
|
||||||
time_dict.update(stats_dict)
|
time_dict.update(stats_dict)
|
||||||
|
|
||||||
# Write output
|
# Write output
|
||||||
mesh_out_file = os.path.join(mesh_dir, '%s.off' % modelname)
|
mesh_out_file = os.path.join(mesh_dir, "%s.off" % modelname)
|
||||||
export_mesh(mesh_out_file, scale2onet(v), f)
|
export_mesh(mesh_out_file, scale2onet(v), f)
|
||||||
out_file_dict['mesh'] = mesh_out_file
|
out_file_dict["mesh"] = mesh_out_file
|
||||||
|
|
||||||
if generate_pointcloud:
|
if generate_pointcloud:
|
||||||
pointcloud_out_file = os.path.join(
|
pointcloud_out_file = os.path.join(pointcloud_dir, "%s.ply" % modelname)
|
||||||
pointcloud_dir, '%s.ply' % modelname)
|
|
||||||
export_pointcloud(pointcloud_out_file, scale2onet(points), normals)
|
export_pointcloud(pointcloud_out_file, scale2onet(points), normals)
|
||||||
out_file_dict['pointcloud'] = pointcloud_out_file
|
out_file_dict["pointcloud"] = pointcloud_out_file
|
||||||
|
|
||||||
if cfg['generation']['copy_input']:
|
if cfg["generation"]["copy_input"]:
|
||||||
inputs_path = os.path.join(in_dir, '%s.ply' % modelname)
|
inputs_path = os.path.join(in_dir, "%s.ply" % modelname)
|
||||||
p = data.get('inputs').to(device)
|
p = data.get("inputs").to(device)
|
||||||
export_pointcloud(inputs_path, scale2onet(p))
|
export_pointcloud(inputs_path, scale2onet(p))
|
||||||
out_file_dict['in'] = inputs_path
|
out_file_dict["in"] = inputs_path
|
||||||
|
|
||||||
# Copy to visualization directory for first vis_n_output samples
|
# Copy to visualization directory for first vis_n_output samples
|
||||||
c_it = model_counter[category_id]
|
c_it = model_counter[category_id]
|
||||||
if c_it < vis_n_outputs:
|
if c_it < vis_n_outputs:
|
||||||
# Save output files
|
# Save output files
|
||||||
img_name = '%02d.off' % c_it
|
"%02d.off" % c_it
|
||||||
for k, filepath in out_file_dict.items():
|
for k, filepath in out_file_dict.items():
|
||||||
ext = os.path.splitext(filepath)[1]
|
ext = os.path.splitext(filepath)[1]
|
||||||
out_file = os.path.join(generation_vis_dir, '%02d_%s%s'
|
out_file = os.path.join(generation_vis_dir, "%02d_%s%s" % (c_it, k, ext))
|
||||||
% (c_it, k, ext))
|
|
||||||
shutil.copyfile(filepath, out_file)
|
shutil.copyfile(filepath, out_file)
|
||||||
|
|
||||||
# Also generate oracle meshes
|
# Also generate oracle meshes
|
||||||
if cfg['generation']['exp_oracle']:
|
if cfg["generation"]["exp_oracle"]:
|
||||||
points_gt = data.get('gt_points').to(device)
|
points_gt = data.get("gt_points").to(device)
|
||||||
normals_gt = data.get('gt_points.normals').to(device)
|
normals_gt = data.get("gt_points.normals").to(device)
|
||||||
psr_gt = dpsr(points_gt, normals_gt)
|
psr_gt = dpsr(points_gt, normals_gt)
|
||||||
v, f, _ = mc_from_psr(psr_gt,
|
v, f, _ = mc_from_psr(psr_gt, zero_level=cfg["data"]["zero_level"])
|
||||||
zero_level=cfg['data']['zero_level'])
|
out_file = os.path.join(generation_vis_dir, "%02d_%s%s" % (c_it, "mesh_oracle", ".off"))
|
||||||
out_file = os.path.join(generation_vis_dir, '%02d_%s%s'
|
|
||||||
% (c_it, 'mesh_oracle', '.off'))
|
|
||||||
export_mesh(out_file, scale2onet(v), f)
|
export_mesh(out_file, scale2onet(v), f)
|
||||||
|
|
||||||
model_counter[category_id] += 1
|
model_counter[category_id] += 1
|
||||||
|
|
||||||
|
|
||||||
# Create pandas dataframe and save
|
# Create pandas dataframe and save
|
||||||
time_df = pd.DataFrame(time_dicts)
|
time_df = pd.DataFrame(time_dicts)
|
||||||
time_df.set_index(['idx'], inplace=True)
|
time_df.set_index(["idx"], inplace=True)
|
||||||
time_df.to_pickle(out_time_file)
|
time_df.to_pickle(out_time_file)
|
||||||
|
|
||||||
# Create pickle files with main statistics
|
# Create pickle files with main statistics
|
||||||
time_df_class = time_df.groupby(by=['class name']).mean()
|
time_df_class = time_df.groupby(by=["class name"]).mean()
|
||||||
time_df_class.loc['mean'] = time_df_class.mean()
|
time_df_class.loc["mean"] = time_df_class.mean()
|
||||||
time_df_class.to_pickle(out_time_file_class)
|
time_df_class.to_pickle(out_time_file_class)
|
||||||
|
|
||||||
# Print results
|
# Print results
|
||||||
print('Timings [s]:')
|
print("Timings [s]:")
|
||||||
print(time_df_class)
|
print(time_df_class)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
300
optim.py
300
optim.py
|
@ -1,79 +1,86 @@
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import open3d as o3d
|
||||||
import torch
|
import torch
|
||||||
import trimesh
|
import trimesh
|
||||||
import shutil, argparse, time, os, glob
|
from plyfile import PlyData
|
||||||
|
from pytorch3d.io import load_objs_as_meshes
|
||||||
import numpy as np; np.set_printoptions(precision=4)
|
from pytorch3d.ops import sample_points_from_meshes
|
||||||
import open3d as o3d
|
from pytorch3d.structures import Meshes
|
||||||
|
from skimage import measure
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from torchvision.utils import save_image
|
|
||||||
from torchvision.io import write_video
|
|
||||||
|
|
||||||
from src.optimization import Trainer
|
from src.optimization import Trainer
|
||||||
from src.utils import load_config, update_config, initialize_logger, \
|
from src.utils import (
|
||||||
get_learning_rate_schedules, adjust_learning_rate, AverageMeter,\
|
AverageMeter,
|
||||||
update_optimizer, export_pointcloud
|
adjust_learning_rate,
|
||||||
from skimage import measure
|
export_pointcloud,
|
||||||
from plyfile import PlyData
|
get_learning_rate_schedules,
|
||||||
from pytorch3d.ops import sample_points_from_meshes
|
initialize_logger,
|
||||||
from pytorch3d.io import load_objs_as_meshes
|
load_config,
|
||||||
from pytorch3d.structures import Meshes
|
update_config,
|
||||||
|
update_optimizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
np.set_printoptions(precision=4)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='MNIST toy experiment')
|
parser = argparse.ArgumentParser(description="MNIST toy experiment")
|
||||||
parser.add_argument('config', type=str, help='Path to config file.')
|
parser.add_argument("config", type=str, help="Path to config file.")
|
||||||
parser.add_argument('--no_cuda', action='store_true', default=False,
|
parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
|
||||||
help='disables CUDA training')
|
parser.add_argument("--seed", type=int, default=1457, metavar="S", help="random seed")
|
||||||
parser.add_argument('--seed', type=int, default=1457, metavar='S',
|
|
||||||
help='random seed')
|
|
||||||
|
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
cfg = load_config(args.config, 'configs/default.yaml')
|
cfg = load_config(args.config, "configs/default.yaml")
|
||||||
cfg = update_config(cfg, unknown)
|
cfg = update_config(cfg, unknown)
|
||||||
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||||
device = torch.device("cuda" if use_cuda else "cpu")
|
device = torch.device("cuda" if use_cuda else "cpu")
|
||||||
data_type = cfg['data']['data_type']
|
data_type = cfg["data"]["data_type"]
|
||||||
data_class = cfg['data']['class']
|
cfg["data"]["class"]
|
||||||
|
|
||||||
print(cfg['train']['out_dir'])
|
print(cfg["train"]["out_dir"])
|
||||||
|
|
||||||
# PYTORCH VERSION > 1.0.0
|
# PYTORCH VERSION > 1.0.0
|
||||||
assert(float(torch.__version__.split('.')[-3]) > 0)
|
assert float(torch.__version__.split(".")[-3]) > 0
|
||||||
|
|
||||||
# boiler-plate
|
# boiler-plate
|
||||||
if cfg['train']['timestamp']:
|
if cfg["train"]["timestamp"]:
|
||||||
cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S")
|
cfg["train"]["out_dir"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
|
||||||
logger = initialize_logger(cfg)
|
logger = initialize_logger(cfg)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
torch.cuda.manual_seed_all(args.seed)
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
shutil.copyfile(args.config,
|
shutil.copyfile(args.config, os.path.join(cfg["train"]["out_dir"], "config.yaml"))
|
||||||
os.path.join(cfg['train']['out_dir'], 'config.yaml'))
|
|
||||||
|
|
||||||
# tensorboardX writer
|
# tensorboardX writer
|
||||||
tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log")
|
tblogdir = os.path.join(cfg["train"]["out_dir"], "tensorboard_log")
|
||||||
if not os.path.exists(tblogdir):
|
if not os.path.exists(tblogdir):
|
||||||
os.makedirs(tblogdir)
|
os.makedirs(tblogdir)
|
||||||
writer = SummaryWriter(log_dir=tblogdir)
|
SummaryWriter(log_dir=tblogdir)
|
||||||
|
|
||||||
# initialize o3d visualizer
|
# initialize o3d visualizer
|
||||||
vis = None
|
vis = None
|
||||||
if cfg['train']['o3d_show']:
|
if cfg["train"]["o3d_show"]:
|
||||||
vis = o3d.visualization.Visualizer()
|
vis = o3d.visualization.Visualizer()
|
||||||
vis.create_window(width=cfg['train']['o3d_window_size'],
|
vis.create_window(width=cfg["train"]["o3d_window_size"], height=cfg["train"]["o3d_window_size"])
|
||||||
height=cfg['train']['o3d_window_size'])
|
|
||||||
|
|
||||||
# initialize dataset
|
# initialize dataset
|
||||||
if data_type == 'point':
|
if data_type == "point":
|
||||||
if cfg['data']['object_id'] != -1:
|
if cfg["data"]["object_id"] != -1:
|
||||||
data_paths = sorted(glob.glob(cfg['data']['data_path']))
|
data_paths = sorted(glob.glob(cfg["data"]["data_path"]))
|
||||||
data_path = data_paths[cfg['data']['object_id']]
|
data_path = data_paths[cfg["data"]["object_id"]]
|
||||||
print('Loaded %d/%d object' % (cfg['data']['object_id']+1, len(data_paths)))
|
print("Loaded %d/%d object" % (cfg["data"]["object_id"] + 1, len(data_paths)))
|
||||||
else:
|
else:
|
||||||
data_path = cfg['data']['data_path']
|
data_path = cfg["data"]["data_path"]
|
||||||
print('Data loaded')
|
print("Data loaded")
|
||||||
ext = data_path.split('.')[-1]
|
ext = data_path.split(".")[-1]
|
||||||
if ext == 'obj': # have GT mesh
|
if ext == "obj": # have GT mesh
|
||||||
mesh = load_objs_as_meshes([data_path], device=device)
|
mesh = load_objs_as_meshes([data_path], device=device)
|
||||||
# scale the mesh into unit cube
|
# scale the mesh into unit cube
|
||||||
verts = mesh.verts_packed()
|
verts = mesh.verts_packed()
|
||||||
|
@ -81,20 +88,15 @@ def main():
|
||||||
center = verts.mean(0)
|
center = verts.mean(0)
|
||||||
mesh.offset_verts_(-center.expand(N, 3))
|
mesh.offset_verts_(-center.expand(N, 3))
|
||||||
scale = max((verts - center).abs().max(0)[0])
|
scale = max((verts - center).abs().max(0)[0])
|
||||||
mesh.scale_verts_((1.0 / float(scale)))
|
mesh.scale_verts_(1.0 / float(scale))
|
||||||
# important for our DPSR to have the range in [0, 1), not reaching 1
|
# important for our DPSR to have the range in [0, 1), not reaching 1
|
||||||
mesh.scale_verts_(0.9)
|
mesh.scale_verts_(0.9)
|
||||||
|
|
||||||
target_pts, target_normals = sample_points_from_meshes(mesh,
|
target_pts, target_normals = sample_points_from_meshes(mesh, num_samples=200000, return_normals=True)
|
||||||
num_samples=200000, return_normals=True)
|
elif ext == "ply": # only have the point cloud
|
||||||
elif ext == 'ply': # only have the point cloud
|
|
||||||
plydata = PlyData.read(data_path)
|
plydata = PlyData.read(data_path)
|
||||||
vertices = np.stack([plydata['vertex']['x'],
|
vertices = np.stack([plydata["vertex"]["x"], plydata["vertex"]["y"], plydata["vertex"]["z"]], axis=1)
|
||||||
plydata['vertex']['y'],
|
normals = np.stack([plydata["vertex"]["nx"], plydata["vertex"]["ny"], plydata["vertex"]["nz"]], axis=1)
|
||||||
plydata['vertex']['z']], axis=1)
|
|
||||||
normals = np.stack([plydata['vertex']['nx'],
|
|
||||||
plydata['vertex']['ny'],
|
|
||||||
plydata['vertex']['nz']], axis=1)
|
|
||||||
N = vertices.shape[0]
|
N = vertices.shape[0]
|
||||||
center = vertices.mean(0)
|
center = vertices.mean(0)
|
||||||
scale = np.max(np.max(np.abs(vertices - center), axis=0))
|
scale = np.max(np.max(np.abs(vertices - center), axis=0))
|
||||||
|
@ -111,205 +113,205 @@ def main():
|
||||||
if not torch.is_tensor(scale):
|
if not torch.is_tensor(scale):
|
||||||
scale = torch.from_numpy(np.array([scale]))
|
scale = torch.from_numpy(np.array([scale]))
|
||||||
|
|
||||||
data = {'target_points': target_pts,
|
data = {
|
||||||
'target_normals': target_normals, # normals are never used
|
"target_points": target_pts,
|
||||||
'gt_mesh': mesh}
|
"target_normals": target_normals, # normals are never used
|
||||||
|
"gt_mesh": mesh,
|
||||||
|
}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
# save the input point cloud
|
# save the input point cloud
|
||||||
if 'target_points' in data.keys():
|
if "target_points" in data.keys():
|
||||||
outdir_pcl = os.path.join(cfg['train']['out_dir'], 'target_pcl.ply')
|
outdir_pcl = os.path.join(cfg["train"]["out_dir"], "target_pcl.ply")
|
||||||
if 'target_normals' in data.keys():
|
if "target_normals" in data.keys():
|
||||||
export_pointcloud(outdir_pcl, data['target_points'], data['target_normals'])
|
export_pointcloud(outdir_pcl, data["target_points"], data["target_normals"])
|
||||||
else:
|
else:
|
||||||
export_pointcloud(outdir_pcl, data['target_points'])
|
export_pointcloud(outdir_pcl, data["target_points"])
|
||||||
|
|
||||||
# save oracle PSR mesh (mesh from our PSR using GT point+normals)
|
# save oracle PSR mesh (mesh from our PSR using GT point+normals)
|
||||||
if data.get('gt_mesh') is not None:
|
if data.get("gt_mesh") is not None:
|
||||||
gt_verts, gt_faces = data['gt_mesh'].get_mesh_verts_faces(0)
|
gt_verts, gt_faces = data["gt_mesh"].get_mesh_verts_faces(0)
|
||||||
pts_gt, norms_gt = sample_points_from_meshes(data['gt_mesh'],
|
pts_gt, norms_gt = sample_points_from_meshes(data["gt_mesh"], num_samples=500000, return_normals=True)
|
||||||
num_samples=500000, return_normals=True)
|
|
||||||
pts_gt = (pts_gt + 1) / 2
|
pts_gt = (pts_gt + 1) / 2
|
||||||
from src.dpsr import DPSR
|
from src.dpsr import DPSR
|
||||||
dpsr_tmp = DPSR(res=(cfg['model']['grid_res'],
|
|
||||||
cfg['model']['grid_res'],
|
dpsr_tmp = DPSR(
|
||||||
cfg['model']['grid_res']),
|
res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
|
||||||
sig=cfg['model']['psr_sigma']).to(device)
|
sig=cfg["model"]["psr_sigma"],
|
||||||
|
).to(device)
|
||||||
target = dpsr_tmp(pts_gt, norms_gt).unsqueeze(1).to(device)
|
target = dpsr_tmp(pts_gt, norms_gt).unsqueeze(1).to(device)
|
||||||
target = torch.tanh(target)
|
target = torch.tanh(target)
|
||||||
s = target.shape[-1] # size of psr_grid
|
s = target.shape[-1] # size of psr_grid
|
||||||
psr_grid_numpy = target.squeeze().detach().cpu().numpy()
|
psr_grid_numpy = target.squeeze().detach().cpu().numpy()
|
||||||
verts, faces, _, _ = measure.marching_cubes(psr_grid_numpy)
|
verts, faces, _, _ = measure.marching_cubes(psr_grid_numpy)
|
||||||
verts = verts / s * 2. - 1 # [-1, 1]
|
verts = verts / s * 2.0 - 1 # [-1, 1]
|
||||||
mesh = o3d.geometry.TriangleMesh()
|
mesh = o3d.geometry.TriangleMesh()
|
||||||
mesh.vertices = o3d.utility.Vector3dVector(verts)
|
mesh.vertices = o3d.utility.Vector3dVector(verts)
|
||||||
mesh.triangles = o3d.utility.Vector3iVector(faces)
|
mesh.triangles = o3d.utility.Vector3iVector(faces)
|
||||||
outdir_mesh = os.path.join(cfg['train']['out_dir'], 'oracle_mesh.ply')
|
outdir_mesh = os.path.join(cfg["train"]["out_dir"], "oracle_mesh.ply")
|
||||||
o3d.io.write_triangle_mesh(outdir_mesh, mesh)
|
o3d.io.write_triangle_mesh(outdir_mesh, mesh)
|
||||||
|
|
||||||
# initialize the source point cloud given an input mesh
|
# initialize the source point cloud given an input mesh
|
||||||
if 'input_mesh' in cfg['train'].keys() and \
|
if "input_mesh" in cfg["train"].keys() and os.path.isfile(cfg["train"]["input_mesh"]):
|
||||||
os.path.isfile(cfg['train']['input_mesh']):
|
if cfg["train"]["input_mesh"].split("/")[-2] == "mesh":
|
||||||
if cfg['train']['input_mesh'].split('/')[-2] == 'mesh':
|
mesh_tmp = trimesh.load_mesh(cfg["train"]["input_mesh"])
|
||||||
mesh_tmp = trimesh.load_mesh(cfg['train']['input_mesh'])
|
|
||||||
verts = torch.from_numpy(mesh_tmp.vertices[None]).float().to(device)
|
verts = torch.from_numpy(mesh_tmp.vertices[None]).float().to(device)
|
||||||
faces = torch.from_numpy(mesh_tmp.faces[None]).to(device)
|
faces = torch.from_numpy(mesh_tmp.faces[None]).to(device)
|
||||||
mesh = Meshes(verts=verts, faces=faces)
|
mesh = Meshes(verts=verts, faces=faces)
|
||||||
points, normals = sample_points_from_meshes(mesh,
|
points, normals = sample_points_from_meshes(
|
||||||
num_samples=cfg['data']['num_points'], return_normals=True)
|
mesh, num_samples=cfg["data"]["num_points"], return_normals=True,
|
||||||
|
)
|
||||||
# mesh is saved in the original scale of the gt
|
# mesh is saved in the original scale of the gt
|
||||||
points -= center.float().to(device)
|
points -= center.float().to(device)
|
||||||
points /= scale.float().to(device)
|
points /= scale.float().to(device)
|
||||||
points *= 0.9
|
points *= 0.9
|
||||||
# make sure the points are within the range of [0, 1)
|
# make sure the points are within the range of [0, 1)
|
||||||
points = points / 2. + 0.5
|
points = points / 2.0 + 0.5
|
||||||
else:
|
else:
|
||||||
# directly initialize from a point cloud
|
# directly initialize from a point cloud
|
||||||
pcd = o3d.io.read_point_cloud(cfg['train']['input_mesh'])
|
pcd = o3d.io.read_point_cloud(cfg["train"]["input_mesh"])
|
||||||
points = torch.from_numpy(np.array(pcd.points)[None]).float().to(device)
|
points = torch.from_numpy(np.array(pcd.points)[None]).float().to(device)
|
||||||
normals = torch.from_numpy(np.array(pcd.normals)[None]).float().to(device)
|
normals = torch.from_numpy(np.array(pcd.normals)[None]).float().to(device)
|
||||||
points -= center.float().to(device)
|
points -= center.float().to(device)
|
||||||
points /= scale.float().to(device)
|
points /= scale.float().to(device)
|
||||||
points *= 0.9
|
points *= 0.9
|
||||||
points = points / 2. + 0.5
|
points = points / 2.0 + 0.5
|
||||||
else: #! initialize our source point cloud from a sphere
|
else: #! initialize our source point cloud from a sphere
|
||||||
sphere_radius = cfg['model']['sphere_radius']
|
sphere_radius = cfg["model"]["sphere_radius"]
|
||||||
sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius,
|
sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius, count=[256, 256])
|
||||||
count=[256,256])
|
points, idx = sphere_mesh.sample(cfg["data"]["num_points"], return_index=True)
|
||||||
points, idx = sphere_mesh.sample(cfg['data']['num_points'],
|
|
||||||
return_index=True)
|
|
||||||
points += 0.5 # make sure the points are within the range of [0, 1)
|
points += 0.5 # make sure the points are within the range of [0, 1)
|
||||||
normals = sphere_mesh.face_normals[idx]
|
normals = sphere_mesh.face_normals[idx]
|
||||||
points = torch.from_numpy(points).unsqueeze(0).to(device)
|
points = torch.from_numpy(points).unsqueeze(0).to(device)
|
||||||
normals = torch.from_numpy(normals).unsqueeze(0).to(device)
|
normals = torch.from_numpy(normals).unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
points = torch.log(points / (1 - points)) # inverse sigmoid
|
||||||
points = torch.log(points/(1-points)) # inverse sigmoid
|
|
||||||
inputs = torch.cat([points, normals], axis=-1).float()
|
inputs = torch.cat([points, normals], axis=-1).float()
|
||||||
inputs.requires_grad = True
|
inputs.requires_grad = True
|
||||||
|
|
||||||
model = None # no network
|
model = None # no network
|
||||||
|
|
||||||
# initialize optimizer
|
# initialize optimizer
|
||||||
cfg['train']['schedule']['pcl']['initial'] = cfg['train']['lr_pcl']
|
cfg["train"]["schedule"]["pcl"]["initial"] = cfg["train"]["lr_pcl"]
|
||||||
print('Initial learning rate:', cfg['train']['schedule']['pcl']['initial'])
|
print("Initial learning rate:", cfg["train"]["schedule"]["pcl"]["initial"])
|
||||||
if 'schedule' in cfg['train']:
|
if "schedule" in cfg["train"]:
|
||||||
lr_schedules = get_learning_rate_schedules(cfg['train']['schedule'])
|
lr_schedules = get_learning_rate_schedules(cfg["train"]["schedule"])
|
||||||
else:
|
else:
|
||||||
lr_schedules = None
|
lr_schedules = None
|
||||||
|
|
||||||
optimizer = update_optimizer(inputs, cfg,
|
optimizer = update_optimizer(inputs, cfg, epoch=0, model=model, schedule=lr_schedules)
|
||||||
epoch=0, model=model, schedule=lr_schedules)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# load model
|
# load model
|
||||||
state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt'))
|
state_dict = torch.load(os.path.join(cfg["train"]["out_dir"], "model.pt"))
|
||||||
if ('pcl' in state_dict.keys()) & (state_dict['pcl'] is not None):
|
if ("pcl" in state_dict.keys()) & (state_dict["pcl"] is not None):
|
||||||
inputs = state_dict['pcl'].to(device)
|
inputs = state_dict["pcl"].to(device)
|
||||||
inputs.requires_grad = True
|
inputs.requires_grad = True
|
||||||
|
|
||||||
optimizer = update_optimizer(inputs, cfg,
|
optimizer = update_optimizer(inputs, cfg, epoch=state_dict.get("epoch"), schedule=lr_schedules)
|
||||||
epoch=state_dict.get('epoch'), schedule=lr_schedules)
|
|
||||||
|
|
||||||
out = "Load model from epoch %d" % state_dict.get('epoch', 0)
|
out = "Load model from epoch %d" % state_dict.get("epoch", 0)
|
||||||
print(out)
|
print(out)
|
||||||
logger.info(out)
|
logger.info(out)
|
||||||
except:
|
except:
|
||||||
state_dict = dict()
|
state_dict = dict()
|
||||||
|
|
||||||
start_epoch = state_dict.get('epoch', -1)
|
start_epoch = state_dict.get("epoch", -1)
|
||||||
|
|
||||||
trainer = Trainer(cfg, optimizer, device=device)
|
trainer = Trainer(cfg, optimizer, device=device)
|
||||||
runtime = {}
|
runtime = {}
|
||||||
runtime['all'] = AverageMeter()
|
runtime["all"] = AverageMeter()
|
||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1):
|
for epoch in range(start_epoch + 1, cfg["train"]["total_epochs"] + 1):
|
||||||
|
|
||||||
# schedule the learning rate
|
# schedule the learning rate
|
||||||
if (epoch>0) & (lr_schedules is not None):
|
if (epoch > 0) & (lr_schedules is not None):
|
||||||
if (epoch % lr_schedules[0].interval == 0):
|
if epoch % lr_schedules[0].interval == 0:
|
||||||
adjust_learning_rate(lr_schedules, optimizer, epoch)
|
adjust_learning_rate(lr_schedules, optimizer, epoch)
|
||||||
if len(lr_schedules) >1:
|
if len(lr_schedules) > 1:
|
||||||
print('[epoch {}] net_lr: {}, pcl_lr: {}'.format(epoch,
|
print(
|
||||||
lr_schedules[0].get_learning_rate(epoch),
|
"[epoch {}] net_lr: {}, pcl_lr: {}".format(
|
||||||
lr_schedules[1].get_learning_rate(epoch)))
|
epoch, lr_schedules[0].get_learning_rate(epoch), lr_schedules[1].get_learning_rate(epoch),
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print('[epoch {}] adjust pcl_lr to: {}'.format(epoch,
|
print(f"[epoch {epoch}] adjust pcl_lr to: {lr_schedules[0].get_learning_rate(epoch)}")
|
||||||
lr_schedules[0].get_learning_rate(epoch)))
|
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
loss, loss_each = trainer.train_step(data, inputs, model, epoch)
|
loss, loss_each = trainer.train_step(data, inputs, model, epoch)
|
||||||
runtime['all'].update(time.time() - start)
|
runtime["all"].update(time.time() - start)
|
||||||
|
|
||||||
if epoch % cfg['train']['print_every'] == 0:
|
if epoch % cfg["train"]["print_every"] == 0:
|
||||||
log_text = ('[Epoch %02d] loss=%.5f') %(epoch, loss)
|
log_text = ("[Epoch %02d] loss=%.5f") % (epoch, loss)
|
||||||
if loss_each is not None:
|
if loss_each is not None:
|
||||||
for k, l in loss_each.items():
|
for k, l in loss_each.items():
|
||||||
if l.item() != 0.:
|
if l.item() != 0.0:
|
||||||
log_text += (' loss_%s=%.5f') % (k, l.item())
|
log_text += f" loss_{k}={l.item():.5f}"
|
||||||
|
|
||||||
log_text += (' time=%.3f / %.3f') % (runtime['all'].val,
|
log_text += (" time={:.3f} / {:.3f}").format(runtime["all"].val, runtime["all"].sum)
|
||||||
runtime['all'].sum)
|
|
||||||
logger.info(log_text)
|
logger.info(log_text)
|
||||||
print(log_text)
|
print(log_text)
|
||||||
|
|
||||||
# visualize point clouds and meshes
|
# visualize point clouds and meshes
|
||||||
if (epoch % cfg['train']['visualize_every'] == 0) & (vis is not None):
|
if (epoch % cfg["train"]["visualize_every"] == 0) & (vis is not None):
|
||||||
trainer.visualize(data, inputs, model, epoch, o3d_vis=vis)
|
trainer.visualize(data, inputs, model, epoch, o3d_vis=vis)
|
||||||
|
|
||||||
# save outputs
|
# save outputs
|
||||||
if epoch % cfg['train']['save_every'] == 0:
|
if epoch % cfg["train"]["save_every"] == 0:
|
||||||
trainer.save_mesh_pointclouds(inputs, epoch,
|
trainer.save_mesh_pointclouds(inputs, epoch, center.cpu().numpy(), scale.cpu().numpy() * (1 / 0.9))
|
||||||
center.cpu().numpy(),
|
|
||||||
scale.cpu().numpy()*(1/0.9))
|
|
||||||
|
|
||||||
# save checkpoints
|
# save checkpoints
|
||||||
if (epoch > 0) & (epoch % cfg['train']['checkpoint_every'] == 0):
|
if (epoch > 0) & (epoch % cfg["train"]["checkpoint_every"] == 0):
|
||||||
state = {'epoch': epoch}
|
state = {"epoch": epoch}
|
||||||
pcl = None
|
|
||||||
if isinstance(inputs, torch.Tensor):
|
if isinstance(inputs, torch.Tensor):
|
||||||
state['pcl'] = inputs.detach().cpu()
|
state["pcl"] = inputs.detach().cpu()
|
||||||
|
|
||||||
torch.save(state, os.path.join(cfg['train']['dir_model'],
|
torch.save(state, os.path.join(cfg["train"]["dir_model"], "%04d" % epoch + ".pt"))
|
||||||
'%04d' % epoch + '.pt'))
|
|
||||||
print("Save new model at epoch %d" % epoch)
|
print("Save new model at epoch %d" % epoch)
|
||||||
logger.info("Save new model at epoch %d" % epoch)
|
logger.info("Save new model at epoch %d" % epoch)
|
||||||
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt'))
|
torch.save(state, os.path.join(cfg["train"]["out_dir"], "model.pt"))
|
||||||
|
|
||||||
# resample and gradually add new points to the source pcl
|
# resample and gradually add new points to the source pcl
|
||||||
if (epoch > 0) & \
|
if (
|
||||||
(cfg['train']['resample_every']!=0) & \
|
(epoch > 0)
|
||||||
(epoch % cfg['train']['resample_every'] == 0) & \
|
& (cfg["train"]["resample_every"] != 0)
|
||||||
(epoch < cfg['train']['total_epochs']):
|
& (epoch % cfg["train"]["resample_every"] == 0)
|
||||||
|
& (epoch < cfg["train"]["total_epochs"])
|
||||||
|
):
|
||||||
inputs = trainer.point_resampling(inputs)
|
inputs = trainer.point_resampling(inputs)
|
||||||
optimizer = update_optimizer(inputs, cfg,
|
optimizer = update_optimizer(inputs, cfg, epoch=epoch, model=model, schedule=lr_schedules)
|
||||||
epoch=epoch, model=model, schedule=lr_schedules)
|
|
||||||
trainer = Trainer(cfg, optimizer, device=device)
|
trainer = Trainer(cfg, optimizer, device=device)
|
||||||
|
|
||||||
# visualize the Open3D outputs
|
# visualize the Open3D outputs
|
||||||
if cfg['train']['o3d_show']:
|
if cfg["train"]["o3d_show"]:
|
||||||
out_video_dir = os.path.join(cfg['train']['out_dir'],
|
out_video_dir = os.path.join(cfg["train"]["out_dir"], "vis/o3d/video.mp4")
|
||||||
'vis/o3d/video.mp4')
|
|
||||||
if os.path.isfile(out_video_dir):
|
if os.path.isfile(out_video_dir):
|
||||||
os.system('rm {}'.format(out_video_dir))
|
os.system(f"rm {out_video_dir}")
|
||||||
os.system('ffmpeg -framerate 30 \
|
os.system(
|
||||||
|
"ffmpeg -framerate 30 \
|
||||||
-start_number 0 \
|
-start_number 0 \
|
||||||
-i {}/vis/o3d/%04d.jpg \
|
-i {}/vis/o3d/%04d.jpg \
|
||||||
-pix_fmt yuv420p \
|
-pix_fmt yuv420p \
|
||||||
-crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir))
|
-crf 17 {}".format(
|
||||||
out_video_dir = os.path.join(cfg['train']['out_dir'],
|
cfg["train"]["out_dir"], out_video_dir,
|
||||||
'vis/o3d/video_pcd.mp4')
|
),
|
||||||
|
)
|
||||||
|
out_video_dir = os.path.join(cfg["train"]["out_dir"], "vis/o3d/video_pcd.mp4")
|
||||||
if os.path.isfile(out_video_dir):
|
if os.path.isfile(out_video_dir):
|
||||||
os.system('rm {}'.format(out_video_dir))
|
os.system(f"rm {out_video_dir}")
|
||||||
os.system('ffmpeg -framerate 30 \
|
os.system(
|
||||||
|
"ffmpeg -framerate 30 \
|
||||||
-start_number 0 \
|
-start_number 0 \
|
||||||
-i {}/vis/o3d/%04d_pcd.jpg \
|
-i {}/vis/o3d/%04d_pcd.jpg \
|
||||||
-pix_fmt yuv420p \
|
-pix_fmt yuv420p \
|
||||||
-crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir))
|
-crf 17 {}".format(
|
||||||
print('Video saved.')
|
cfg["train"]["out_dir"], out_video_dir,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
print("Video saved.")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -1,69 +1,80 @@
|
||||||
import sys, os
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
from src.utils import load_config
|
from src.utils import load_config
|
||||||
import subprocess
|
|
||||||
os.environ['MKL_THREADING_LAYER'] = 'GNU'
|
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="MNIST toy experiment")
|
||||||
parser = argparse.ArgumentParser(description='MNIST toy experiment')
|
parser.add_argument("config", type=str, help="Path to config file.")
|
||||||
parser.add_argument('config', type=str, help='Path to config file.')
|
parser.add_argument("--start_res", type=int, default=-1, help="Resolution to start with.")
|
||||||
parser.add_argument('--start_res', type=int, default=-1, help='Resolution to start with.')
|
parser.add_argument("--object_id", type=int, default=-1, help="Object index.")
|
||||||
parser.add_argument('--object_id', type=int, default=-1, help='Object index.')
|
|
||||||
|
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
cfg = load_config(args.config, 'configs/default.yaml')
|
cfg = load_config(args.config, "configs/default.yaml")
|
||||||
|
|
||||||
resolutions=[32, 64, 128, 256]
|
resolutions = [32, 64, 128, 256]
|
||||||
iterations=[1000, 1000, 1000, 200]
|
iterations = [1000, 1000, 1000, 200]
|
||||||
lrs=[2e-3, 2e-3*0.7, 2e-3*(0.7**2), 2e-3*(0.7**3)] # reduce lr
|
lrs = [2e-3, 2e-3 * 0.7, 2e-3 * (0.7**2), 2e-3 * (0.7**3)] # reduce lr
|
||||||
for idx,(res, iteration, lr) in enumerate(zip(resolutions, iterations, lrs)):
|
for idx, (res, iteration, lr) in enumerate(zip(resolutions, iterations, lrs)):
|
||||||
|
if res < args.start_res:
|
||||||
if res<args.start_res:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if res>cfg['model']['grid_res']:
|
if res > cfg["model"]["grid_res"]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
psr_sigma= 2 if res<=128 else 3
|
psr_sigma = 2 if res <= 128 else 3
|
||||||
|
|
||||||
if res > 128:
|
if res > 128:
|
||||||
psr_sigma = 5 if 'thingi_noisy' in args.config else 3
|
psr_sigma = 5 if "thingi_noisy" in args.config else 3
|
||||||
|
|
||||||
if args.object_id != -1:
|
if args.object_id != -1:
|
||||||
out_dir = os.path.join(cfg['train']['out_dir'], 'object_%02d'%args.object_id, 'res_%d'%res)
|
out_dir = os.path.join(cfg["train"]["out_dir"], "object_%02d" % args.object_id, "res_%d" % res)
|
||||||
else:
|
else:
|
||||||
out_dir = os.path.join(cfg['train']['out_dir'], 'res_%d'%res)
|
out_dir = os.path.join(cfg["train"]["out_dir"], "res_%d" % res)
|
||||||
|
|
||||||
# sample from mesh when resampling is enabled, otherwise reuse the pointcloud
|
# sample from mesh when resampling is enabled, otherwise reuse the pointcloud
|
||||||
init_shape='mesh' if cfg['train']['resample_every']>0 else 'pointcloud'
|
init_shape = "mesh" if cfg["train"]["resample_every"] > 0 else "pointcloud"
|
||||||
|
|
||||||
|
|
||||||
if args.object_id != -1:
|
if args.object_id != -1:
|
||||||
input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'],
|
input_mesh = (
|
||||||
'object_%02d'%args.object_id, 'res_%d' % (resolutions[idx-1]),
|
"None"
|
||||||
'vis', init_shape, '%04d.ply' % (iterations[idx-1]))
|
if idx == 0
|
||||||
|
else os.path.join(
|
||||||
|
cfg["train"]["out_dir"],
|
||||||
|
"object_%02d" % args.object_id,
|
||||||
|
"res_%d" % (resolutions[idx - 1]),
|
||||||
|
"vis",
|
||||||
|
init_shape,
|
||||||
|
"%04d.ply" % (iterations[idx - 1]),
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'],
|
input_mesh = (
|
||||||
'res_%d' % (resolutions[idx-1]),
|
"None"
|
||||||
'vis', init_shape, '%04d.ply' % (iterations[idx-1]))
|
if idx == 0
|
||||||
|
else os.path.join(
|
||||||
|
cfg["train"]["out_dir"],
|
||||||
|
"res_%d" % (resolutions[idx - 1]),
|
||||||
|
"vis",
|
||||||
|
init_shape,
|
||||||
|
"%04d.ply" % (iterations[idx - 1]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd = "export MKL_SERVICE_FORCE_INTEL=1 && "
|
||||||
cmd = 'export MKL_SERVICE_FORCE_INTEL=1 && '
|
cmd += (
|
||||||
cmd += "python optim.py %s --model:grid_res %d --model:psr_sigma %d \
|
"python optim.py %s --model:grid_res %d --model:psr_sigma %d \
|
||||||
--train:input_mesh %s --train:total_epochs %d \
|
--train:input_mesh %s --train:total_epochs %d \
|
||||||
--train:out_dir %s --train:lr_pcl %f \
|
--train:out_dir %s --train:lr_pcl %f \
|
||||||
--data:object_id %d" % (
|
--data:object_id %d"
|
||||||
args.config,
|
% (args.config, res, psr_sigma, input_mesh, iteration, out_dir, lr, args.object_id)
|
||||||
res,
|
)
|
||||||
psr_sigma,
|
|
||||||
input_mesh,
|
|
||||||
iteration,
|
|
||||||
out_dir,
|
|
||||||
lr,
|
|
||||||
args.object_id)
|
|
||||||
print(cmd)
|
print(cmd)
|
||||||
os.system(cmd)
|
os.system(cmd)
|
||||||
|
|
||||||
if __name__=="__main__":
|
|
||||||
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
55
pyproject.toml
Normal file
55
pyproject.toml
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 120
|
||||||
|
ignore-init-module-imports = true
|
||||||
|
ignore = [
|
||||||
|
"G004", # Logging statement uses f-string
|
||||||
|
"EM102", # Exception must not use an f-string literal, assign to variable first
|
||||||
|
]
|
||||||
|
select = [
|
||||||
|
"A", # flake8-builtins
|
||||||
|
"B", # flake8-bugbear
|
||||||
|
"C90", # mccabe
|
||||||
|
"COM", # flake8-commas
|
||||||
|
"D", # pydocstyle
|
||||||
|
"EM", # flake8-errmsg
|
||||||
|
"E", # pycodestyle errors
|
||||||
|
"F", # Pyflakes
|
||||||
|
"G", # flake8-logging-format
|
||||||
|
"I", # isort
|
||||||
|
"N", # pep8-naming
|
||||||
|
"PIE", # flake8-pie
|
||||||
|
"PTH", # flake8-use-pathlib
|
||||||
|
"RET", # flake8-return
|
||||||
|
"RUF", # ruff
|
||||||
|
"S", # flake8-bandit
|
||||||
|
"TCH", # flake8-type-checking
|
||||||
|
"TID", # flake8-tidy-imports
|
||||||
|
"UP", # pyupgrade
|
||||||
|
"W", # pycodestyle warnings
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.pydocstyle]
|
||||||
|
convention = "google"
|
||||||
|
|
||||||
|
[tool.ruff.per-file-ignores]
|
||||||
|
"__init__.py" = ["F401"]
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
exclude = '''
|
||||||
|
/(
|
||||||
|
\.git
|
||||||
|
\.venv
|
||||||
|
)/
|
||||||
|
'''
|
||||||
|
include = '\.pyi?$'
|
||||||
|
line-length = 120
|
||||||
|
target-version = ["py310"]
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
multi_line_output = 3
|
||||||
|
profile = "black"
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.10"
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unused_configs = true
|
|
@ -1,14 +1,16 @@
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import time
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from src.dpsr import DPSR
|
from src.dpsr import DPSR
|
||||||
|
|
||||||
data_path = 'data/ShapeNet' # path for ShapeNet from ONet
|
data_path = "data/ShapeNet" # path for ShapeNet from ONet
|
||||||
base = 'data' # output base directory
|
base = "data" # output base directory
|
||||||
dataset_name = 'shapenet_psr'
|
dataset_name = "shapenet_psr"
|
||||||
multiprocess = True
|
multiprocess = True
|
||||||
njobs = 8
|
njobs = 8
|
||||||
save_pointcloud = True
|
save_pointcloud = True
|
||||||
|
@ -20,20 +22,20 @@ padding = 1.2
|
||||||
|
|
||||||
dpsr = DPSR(res=(resolution, resolution, resolution), sig=0)
|
dpsr = DPSR(res=(resolution, resolution, resolution), sig=0)
|
||||||
|
|
||||||
def process_one(obj):
|
|
||||||
|
|
||||||
obj_name = obj.split('/')[-1]
|
def process_one(obj):
|
||||||
c = obj.split('/')[-2]
|
obj_name = obj.split("/")[-1]
|
||||||
|
c = obj.split("/")[-2]
|
||||||
|
|
||||||
# create new for the current object
|
# create new for the current object
|
||||||
out_path_cur = os.path.join(base, dataset_name, c)
|
out_path_cur = os.path.join(base, dataset_name, c)
|
||||||
out_path_cur_obj = os.path.join(out_path_cur, obj_name)
|
out_path_cur_obj = os.path.join(out_path_cur, obj_name)
|
||||||
os.makedirs(out_path_cur_obj, exist_ok=True)
|
os.makedirs(out_path_cur_obj, exist_ok=True)
|
||||||
|
|
||||||
gt_path = os.path.join(data_path, c, obj_name, 'pointcloud.npz')
|
gt_path = os.path.join(data_path, c, obj_name, "pointcloud.npz")
|
||||||
data = np.load(gt_path)
|
data = np.load(gt_path)
|
||||||
points = data['points']
|
points = data["points"]
|
||||||
normals = data['normals']
|
normals = data["normals"]
|
||||||
|
|
||||||
# normalize the point to [0, 1)
|
# normalize the point to [0, 1)
|
||||||
points = points / padding + 0.5
|
points = points / padding + 0.5
|
||||||
|
@ -41,31 +43,35 @@ def process_one(obj):
|
||||||
#! p = (p - 0.5) * padding
|
#! p = (p - 0.5) * padding
|
||||||
|
|
||||||
if save_pointcloud:
|
if save_pointcloud:
|
||||||
outdir = os.path.join(out_path_cur_obj, 'pointcloud.npz')
|
outdir = os.path.join(out_path_cur_obj, "pointcloud.npz")
|
||||||
# np.savez(outdir, points=points, normals=normals)
|
# np.savez(outdir, points=points, normals=normals)
|
||||||
np.savez(outdir, points=data['points'], normals=data['normals'])
|
np.savez(outdir, points=data["points"], normals=data["normals"])
|
||||||
# return
|
# return
|
||||||
|
|
||||||
if save_psr_field:
|
if save_psr_field:
|
||||||
psr_gt = dpsr(torch.from_numpy(points.astype(np.float32))[None],
|
psr_gt = (
|
||||||
torch.from_numpy(normals.astype(np.float32))[None]).squeeze().cpu().numpy().astype(np.float16)
|
dpsr(torch.from_numpy(points.astype(np.float32))[None], torch.from_numpy(normals.astype(np.float32))[None])
|
||||||
|
.squeeze()
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
.astype(np.float16)
|
||||||
|
)
|
||||||
|
|
||||||
outdir = os.path.join(out_path_cur_obj, 'psr.npz')
|
outdir = os.path.join(out_path_cur_obj, "psr.npz")
|
||||||
np.savez(outdir, psr=psr_gt)
|
np.savez(outdir, psr=psr_gt)
|
||||||
|
|
||||||
|
|
||||||
def main(c):
|
def main(c):
|
||||||
|
print("---------------------------------------")
|
||||||
|
print(f"Processing {c} {split}")
|
||||||
|
print("---------------------------------------")
|
||||||
|
|
||||||
print('---------------------------------------')
|
for split in ["train", "val", "test"]:
|
||||||
print('Processing {} {}'.format(c, split))
|
fname = os.path.join(data_path, c, split + ".lst")
|
||||||
print('---------------------------------------')
|
with open(fname) as f:
|
||||||
|
|
||||||
for split in ['train', 'val', 'test']:
|
|
||||||
fname = os.path.join(data_path, c, split+'.lst')
|
|
||||||
with open(fname, 'r') as f:
|
|
||||||
obj_list = f.read().splitlines()
|
obj_list = f.read().splitlines()
|
||||||
|
|
||||||
obj_list = [c+'/'+s for s in obj_list]
|
obj_list = [c + "/" + s for s in obj_list]
|
||||||
|
|
||||||
if multiprocess:
|
if multiprocess:
|
||||||
# multiprocessing.set_start_method('spawn', force=True)
|
# multiprocessing.set_start_method('spawn', force=True)
|
||||||
|
@ -82,20 +88,29 @@ def main(c):
|
||||||
for obj in tqdm(obj_list):
|
for obj in tqdm(obj_list):
|
||||||
process_one(obj)
|
process_one(obj)
|
||||||
|
|
||||||
print('Done Processing {} {}!'.format(c, split))
|
print(f"Done Processing {c} {split}!")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
classes = [
|
||||||
classes = ['02691156', '02828884', '02933112',
|
"02691156",
|
||||||
'02958343', '03211117', '03001627',
|
"02828884",
|
||||||
'03636649', '03691459', '04090263',
|
"02933112",
|
||||||
'04256520', '04379243', '04401088', '04530566']
|
"02958343",
|
||||||
|
"03211117",
|
||||||
|
"03001627",
|
||||||
|
"03636649",
|
||||||
|
"03691459",
|
||||||
|
"04090263",
|
||||||
|
"04256520",
|
||||||
|
"04379243",
|
||||||
|
"04401088",
|
||||||
|
"04530566",
|
||||||
|
]
|
||||||
|
|
||||||
t_start = time.time()
|
t_start = time.time()
|
||||||
for c in classes:
|
for c in classes:
|
||||||
main(c)
|
main(c)
|
||||||
|
|
||||||
t_end = time.time()
|
t_end = time.time()
|
||||||
print('Total processing time: ', t_end - t_start)
|
print("Total processing time: ", t_end - t_start)
|
||||||
|
|
139
src/config.py
139
src/config.py
|
@ -1,146 +1,151 @@
|
||||||
import yaml
|
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
from src import data, generation
|
from src import data, generation
|
||||||
from src.dpsr import DPSR
|
from src.dpsr import DPSR
|
||||||
from ipdb import set_trace as st
|
|
||||||
|
|
||||||
|
|
||||||
# Generator for final mesh extraction
|
# Generator for final mesh extraction
|
||||||
def get_generator(model, cfg, device, **kwargs):
|
def get_generator(model, cfg, device, **kwargs):
|
||||||
''' Returns the generator object.
|
"""Returns the generator object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): Occupancy Network model
|
model (nn.Module): Occupancy Network model
|
||||||
cfg (dict): imported yaml config
|
cfg (dict): imported yaml config
|
||||||
device (device): pytorch device
|
device (device): pytorch device
|
||||||
'''
|
"""
|
||||||
|
if cfg["generation"]["psr_resolution"] == 0:
|
||||||
if cfg['generation']['psr_resolution'] == 0:
|
psr_res = cfg["model"]["grid_res"]
|
||||||
psr_res = cfg['model']['grid_res']
|
psr_sigma = cfg["model"]["psr_sigma"]
|
||||||
psr_sigma = cfg['model']['psr_sigma']
|
|
||||||
else:
|
else:
|
||||||
psr_res = cfg['generation']['psr_resolution']
|
psr_res = cfg["generation"]["psr_resolution"]
|
||||||
psr_sigma = cfg['generation']['psr_sigma']
|
psr_sigma = cfg["generation"]["psr_sigma"]
|
||||||
|
|
||||||
dpsr = DPSR(res=(psr_res, psr_res, psr_res),
|
|
||||||
sig= psr_sigma).to(device)
|
|
||||||
|
|
||||||
|
dpsr = DPSR(res=(psr_res, psr_res, psr_res), sig=psr_sigma).to(device)
|
||||||
|
|
||||||
generator = generation.Generator3D(
|
generator = generation.Generator3D(
|
||||||
model,
|
model,
|
||||||
device=device,
|
device=device,
|
||||||
threshold=cfg['data']['zero_level'],
|
threshold=cfg["data"]["zero_level"],
|
||||||
sample=cfg['generation']['use_sampling'],
|
sample=cfg["generation"]["use_sampling"],
|
||||||
input_type = cfg['data']['input_type'],
|
input_type=cfg["data"]["input_type"],
|
||||||
padding=cfg['data']['padding'],
|
padding=cfg["data"]["padding"],
|
||||||
dpsr=dpsr,
|
dpsr=dpsr,
|
||||||
psr_tanh=cfg['model']['psr_tanh']
|
psr_tanh=cfg["model"]["psr_tanh"],
|
||||||
)
|
)
|
||||||
return generator
|
return generator
|
||||||
|
|
||||||
|
|
||||||
# Datasets
|
# Datasets
|
||||||
def get_dataset(mode, cfg, return_idx=False):
|
def get_dataset(mode, cfg, return_idx=False):
|
||||||
''' Returns the dataset.
|
"""Returns the dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): the model which is used
|
model (nn.Module): the model which is used
|
||||||
cfg (dict): config dictionary
|
cfg (dict): config dictionary
|
||||||
return_idx (bool): whether to include an ID field
|
return_idx (bool): whether to include an ID field
|
||||||
'''
|
"""
|
||||||
dataset_type = cfg['data']['dataset']
|
dataset_type = cfg["data"]["dataset"]
|
||||||
dataset_folder = cfg['data']['path']
|
dataset_folder = cfg["data"]["path"]
|
||||||
categories = cfg['data']['class']
|
categories = cfg["data"]["class"]
|
||||||
|
|
||||||
# Get split
|
# Get split
|
||||||
splits = {
|
splits = {
|
||||||
'train': cfg['data']['train_split'],
|
"train": cfg["data"]["train_split"],
|
||||||
'val': cfg['data']['val_split'],
|
"val": cfg["data"]["val_split"],
|
||||||
'test': cfg['data']['test_split'],
|
"test": cfg["data"]["test_split"],
|
||||||
'vis': cfg['data']['val_split'],
|
"vis": cfg["data"]["val_split"],
|
||||||
}
|
}
|
||||||
|
|
||||||
split = splits[mode]
|
split = splits[mode]
|
||||||
|
|
||||||
# Create dataset
|
# Create dataset
|
||||||
if dataset_type == 'Shapes3D':
|
if dataset_type == "Shapes3D":
|
||||||
fields = get_data_fields(mode, cfg)
|
fields = get_data_fields(mode, cfg)
|
||||||
# Input fields
|
# Input fields
|
||||||
inputs_field = get_inputs_field(mode, cfg)
|
inputs_field = get_inputs_field(mode, cfg)
|
||||||
if inputs_field is not None:
|
if inputs_field is not None:
|
||||||
fields['inputs'] = inputs_field
|
fields["inputs"] = inputs_field
|
||||||
|
|
||||||
if return_idx:
|
if return_idx:
|
||||||
fields['idx'] = data.IndexField()
|
fields["idx"] = data.IndexField()
|
||||||
|
|
||||||
dataset = data.Shapes3dDataset(
|
dataset = data.Shapes3dDataset(
|
||||||
dataset_folder, fields,
|
dataset_folder,
|
||||||
|
fields,
|
||||||
split=split,
|
split=split,
|
||||||
categories=categories,
|
categories=categories,
|
||||||
cfg = cfg
|
cfg=cfg,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Invalid dataset "%s"' % cfg['data']['dataset'])
|
raise ValueError('Invalid dataset "%s"' % cfg["data"]["dataset"])
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def get_inputs_field(mode, cfg):
|
def get_inputs_field(mode, cfg):
|
||||||
''' Returns the inputs fields.
|
"""Returns the inputs fields.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mode (str): the mode which is used
|
mode (str): the mode which is used
|
||||||
cfg (dict): config dictionary
|
cfg (dict): config dictionary
|
||||||
'''
|
"""
|
||||||
input_type = cfg['data']['input_type']
|
input_type = cfg["data"]["input_type"]
|
||||||
|
|
||||||
if input_type is None:
|
if input_type is None:
|
||||||
inputs_field = None
|
inputs_field = None
|
||||||
elif input_type == 'pointcloud':
|
elif input_type == "pointcloud":
|
||||||
noise_level = cfg['data']['pointcloud_noise']
|
noise_level = cfg["data"]["pointcloud_noise"]
|
||||||
if cfg['data']['pointcloud_outlier_ratio']>0:
|
if cfg["data"]["pointcloud_outlier_ratio"] > 0:
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose(
|
||||||
data.SubsamplePointcloud(cfg['data']['pointcloud_n']),
|
[
|
||||||
|
data.SubsamplePointcloud(cfg["data"]["pointcloud_n"]),
|
||||||
data.PointcloudNoise(noise_level),
|
data.PointcloudNoise(noise_level),
|
||||||
data.PointcloudOutliers(cfg['data']['pointcloud_outlier_ratio'])
|
data.PointcloudOutliers(cfg["data"]["pointcloud_outlier_ratio"]),
|
||||||
])
|
],
|
||||||
else:
|
|
||||||
transform = transforms.Compose([
|
|
||||||
data.SubsamplePointcloud(cfg['data']['pointcloud_n']),
|
|
||||||
data.PointcloudNoise(noise_level)
|
|
||||||
])
|
|
||||||
|
|
||||||
data_type = cfg['data']['data_type']
|
|
||||||
inputs_field = data.PointCloudField(
|
|
||||||
cfg['data']['pointcloud_file'], data_type, transform,
|
|
||||||
multi_files= cfg['data']['multi_files']
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
transform = transforms.Compose(
|
||||||
'Invalid input type (%s)' % input_type)
|
[
|
||||||
|
data.SubsamplePointcloud(cfg["data"]["pointcloud_n"]),
|
||||||
|
data.PointcloudNoise(noise_level),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
data_type = cfg["data"]["data_type"]
|
||||||
|
inputs_field = data.PointCloudField(
|
||||||
|
cfg["data"]["pointcloud_file"],
|
||||||
|
data_type,
|
||||||
|
transform,
|
||||||
|
multi_files=cfg["data"]["multi_files"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid input type (%s)" % input_type)
|
||||||
return inputs_field
|
return inputs_field
|
||||||
|
|
||||||
|
|
||||||
def get_data_fields(mode, cfg):
|
def get_data_fields(mode, cfg):
|
||||||
''' Returns the data fields.
|
"""Returns the data fields.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mode (str): the mode which is used
|
mode (str): the mode which is used
|
||||||
cfg (dict): imported yaml config
|
cfg (dict): imported yaml config
|
||||||
'''
|
"""
|
||||||
data_type = cfg['data']['data_type']
|
data_type = cfg["data"]["data_type"]
|
||||||
fields = {}
|
fields = {}
|
||||||
|
|
||||||
if (mode in ('val', 'test')):
|
if mode in ("val", "test"):
|
||||||
transform = data.SubsamplePointcloud(100000)
|
transform = data.SubsamplePointcloud(100000)
|
||||||
else:
|
else:
|
||||||
transform = data.SubsamplePointcloud(cfg['data']['num_gt_points'])
|
transform = data.SubsamplePointcloud(cfg["data"]["num_gt_points"])
|
||||||
|
|
||||||
data_name = cfg['data']['pointcloud_file']
|
data_name = cfg["data"]["pointcloud_file"]
|
||||||
fields['gt_points'] = data.PointCloudField(data_name,
|
fields["gt_points"] = data.PointCloudField(
|
||||||
transform=transform, data_type=data_type, multi_files=cfg['data']['multi_files'])
|
data_name, transform=transform, data_type=data_type, multi_files=cfg["data"]["multi_files"],
|
||||||
if data_type == 'psr_full':
|
)
|
||||||
if mode != 'test':
|
if data_type == "psr_full":
|
||||||
fields['gt_psr'] = data.FullPSRField(multi_files=cfg['data']['multi_files'])
|
if mode != "test":
|
||||||
|
fields["gt_psr"] = data.FullPSRField(multi_files=cfg["data"]["multi_files"])
|
||||||
else:
|
else:
|
||||||
raise ValueError('Invalid data type (%s)' % data_type)
|
raise ValueError("Invalid data type (%s)" % data_type)
|
||||||
|
|
||||||
return fields
|
return fields
|
|
@ -1,14 +1,12 @@
|
||||||
|
|
||||||
from src.data.core import (
|
from src.data.core import (
|
||||||
Shapes3dDataset, collate_remove_none, worker_init_fn, collate_stack_together
|
Shapes3dDataset,
|
||||||
)
|
collate_remove_none,
|
||||||
from src.data.fields import (
|
collate_stack_together,
|
||||||
IndexField, PointCloudField, FullPSRField
|
worker_init_fn,
|
||||||
)
|
|
||||||
from src.data.transforms import (
|
|
||||||
PointcloudNoise, SubsamplePointcloud,
|
|
||||||
PointcloudOutliers,
|
|
||||||
)
|
)
|
||||||
|
from src.data.fields import FullPSRField, IndexField, PointCloudField
|
||||||
|
from src.data.transforms import PointcloudNoise, PointcloudOutliers, SubsamplePointcloud
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Core
|
# Core
|
||||||
Shapes3dDataset,
|
Shapes3dDataset,
|
||||||
|
|
132
src/data/core.py
132
src/data/core.py
|
@ -1,45 +1,41 @@
|
||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
from torch.utils import data
|
import os
|
||||||
from pdb import set_trace as st
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
|
from torch.utils import data
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Fields
|
# Fields
|
||||||
class Field(object):
|
class Field:
|
||||||
''' Data fields class.
|
"""Data fields class."""
|
||||||
'''
|
|
||||||
|
|
||||||
def load(self, data_path, idx, category):
|
def load(self, data_path, idx, category):
|
||||||
''' Loads a data point.
|
"""Loads a data point.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data_path (str): path to data file
|
data_path (str): path to data file
|
||||||
idx (int): index of data point
|
idx (int): index of data point
|
||||||
category (int): index of category
|
category (int): index of category
|
||||||
'''
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def check_complete(self, files):
|
def check_complete(self, files):
|
||||||
''' Checks if set is complete.
|
"""Checks if set is complete.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
files: files
|
files: files
|
||||||
'''
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class Shapes3dDataset(data.Dataset):
|
class Shapes3dDataset(data.Dataset):
|
||||||
''' 3D Shapes dataset class.
|
"""3D Shapes dataset class."""
|
||||||
'''
|
|
||||||
|
|
||||||
def __init__(self, dataset_folder, fields, split=None,
|
def __init__(self, dataset_folder, fields, split=None, categories=None, no_except=True, transform=None, cfg=None):
|
||||||
categories=None, no_except=True, transform=None, cfg=None):
|
"""Initialization of the the 3D shape dataset.
|
||||||
''' Initialization of the the 3D shape dataset.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_folder (str): dataset folder
|
dataset_folder (str): dataset folder
|
||||||
|
@ -49,7 +45,7 @@ class Shapes3dDataset(data.Dataset):
|
||||||
no_except (bool): no exception
|
no_except (bool): no exception
|
||||||
transform (callable): transformation applied to data points
|
transform (callable): transformation applied to data points
|
||||||
cfg (yaml): config file
|
cfg (yaml): config file
|
||||||
'''
|
"""
|
||||||
# Attributes
|
# Attributes
|
||||||
self.dataset_folder = dataset_folder
|
self.dataset_folder = dataset_folder
|
||||||
self.fields = fields
|
self.fields = fields
|
||||||
|
@ -60,76 +56,69 @@ class Shapes3dDataset(data.Dataset):
|
||||||
# If categories is None, use all subfolders
|
# If categories is None, use all subfolders
|
||||||
if categories is None:
|
if categories is None:
|
||||||
categories = os.listdir(dataset_folder)
|
categories = os.listdir(dataset_folder)
|
||||||
categories = [c for c in categories
|
categories = [c for c in categories if os.path.isdir(os.path.join(dataset_folder, c))]
|
||||||
if os.path.isdir(os.path.join(dataset_folder, c))]
|
|
||||||
|
|
||||||
# Read metadata file
|
# Read metadata file
|
||||||
metadata_file = os.path.join(dataset_folder, 'metadata.yaml')
|
metadata_file = os.path.join(dataset_folder, "metadata.yaml")
|
||||||
|
|
||||||
if os.path.exists(metadata_file):
|
if os.path.exists(metadata_file):
|
||||||
with open(metadata_file, 'r') as f:
|
with open(metadata_file) as f:
|
||||||
self.metadata = yaml.load(f, Loader=yaml.Loader)
|
self.metadata = yaml.load(f, Loader=yaml.Loader)
|
||||||
else:
|
else:
|
||||||
self.metadata = {
|
self.metadata = {c: {"id": c, "name": "n/a"} for c in categories}
|
||||||
c: {'id': c, 'name': 'n/a'} for c in categories
|
|
||||||
}
|
|
||||||
|
|
||||||
# Set index
|
# Set index
|
||||||
for c_idx, c in enumerate(categories):
|
for c_idx, c in enumerate(categories):
|
||||||
self.metadata[c]['idx'] = c_idx
|
self.metadata[c]["idx"] = c_idx
|
||||||
|
|
||||||
# Get all models
|
# Get all models
|
||||||
self.models = []
|
self.models = []
|
||||||
for c_idx, c in enumerate(categories):
|
for c_idx, c in enumerate(categories):
|
||||||
subpath = os.path.join(dataset_folder, c)
|
subpath = os.path.join(dataset_folder, c)
|
||||||
if not os.path.isdir(subpath):
|
if not os.path.isdir(subpath):
|
||||||
logger.warning('Category %s does not exist in dataset.' % c)
|
logger.warning("Category %s does not exist in dataset." % c)
|
||||||
|
|
||||||
if split is None:
|
if split is None:
|
||||||
self.models += [
|
self.models += [
|
||||||
{'category': c, 'model': m} for m in [d for d in os.listdir(subpath) if (os.path.isdir(os.path.join(subpath, d)) and d != '') ]
|
{"category": c, "model": m}
|
||||||
|
for m in [d for d in os.listdir(subpath) if (os.path.isdir(os.path.join(subpath, d)) and d != "")]
|
||||||
]
|
]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
split_file = os.path.join(subpath, split + '.lst')
|
split_file = os.path.join(subpath, split + ".lst")
|
||||||
with open(split_file, 'r') as f:
|
with open(split_file) as f:
|
||||||
models_c = f.read().split('\n')
|
models_c = f.read().split("\n")
|
||||||
|
|
||||||
if '' in models_c:
|
if "" in models_c:
|
||||||
models_c.remove('')
|
models_c.remove("")
|
||||||
|
|
||||||
self.models += [
|
self.models += [{"category": c, "model": m} for m in models_c]
|
||||||
{'category': c, 'model': m}
|
|
||||||
for m in models_c
|
|
||||||
]
|
|
||||||
|
|
||||||
# precompute
|
# precompute
|
||||||
self.split = split
|
self.split = split
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
''' Returns the length of the dataset.
|
"""Returns the length of the dataset."""
|
||||||
'''
|
|
||||||
return len(self.models)
|
return len(self.models)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
''' Returns an item of the dataset.
|
"""Returns an item of the dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
idx (int): ID of data point
|
idx (int): ID of data point
|
||||||
'''
|
"""
|
||||||
|
category = self.models[idx]["category"]
|
||||||
category = self.models[idx]['category']
|
model = self.models[idx]["model"]
|
||||||
model = self.models[idx]['model']
|
c_idx = self.metadata[category]["idx"]
|
||||||
c_idx = self.metadata[category]['idx']
|
|
||||||
|
|
||||||
model_path = os.path.join(self.dataset_folder, category, model)
|
model_path = os.path.join(self.dataset_folder, category, model)
|
||||||
data = {}
|
data = {}
|
||||||
|
|
||||||
info = c_idx
|
info = c_idx
|
||||||
|
|
||||||
if self.cfg['data']['multi_files'] is not None:
|
if self.cfg["data"]["multi_files"] is not None:
|
||||||
idx = np.random.randint(self.cfg['data']['multi_files'])
|
idx = np.random.randint(self.cfg["data"]["multi_files"])
|
||||||
if self.split != 'train':
|
if self.split != "train":
|
||||||
idx = 0
|
idx = 0
|
||||||
|
|
||||||
for field_name, field in self.fields.items():
|
for field_name, field in self.fields.items():
|
||||||
|
@ -137,9 +126,8 @@ class Shapes3dDataset(data.Dataset):
|
||||||
field_data = field.load(model_path, idx, info)
|
field_data = field.load(model_path, idx, info)
|
||||||
except Exception:
|
except Exception:
|
||||||
if self.no_except:
|
if self.no_except:
|
||||||
logger.warn(
|
logger.warning(
|
||||||
'Error occured when loading field %s of model %s'
|
f"Error occured when loading field {field_name} of model {model}",
|
||||||
% (field_name, model)
|
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
|
@ -150,7 +138,7 @@ class Shapes3dDataset(data.Dataset):
|
||||||
if k is None:
|
if k is None:
|
||||||
data[field_name] = v
|
data[field_name] = v
|
||||||
else:
|
else:
|
||||||
data['%s.%s' % (field_name, k)] = v
|
data[f"{field_name}.{k}"] = v
|
||||||
else:
|
else:
|
||||||
data[field_name] = field_data
|
data[field_name] = field_data
|
||||||
|
|
||||||
|
@ -159,77 +147,75 @@ class Shapes3dDataset(data.Dataset):
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def get_model_dict(self, idx):
|
def get_model_dict(self, idx):
|
||||||
return self.models[idx]
|
return self.models[idx]
|
||||||
|
|
||||||
def test_model_complete(self, category, model):
|
def test_model_complete(self, category, model):
|
||||||
''' Tests if model is complete.
|
"""Tests if model is complete.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (str): modelname
|
model (str): modelname
|
||||||
'''
|
"""
|
||||||
model_path = os.path.join(self.dataset_folder, category, model)
|
model_path = os.path.join(self.dataset_folder, category, model)
|
||||||
files = os.listdir(model_path)
|
files = os.listdir(model_path)
|
||||||
for field_name, field in self.fields.items():
|
for field_name, field in self.fields.items():
|
||||||
if not field.check_complete(files):
|
if not field.check_complete(files):
|
||||||
logger.warn('Field "%s" is incomplete: %s'
|
logger.warning(f'Field "{field_name}" is incomplete: {model_path}')
|
||||||
% (field_name, model_path))
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def collate_remove_none(batch):
|
def collate_remove_none(batch):
|
||||||
''' Collater that puts each data field into a tensor with outer dimension
|
"""Collater that puts each data field into a tensor with outer dimension
|
||||||
batch size.
|
batch size.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch: batch
|
batch: batch
|
||||||
'''
|
"""
|
||||||
|
|
||||||
batch = list(filter(lambda x: x is not None, batch))
|
batch = list(filter(lambda x: x is not None, batch))
|
||||||
return data.dataloader.default_collate(batch)
|
return data.dataloader.default_collate(batch)
|
||||||
|
|
||||||
|
|
||||||
def collate_stack_together(batch):
|
def collate_stack_together(batch):
|
||||||
''' Collater that puts each data field into a tensor with outer dimension
|
"""Collater that puts each data field into a tensor with outer dimension
|
||||||
batch size.
|
batch size.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch: batch
|
batch: batch
|
||||||
'''
|
"""
|
||||||
|
|
||||||
batch = list(filter(lambda x: x is not None, batch))
|
batch = list(filter(lambda x: x is not None, batch))
|
||||||
keys = batch[0].keys()
|
keys = batch[0].keys()
|
||||||
concat = {}
|
concat = {}
|
||||||
if len(batch)>1:
|
if len(batch) > 1:
|
||||||
for key in keys:
|
for key in keys:
|
||||||
key_val = [item[key] for item in batch]
|
key_val = [item[key] for item in batch]
|
||||||
concat[key] = np.concatenate(key_val, axis=0)
|
concat[key] = np.concatenate(key_val, axis=0)
|
||||||
if key == 'inputs':
|
if key == "inputs":
|
||||||
n_pts = [item[key].shape[0] for item in batch]
|
n_pts = [item[key].shape[0] for item in batch]
|
||||||
|
|
||||||
concat['batch_ind'] = np.concatenate(
|
concat["batch_ind"] = np.concatenate([i * np.ones(n, dtype=int) for i, n in enumerate(n_pts)], axis=0)
|
||||||
[i * np.ones(n, dtype=int) for i, n in enumerate(n_pts)], axis=0)
|
|
||||||
|
|
||||||
return data.dataloader.default_collate([concat])
|
return data.dataloader.default_collate([concat])
|
||||||
else:
|
else:
|
||||||
n_pts = batch[0]['inputs'].shape[0]
|
n_pts = batch[0]["inputs"].shape[0]
|
||||||
batch[0]['batch_ind'] = np.zeros(n_pts, dtype=int)
|
batch[0]["batch_ind"] = np.zeros(n_pts, dtype=int)
|
||||||
return data.dataloader.default_collate(batch)
|
return data.dataloader.default_collate(batch)
|
||||||
|
|
||||||
|
|
||||||
def worker_init_fn(worker_id):
|
def worker_init_fn(worker_id):
|
||||||
''' Worker init function to ensure true randomness.
|
"""Worker init function to ensure true randomness."""
|
||||||
'''
|
|
||||||
def set_num_threads(nt):
|
def set_num_threads(nt):
|
||||||
try:
|
try:
|
||||||
import mkl; mkl.set_num_threads(nt)
|
import mkl
|
||||||
|
|
||||||
|
mkl.set_num_threads(nt)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
os.environ['IPC_ENABLE']='1'
|
os.environ["IPC_ENABLE"] = "1"
|
||||||
for o in ['OPENBLAS_NUM_THREADS','NUMEXPR_NUM_THREADS','OMP_NUM_THREADS','MKL_NUM_THREADS']:
|
for o in ["OPENBLAS_NUM_THREADS", "NUMEXPR_NUM_THREADS", "OMP_NUM_THREADS", "MKL_NUM_THREADS"]:
|
||||||
os.environ[o] = str(nt)
|
os.environ[o] = str(nt)
|
||||||
|
|
||||||
random_data = os.urandom(4)
|
random_data = os.urandom(4)
|
||||||
|
|
|
@ -1,34 +1,32 @@
|
||||||
import os
|
import os
|
||||||
import glob
|
|
||||||
import time
|
|
||||||
import random
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import trimesh
|
|
||||||
from src.data.core import Field
|
from src.data.core import Field
|
||||||
from pdb import set_trace as st
|
|
||||||
|
|
||||||
|
|
||||||
class IndexField(Field):
|
class IndexField(Field):
|
||||||
''' Basic index field.'''
|
"""Basic index field."""
|
||||||
|
|
||||||
def load(self, model_path, idx, category):
|
def load(self, model_path, idx, category):
|
||||||
''' Loads the index field.
|
"""Loads the index field.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path (str): path to model
|
model_path (str): path to model
|
||||||
idx (int): ID of data point
|
idx (int): ID of data point
|
||||||
category (int): index of category
|
category (int): index of category
|
||||||
'''
|
"""
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
def check_complete(self, files):
|
def check_complete(self, files):
|
||||||
''' Check if field is complete.
|
"""Check if field is complete.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
files: files
|
files: files
|
||||||
'''
|
"""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class FullPSRField(Field):
|
class FullPSRField(Field):
|
||||||
def __init__(self, transform=None, multi_files=None):
|
def __init__(self, transform=None, multi_files=None):
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
@ -36,16 +34,15 @@ class FullPSRField(Field):
|
||||||
self.multi_files = multi_files
|
self.multi_files = multi_files
|
||||||
|
|
||||||
def load(self, model_path, idx, category):
|
def load(self, model_path, idx, category):
|
||||||
|
|
||||||
# try:
|
# try:
|
||||||
# t0 = time.time()
|
# t0 = time.time()
|
||||||
if self.multi_files is not None:
|
if self.multi_files is not None:
|
||||||
psr_path = os.path.join(model_path, 'psr', 'psr_{:02d}.npz'.format(idx))
|
psr_path = os.path.join(model_path, "psr", f"psr_{idx:02d}.npz")
|
||||||
else:
|
else:
|
||||||
psr_path = os.path.join(model_path, 'psr.npz')
|
psr_path = os.path.join(model_path, "psr.npz")
|
||||||
psr_dict = np.load(psr_path)
|
psr_dict = np.load(psr_path)
|
||||||
# t1 = time.time()
|
# t1 = time.time()
|
||||||
psr = psr_dict['psr']
|
psr = psr_dict["psr"]
|
||||||
psr = psr.astype(np.float32)
|
psr = psr.astype(np.float32)
|
||||||
# t2 = time.time()
|
# t2 = time.time()
|
||||||
# print('load PSR: {:.4f}, change type: {:.4f}, total: {:.4f}'.format(t1 - t0, t2 - t1, t2-t0))
|
# print('load PSR: {:.4f}, change type: {:.4f}, total: {:.4f}'.format(t1 - t0, t2 - t1, t2-t0))
|
||||||
|
@ -56,8 +53,9 @@ class FullPSRField(Field):
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class PointCloudField(Field):
|
class PointCloudField(Field):
|
||||||
''' Point cloud field.
|
"""Point cloud field.
|
||||||
|
|
||||||
It provides the field used for point cloud data. These are the points
|
It provides the field used for point cloud data. These are the points
|
||||||
randomly sampled on the mesh.
|
randomly sampled on the mesh.
|
||||||
|
@ -66,7 +64,8 @@ class PointCloudField(Field):
|
||||||
file_name (str): file name
|
file_name (str): file name
|
||||||
transform (list): list of transformations applied to data points
|
transform (list): list of transformations applied to data points
|
||||||
multi_files (callable): number of files
|
multi_files (callable): number of files
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, file_name, data_type=None, transform=None, multi_files=None, padding=0.1, scale=1.2):
|
def __init__(self, file_name, data_type=None, transform=None, multi_files=None, padding=0.1, scale=1.2):
|
||||||
self.file_name = file_name
|
self.file_name = file_name
|
||||||
self.data_type = data_type # to make sure the range of input is correct
|
self.data_type = data_type # to make sure the range of input is correct
|
||||||
|
@ -76,43 +75,43 @@ class PointCloudField(Field):
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
def load(self, model_path, idx, category):
|
def load(self, model_path, idx, category):
|
||||||
''' Loads the data point.
|
"""Loads the data point.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path (str): path to model
|
model_path (str): path to model
|
||||||
idx (int): ID of data point
|
idx (int): ID of data point
|
||||||
category (int): index of category
|
category (int): index of category
|
||||||
'''
|
"""
|
||||||
if self.multi_files is None:
|
if self.multi_files is None:
|
||||||
file_path = os.path.join(model_path, self.file_name)
|
file_path = os.path.join(model_path, self.file_name)
|
||||||
else:
|
else:
|
||||||
# num = np.random.randint(self.multi_files)
|
# num = np.random.randint(self.multi_files)
|
||||||
# file_path = os.path.join(model_path, self.file_name, '%s_%02d.npz' % (self.file_name, num))
|
# file_path = os.path.join(model_path, self.file_name, '%s_%02d.npz' % (self.file_name, num))
|
||||||
file_path = os.path.join(model_path, self.file_name, 'pointcloud_%02d.npz' % (idx))
|
file_path = os.path.join(model_path, self.file_name, "pointcloud_%02d.npz" % (idx))
|
||||||
|
|
||||||
pointcloud_dict = np.load(file_path)
|
pointcloud_dict = np.load(file_path)
|
||||||
|
|
||||||
points = pointcloud_dict['points'].astype(np.float32)
|
points = pointcloud_dict["points"].astype(np.float32)
|
||||||
normals = pointcloud_dict['normals'].astype(np.float32)
|
normals = pointcloud_dict["normals"].astype(np.float32)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
None: points,
|
None: points,
|
||||||
'normals': normals,
|
"normals": normals,
|
||||||
}
|
}
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
data = self.transform(data)
|
data = self.transform(data)
|
||||||
|
|
||||||
if self.data_type == 'psr_full':
|
if self.data_type == "psr_full":
|
||||||
# scale the point cloud to the range of (0, 1)
|
# scale the point cloud to the range of (0, 1)
|
||||||
data[None] = data[None] / self.scale + 0.5
|
data[None] = data[None] / self.scale + 0.5
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def check_complete(self, files):
|
def check_complete(self, files):
|
||||||
''' Check if field is complete.
|
"""Check if field is complete.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
files: files
|
files: files
|
||||||
'''
|
"""
|
||||||
complete = (self.file_name in files)
|
complete = self.file_name in files
|
||||||
return complete
|
return complete
|
||||||
|
|
|
@ -2,24 +2,24 @@ import numpy as np
|
||||||
|
|
||||||
|
|
||||||
# Transforms
|
# Transforms
|
||||||
class PointcloudNoise(object):
|
class PointcloudNoise:
|
||||||
''' Point cloud noise transformation class.
|
"""Point cloud noise transformation class.
|
||||||
|
|
||||||
It adds noise to point cloud data.
|
It adds noise to point cloud data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stddev (int): standard deviation
|
stddev (int): standard deviation
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, stddev):
|
def __init__(self, stddev):
|
||||||
self.stddev = stddev
|
self.stddev = stddev
|
||||||
|
|
||||||
def __call__(self, data):
|
def __call__(self, data):
|
||||||
''' Calls the transformation.
|
"""Calls the transformation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (dictionary): data dictionary
|
data (dictionary): data dictionary
|
||||||
'''
|
"""
|
||||||
data_out = data.copy()
|
data_out = data.copy()
|
||||||
points = data[None]
|
points = data[None]
|
||||||
noise = self.stddev * np.random.randn(*points.shape)
|
noise = self.stddev * np.random.randn(*points.shape)
|
||||||
|
@ -27,28 +27,29 @@ class PointcloudNoise(object):
|
||||||
data_out[None] = points + noise
|
data_out[None] = points + noise
|
||||||
return data_out
|
return data_out
|
||||||
|
|
||||||
class PointcloudOutliers(object):
|
|
||||||
''' Point cloud outlier transformation class.
|
class PointcloudOutliers:
|
||||||
|
"""Point cloud outlier transformation class.
|
||||||
|
|
||||||
It adds outliers to point cloud data.
|
It adds outliers to point cloud data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ratio (int): outlier percentage to the entire point cloud
|
ratio (int): outlier percentage to the entire point cloud
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, ratio):
|
def __init__(self, ratio):
|
||||||
self.ratio = ratio
|
self.ratio = ratio
|
||||||
|
|
||||||
def __call__(self, data):
|
def __call__(self, data):
|
||||||
''' Calls the transformation.
|
"""Calls the transformation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (dictionary): data dictionary
|
data (dictionary): data dictionary
|
||||||
'''
|
"""
|
||||||
data_out = data.copy()
|
data_out = data.copy()
|
||||||
points = data[None]
|
points = data[None]
|
||||||
n_points = points.shape[0]
|
n_points = points.shape[0]
|
||||||
n_outlier_points = int(n_points*self.ratio)
|
n_outlier_points = int(n_points * self.ratio)
|
||||||
ind = np.random.randint(0, n_points, n_outlier_points)
|
ind = np.random.randint(0, n_points, n_outlier_points)
|
||||||
|
|
||||||
outliers = np.random.uniform(-0.55, 0.55, (n_outlier_points, 3))
|
outliers = np.random.uniform(-0.55, 0.55, (n_outlier_points, 3))
|
||||||
|
@ -57,30 +58,32 @@ class PointcloudOutliers(object):
|
||||||
data_out[None] = points
|
data_out[None] = points
|
||||||
return data_out
|
return data_out
|
||||||
|
|
||||||
class SubsamplePointcloud(object):
|
|
||||||
''' Point cloud subsampling transformation class.
|
class SubsamplePointcloud:
|
||||||
|
"""Point cloud subsampling transformation class.
|
||||||
|
|
||||||
It subsamples the point cloud data.
|
It subsamples the point cloud data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
N (int): number of points to be subsampled
|
N (int): number of points to be subsampled
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, N):
|
def __init__(self, N):
|
||||||
self.N = N
|
self.N = N
|
||||||
|
|
||||||
def __call__(self, data):
|
def __call__(self, data):
|
||||||
''' Calls the transformation.
|
"""Calls the transformation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (dict): data dictionary
|
data (dict): data dictionary
|
||||||
'''
|
"""
|
||||||
data_out = data.copy()
|
data_out = data.copy()
|
||||||
points = data[None]
|
points = data[None]
|
||||||
|
|
||||||
indices = np.random.randint(points.shape[0], size=self.N)
|
indices = np.random.randint(points.shape[0], size=self.N)
|
||||||
data_out[None] = points[indices, :]
|
data_out[None] = points[indices, :]
|
||||||
if 'normals' in data.keys():
|
if "normals" in data.keys():
|
||||||
normals = data['normals']
|
normals = data["normals"]
|
||||||
data_out['normals'] = normals[indices, :]
|
data_out["normals"] = normals[indices, :]
|
||||||
|
|
||||||
return data_out
|
return data_out
|
|
@ -1,17 +1,20 @@
|
||||||
import os
|
import os
|
||||||
import cv2
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from torch.utils import data
|
|
||||||
from src.utils import load_rgb, load_mask, get_camera_params
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
from pytorch3d.renderer import PerspectiveCameras
|
from pytorch3d.renderer import PerspectiveCameras
|
||||||
from skimage import img_as_float32
|
from skimage import img_as_float32
|
||||||
|
from torch.utils import data
|
||||||
|
|
||||||
|
from src.utils import get_camera_params, load_mask, load_rgb
|
||||||
|
|
||||||
##################################################
|
##################################################
|
||||||
# Below are for the differentiable renderer
|
# Below are for the differentiable renderer
|
||||||
# Taken from https://github.com/lioryariv/idr/blob/main/code/utils/rend_util.py
|
# Taken from https://github.com/lioryariv/idr/blob/main/code/utils/rend_util.py
|
||||||
|
|
||||||
|
|
||||||
def load_rgb(path):
|
def load_rgb(path):
|
||||||
img = imageio.imread(path)
|
img = imageio.imread(path)
|
||||||
img = img_as_float32(img)
|
img = img_as_float32(img)
|
||||||
|
@ -23,6 +26,7 @@ def load_rgb(path):
|
||||||
# img = img.transpose(2, 0, 1)
|
# img = img.transpose(2, 0, 1)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
def load_mask(path):
|
def load_mask(path):
|
||||||
alpha = imageio.imread(path, as_gray=True)
|
alpha = imageio.imread(path, as_gray=True)
|
||||||
alpha = img_as_float32(alpha)
|
alpha = img_as_float32(alpha)
|
||||||
|
@ -32,10 +36,10 @@ def load_mask(path):
|
||||||
|
|
||||||
|
|
||||||
def get_camera_params(uv, pose, intrinsics):
|
def get_camera_params(uv, pose, intrinsics):
|
||||||
if pose.shape[1] == 7: #In case of quaternion vector representation
|
if pose.shape[1] == 7: # In case of quaternion vector representation
|
||||||
cam_loc = pose[:, 4:]
|
cam_loc = pose[:, 4:]
|
||||||
R = quat_to_rot(pose[:,:4])
|
R = quat_to_rot(pose[:, :4])
|
||||||
p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float()
|
p = torch.eye(4).repeat(pose.shape[0], 1, 1).cuda().float()
|
||||||
p[:, :3, :3] = R
|
p[:, :3, :3] = R
|
||||||
p[:, :3, 3] = cam_loc
|
p[:, :3, 3] = cam_loc
|
||||||
else: # In case of pose matrix representation
|
else: # In case of pose matrix representation
|
||||||
|
@ -60,25 +64,27 @@ def get_camera_params(uv, pose, intrinsics):
|
||||||
|
|
||||||
return ray_dirs, cam_loc
|
return ray_dirs, cam_loc
|
||||||
|
|
||||||
|
|
||||||
def quat_to_rot(q):
|
def quat_to_rot(q):
|
||||||
batch_size, _ = q.shape
|
batch_size, _ = q.shape
|
||||||
q = F.normalize(q, dim=1)
|
q = F.normalize(q, dim=1)
|
||||||
R = torch.ones((batch_size, 3,3)).cuda()
|
R = torch.ones((batch_size, 3, 3)).cuda()
|
||||||
qr=q[:,0]
|
qr = q[:, 0]
|
||||||
qi = q[:, 1]
|
qi = q[:, 1]
|
||||||
qj = q[:, 2]
|
qj = q[:, 2]
|
||||||
qk = q[:, 3]
|
qk = q[:, 3]
|
||||||
R[:, 0, 0]=1-2 * (qj**2 + qk**2)
|
R[:, 0, 0] = 1 - 2 * (qj**2 + qk**2)
|
||||||
R[:, 0, 1] = 2 * (qj *qi -qk*qr)
|
R[:, 0, 1] = 2 * (qj * qi - qk * qr)
|
||||||
R[:, 0, 2] = 2 * (qi * qk + qr * qj)
|
R[:, 0, 2] = 2 * (qi * qk + qr * qj)
|
||||||
R[:, 1, 0] = 2 * (qj * qi + qk * qr)
|
R[:, 1, 0] = 2 * (qj * qi + qk * qr)
|
||||||
R[:, 1, 1] = 1-2 * (qi**2 + qk**2)
|
R[:, 1, 1] = 1 - 2 * (qi**2 + qk**2)
|
||||||
R[:, 1, 2] = 2*(qj*qk - qi*qr)
|
R[:, 1, 2] = 2 * (qj * qk - qi * qr)
|
||||||
R[:, 2, 0] = 2 * (qk * qi-qj * qr)
|
R[:, 2, 0] = 2 * (qk * qi - qj * qr)
|
||||||
R[:, 2, 1] = 2 * (qj*qk + qi*qr)
|
R[:, 2, 1] = 2 * (qj * qk + qi * qr)
|
||||||
R[:, 2, 2] = 1-2 * (qi**2 + qj**2)
|
R[:, 2, 2] = 1 - 2 * (qi**2 + qj**2)
|
||||||
return R
|
return R
|
||||||
|
|
||||||
|
|
||||||
def lift(x, y, z, intrinsics):
|
def lift(x, y, z, intrinsics):
|
||||||
# parse intrinsics
|
# parse intrinsics
|
||||||
# intrinsics = intrinsics.cuda()
|
# intrinsics = intrinsics.cuda()
|
||||||
|
@ -88,7 +94,16 @@ def lift(x, y, z, intrinsics):
|
||||||
cy = intrinsics[:, 1, 2]
|
cy = intrinsics[:, 1, 2]
|
||||||
sk = intrinsics[:, 0, 1]
|
sk = intrinsics[:, 0, 1]
|
||||||
|
|
||||||
x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z
|
x_lift = (
|
||||||
|
(
|
||||||
|
x
|
||||||
|
- cx.unsqueeze(-1)
|
||||||
|
+ cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1)
|
||||||
|
- sk.unsqueeze(-1) * y / fy.unsqueeze(-1)
|
||||||
|
)
|
||||||
|
/ fx.unsqueeze(-1)
|
||||||
|
* z
|
||||||
|
)
|
||||||
y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z
|
y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z
|
||||||
|
|
||||||
# homogeneous
|
# homogeneous
|
||||||
|
@ -96,21 +111,18 @@ def lift(x, y, z, intrinsics):
|
||||||
|
|
||||||
|
|
||||||
class PixelNeRFDTUDataset(data.Dataset):
|
class PixelNeRFDTUDataset(data.Dataset):
|
||||||
"""
|
"""Processed DTU from pixelNeRF."""
|
||||||
Processed DTU from pixelNeRF
|
|
||||||
"""
|
def __init__(
|
||||||
def __init__(self,
|
self,
|
||||||
data_dir='data/DTU',
|
data_dir="data/DTU",
|
||||||
scan_id=65,
|
scan_id=65,
|
||||||
img_size=None,
|
img_size=None,
|
||||||
device=None,
|
device=None,
|
||||||
fixed_scale=0,
|
fixed_scale=0,
|
||||||
):
|
):
|
||||||
data_dir = os.path.join(data_dir, "scan{}".format(scan_id))
|
data_dir = os.path.join(data_dir, f"scan{scan_id}")
|
||||||
rgb_paths = [
|
rgb_paths = [x for x in glob(os.path.join(data_dir, "image", "*")) if (x.endswith((".jpg", ".png")))]
|
||||||
x for x in glob(os.path.join(data_dir, "image", "*"))
|
|
||||||
if (x.endswith(".jpg") or x.endswith(".png"))
|
|
||||||
]
|
|
||||||
rgb_paths = sorted(rgb_paths)
|
rgb_paths = sorted(rgb_paths)
|
||||||
mask_paths = sorted(glob(os.path.join(data_dir, "mask", "*.png")))
|
mask_paths = sorted(glob(os.path.join(data_dir, "mask", "*.png")))
|
||||||
if len(mask_paths) == 0:
|
if len(mask_paths) == 0:
|
||||||
|
@ -129,21 +141,18 @@ class PixelNeRFDTUDataset(data.Dataset):
|
||||||
all_T = []
|
all_T = []
|
||||||
|
|
||||||
for idx, (rgb_path, mask_path) in enumerate(zip(rgb_paths, mask_paths)):
|
for idx, (rgb_path, mask_path) in enumerate(zip(rgb_paths, mask_paths)):
|
||||||
|
|
||||||
i = sel_indices[idx]
|
i = sel_indices[idx]
|
||||||
rgb = load_rgb(rgb_path)
|
rgb = load_rgb(rgb_path)
|
||||||
mask = load_mask(mask_path)
|
mask = load_mask(mask_path)
|
||||||
rgb[~mask] = 0.
|
rgb[~mask] = 0.0
|
||||||
rgb = torch.from_numpy(rgb).float().to(device)
|
rgb = torch.from_numpy(rgb).float().to(device)
|
||||||
mask = torch.from_numpy(mask).float().to(device)
|
mask = torch.from_numpy(mask).float().to(device)
|
||||||
x_scale = y_scale = 1.0
|
|
||||||
xy_delta = 0.0
|
|
||||||
|
|
||||||
P = all_cam["world_mat_" + str(i)]
|
P = all_cam["world_mat_" + str(i)]
|
||||||
P = P[:3]
|
P = P[:3]
|
||||||
|
|
||||||
# scale the original shape to really [-0.9, 0.9]
|
# scale the original shape to really [-0.9, 0.9]
|
||||||
if fixed_scale!=0.:
|
if fixed_scale != 0.0:
|
||||||
scale_mat_new = np.eye(4, 4)
|
scale_mat_new = np.eye(4, 4)
|
||||||
scale_mat_new[:3, :3] *= fixed_scale # scale to [-0.9, 0.9]
|
scale_mat_new[:3, :3] *= fixed_scale # scale to [-0.9, 0.9]
|
||||||
P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)] @ scale_mat_new
|
P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)] @ scale_mat_new
|
||||||
|
@ -158,38 +167,34 @@ class PixelNeRFDTUDataset(data.Dataset):
|
||||||
|
|
||||||
########!!!!!
|
########!!!!!
|
||||||
RR = torch.from_numpy(R).permute(1, 0).unsqueeze(0)
|
RR = torch.from_numpy(R).permute(1, 0).unsqueeze(0)
|
||||||
tt = torch.from_numpy(-R@(t[:3] / t[3])).permute(1, 0)
|
tt = torch.from_numpy(-R @ (t[:3] / t[3])).permute(1, 0)
|
||||||
focal = torch.tensor((fx, fy), dtype=torch.float32).unsqueeze(0)
|
focal = torch.tensor((fx, fy), dtype=torch.float32).unsqueeze(0)
|
||||||
pc = torch.tensor((cx, cy), dtype=torch.float32).unsqueeze(0)
|
pc = torch.tensor((cx, cy), dtype=torch.float32).unsqueeze(0)
|
||||||
im_size = (rgb.shape[1], rgb.shape[0])
|
im_size = (rgb.shape[1], rgb.shape[0])
|
||||||
|
|
||||||
# check https://pytorch3d.org/docs/cameras for how to transform from screen to NDC
|
# check https://pytorch3d.org/docs/cameras for how to transform from screen to NDC
|
||||||
s = min(im_size)
|
s = min(im_size)
|
||||||
focal[:, 0] = focal[:, 0] * 2 / (s-1)
|
focal[:, 0] = focal[:, 0] * 2 / (s - 1)
|
||||||
focal[:, 1] = focal[:, 1] * 2 /(s-1)
|
focal[:, 1] = focal[:, 1] * 2 / (s - 1)
|
||||||
pc[:, 0] = -(pc[:, 0] - (im_size[0]-1)/2) * 2 / (s-1)
|
pc[:, 0] = -(pc[:, 0] - (im_size[0] - 1) / 2) * 2 / (s - 1)
|
||||||
pc[:, 1] = -(pc[:, 1] - (im_size[1]-1)/2) * 2 / (s-1)
|
pc[:, 1] = -(pc[:, 1] - (im_size[1] - 1) / 2) * 2 / (s - 1)
|
||||||
|
|
||||||
camera = PerspectiveCameras(focal_length=-focal, principal_point=pc,
|
camera = PerspectiveCameras(focal_length=-focal, principal_point=pc, device=device, R=RR, T=tt)
|
||||||
device=device, R=RR, T=tt)
|
|
||||||
|
|
||||||
# calculate camera rays
|
# calculate camera rays
|
||||||
uv = uv_creation(im_size)[None].float()
|
uv = uv_creation(im_size)[None].float()
|
||||||
pose = np.eye(4, dtype=np.float32)
|
pose = np.eye(4, dtype=np.float32)
|
||||||
pose[:3, :3] = R.transpose()
|
pose[:3, :3] = R.transpose()
|
||||||
pose[:3,3] = (t[:3] / t[3])[:,0]
|
pose[:3, 3] = (t[:3] / t[3])[:, 0]
|
||||||
pose = torch.from_numpy(pose)[None].float()
|
pose = torch.from_numpy(pose)[None].float()
|
||||||
intrinsics = np.eye(4)
|
intrinsics = np.eye(4)
|
||||||
intrinsics[:3, :3] = K
|
intrinsics[:3, :3] = K
|
||||||
intrinsics[0, 1] = 0. #! remove skew for now
|
intrinsics[0, 1] = 0.0 #! remove skew for now
|
||||||
intrinsics = torch.from_numpy(intrinsics)[None].float()
|
intrinsics = torch.from_numpy(intrinsics)[None].float()
|
||||||
|
|
||||||
|
|
||||||
rays, _ = get_camera_params(uv, pose, intrinsics)
|
rays, _ = get_camera_params(uv, pose, intrinsics)
|
||||||
rays = -rays.to(device)
|
rays = -rays.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
all_poses.append(camera)
|
all_poses.append(camera)
|
||||||
all_imgs.append(rgb)
|
all_imgs.append(rgb)
|
||||||
all_masks.append(mask)
|
all_masks.append(mask)
|
||||||
|
@ -198,7 +203,7 @@ class PixelNeRFDTUDataset(data.Dataset):
|
||||||
# only for neural renderer
|
# only for neural renderer
|
||||||
all_K.append(torch.tensor(K).to(device))
|
all_K.append(torch.tensor(K).to(device))
|
||||||
all_R.append(torch.tensor(R).to(device))
|
all_R.append(torch.tensor(R).to(device))
|
||||||
all_T.append(torch.tensor(t[:3]/t[3]).to(device))
|
all_T.append(torch.tensor(t[:3] / t[3]).to(device))
|
||||||
|
|
||||||
all_imgs = torch.stack(all_imgs)
|
all_imgs = torch.stack(all_imgs)
|
||||||
all_masks = torch.stack(all_masks)
|
all_masks = torch.stack(all_masks)
|
||||||
|
@ -210,15 +215,16 @@ class PixelNeRFDTUDataset(data.Dataset):
|
||||||
all_T = torch.stack(all_T).permute(0, 2, 1).float()
|
all_T = torch.stack(all_T).permute(0, 2, 1).float()
|
||||||
|
|
||||||
uv = uv_creation((all_imgs.size(2), all_imgs.size(1)))
|
uv = uv_creation((all_imgs.size(2), all_imgs.size(1)))
|
||||||
self.data = {'rgbs': all_imgs,
|
self.data = {
|
||||||
'masks': all_masks,
|
"rgbs": all_imgs,
|
||||||
'poses': all_poses,
|
"masks": all_masks,
|
||||||
'rays': all_rays,
|
"poses": all_poses,
|
||||||
'uv': uv,
|
"rays": all_rays,
|
||||||
'light_pose': all_light_pose, # for rendering lights
|
"uv": uv,
|
||||||
'K': all_K,
|
"light_pose": all_light_pose, # for rendering lights
|
||||||
'R': all_R,
|
"K": all_K,
|
||||||
'T': all_T,
|
"R": all_R,
|
||||||
|
"T": all_T,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
37
src/dpsr.py
37
src/dpsr.py
|
@ -1,17 +1,17 @@
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from src.utils import spec_gaussian_filter, fftfreqs, img, grid_interp, point_rasterize
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
import torch.fft
|
import torch.fft
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from src.utils import fftfreqs, grid_interp, img, point_rasterize, spec_gaussian_filter
|
||||||
|
|
||||||
|
|
||||||
class DPSR(nn.Module):
|
class DPSR(nn.Module):
|
||||||
def __init__(self, res, sig=10, scale=True, shift=True):
|
def __init__(self, res, sig=10, scale=True, shift=True):
|
||||||
"""
|
""":param res: tuple of output field resolution. eg., (128,128)
|
||||||
:param res: tuple of output field resolution. eg., (128,128)
|
|
||||||
:param sig: degree of gaussian smoothing
|
:param sig: degree of gaussian smoothing
|
||||||
"""
|
"""
|
||||||
super(DPSR, self).__init__()
|
super().__init__()
|
||||||
self.res = res
|
self.res = res
|
||||||
self.sig = sig
|
self.sig = sig
|
||||||
self.dim = len(res)
|
self.dim = len(res)
|
||||||
|
@ -24,16 +24,15 @@ class DPSR(nn.Module):
|
||||||
self.register_buffer("G", G)
|
self.register_buffer("G", G)
|
||||||
|
|
||||||
def forward(self, V, N):
|
def forward(self, V, N):
|
||||||
"""
|
""":param V: (batch, nv, 2 or 3) tensor for point cloud coordinates
|
||||||
:param V: (batch, nv, 2 or 3) tensor for point cloud coordinates
|
|
||||||
:param N: (batch, nv, 2 or 3) tensor for point normals
|
:param N: (batch, nv, 2 or 3) tensor for point normals
|
||||||
:return phi: (batch, res, res, ...) tensor of output indicator function field
|
:return phi: (batch, res, res, ...) tensor of output indicator function field
|
||||||
"""
|
"""
|
||||||
assert(V.shape == N.shape) # [b, nv, ndims]
|
assert V.shape == N.shape # [b, nv, ndims]
|
||||||
ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2]
|
ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2]
|
||||||
|
|
||||||
ras_s = torch.fft.rfftn(ras_p, dim=(2,3,4))
|
ras_s = torch.fft.rfftn(ras_p, dim=(2, 3, 4))
|
||||||
ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+1))+[self.dim+1, 1]))
|
ras_s = ras_s.permute(*tuple([0, *list(range(2, self.dim + 1)), self.dim + 1, 1]))
|
||||||
N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1]
|
N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1]
|
||||||
|
|
||||||
omega = fftfreqs(self.res, dtype=torch.float32).unsqueeze(-1) # [dim0, dim1, dim2/2+1, n_dim, 1]
|
omega = fftfreqs(self.res, dtype=torch.float32).unsqueeze(-1) # [dim0, dim1, dim2/2+1, n_dim, 1]
|
||||||
|
@ -43,12 +42,12 @@ class DPSR(nn.Module):
|
||||||
DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2)
|
DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2)
|
||||||
|
|
||||||
Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1]
|
Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1]
|
||||||
Phi = DivN / (Lap+1e-6) # [b, dim0, dim1, dim2/2+1, 2]
|
Phi = DivN / (Lap + 1e-6) # [b, dim0, dim1, dim2/2+1, 2]
|
||||||
Phi = Phi.permute(*tuple([list(range(1,self.dim+2)) + [0]])) # [dim0, dim1, dim2/2+1, 2, b]
|
Phi = Phi.permute(*tuple([[*list(range(1, self.dim + 2)), 0]])) # [dim0, dim1, dim2/2+1, 2, b]
|
||||||
Phi[tuple([0] * self.dim)] = 0
|
Phi[tuple([0] * self.dim)] = 0
|
||||||
Phi = Phi.permute(*tuple([[self.dim+1] + list(range(self.dim+1))])) # [b, dim0, dim1, dim2/2+1, 2]
|
Phi = Phi.permute(*tuple([[self.dim + 1, *list(range(self.dim + 1))]])) # [b, dim0, dim1, dim2/2+1, 2]
|
||||||
|
|
||||||
phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1,2,3))
|
phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1, 2, 3))
|
||||||
|
|
||||||
if self.shift or self.scale:
|
if self.shift or self.scale:
|
||||||
# ensure values at points are zero
|
# ensure values at points are zero
|
||||||
|
@ -57,10 +56,10 @@ class DPSR(nn.Module):
|
||||||
offset = torch.mean(fv, dim=-1) # [b,]
|
offset = torch.mean(fv, dim=-1) # [b,]
|
||||||
phi -= offset.view(*tuple([-1] + [1] * self.dim))
|
phi -= offset.view(*tuple([-1] + [1] * self.dim))
|
||||||
|
|
||||||
phi = phi.permute(*tuple([list(range(1,self.dim+1)) + [0]]))
|
phi = phi.permute(*tuple([[*list(range(1, self.dim + 1)), 0]]))
|
||||||
fv0 = phi[tuple([0] * self.dim)] # [b,]
|
fv0 = phi[tuple([0] * self.dim)] # [b,]
|
||||||
phi = phi.permute(*tuple([[self.dim] + list(range(self.dim))]))
|
phi = phi.permute(*tuple([[self.dim, *list(range(self.dim))]]))
|
||||||
|
|
||||||
if self.scale:
|
if self.scale:
|
||||||
phi = -phi / torch.abs(fv0.view(*tuple([-1]+[1] * self.dim))) *0.5
|
phi = -phi / torch.abs(fv0.view(*tuple([-1] + [1] * self.dim))) * 0.5
|
||||||
return phi
|
return phi
|
126
src/eval.py
126
src/eval.py
|
@ -1,43 +1,45 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import trimesh
|
|
||||||
from pykdtree.kdtree import KDTree
|
from pykdtree.kdtree import KDTree
|
||||||
|
|
||||||
EMPTY_PCL_DICT = {
|
EMPTY_PCL_DICT = {
|
||||||
'completeness': np.sqrt(3),
|
"completeness": np.sqrt(3),
|
||||||
'accuracy': np.sqrt(3),
|
"accuracy": np.sqrt(3),
|
||||||
'completeness2': 3,
|
"completeness2": 3,
|
||||||
'accuracy2': 3,
|
"accuracy2": 3,
|
||||||
'chamfer': 6,
|
"chamfer": 6,
|
||||||
}
|
}
|
||||||
|
|
||||||
EMPTY_PCL_DICT_NORMALS = {
|
EMPTY_PCL_DICT_NORMALS = {
|
||||||
'normals completeness': -1.,
|
"normals completeness": -1.0,
|
||||||
'normals accuracy': -1.,
|
"normals accuracy": -1.0,
|
||||||
'normals': -1.,
|
"normals": -1.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MeshEvaluator(object):
|
class MeshEvaluator:
|
||||||
''' Mesh evaluation class.
|
"""Mesh evaluation class.
|
||||||
It handles the mesh evaluation process.
|
It handles the mesh evaluation process.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_points (int): number of points to be used for evaluation
|
n_points (int): number of points to be used for evaluation.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, n_points=100000):
|
def __init__(self, n_points=100000):
|
||||||
self.n_points = n_points
|
self.n_points = n_points
|
||||||
|
|
||||||
def eval_mesh(self, mesh, pointcloud_tgt, normals_tgt, thresholds=np.linspace(1./1000, 1, 1000)):
|
def eval_mesh(self, mesh, pointcloud_tgt, normals_tgt, thresholds=np.linspace(1.0 / 1000, 1, 1000)):
|
||||||
''' Evaluates a mesh.
|
"""Evaluates a mesh.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mesh (trimesh): mesh which should be evaluated
|
mesh (trimesh): mesh which should be evaluated
|
||||||
pointcloud_tgt (numpy array): target point cloud
|
pointcloud_tgt (numpy array): target point cloud
|
||||||
normals_tgt (numpy array): target normals
|
normals_tgt (numpy array): target normals
|
||||||
thresholds (numpy arry): for F-Score
|
thresholds (numpy arry): for F-Score.
|
||||||
'''
|
"""
|
||||||
if len(mesh.vertices) != 0 and len(mesh.faces) != 0:
|
if len(mesh.vertices) != 0 and len(mesh.faces) != 0:
|
||||||
pointcloud, idx = mesh.sample(self.n_points, return_index=True)
|
pointcloud, idx = mesh.sample(self.n_points, return_index=True)
|
||||||
|
|
||||||
|
@ -47,25 +49,25 @@ class MeshEvaluator(object):
|
||||||
pointcloud = np.empty((0, 3))
|
pointcloud = np.empty((0, 3))
|
||||||
normals = np.empty((0, 3))
|
normals = np.empty((0, 3))
|
||||||
|
|
||||||
out_dict = self.eval_pointcloud(
|
out_dict = self.eval_pointcloud(pointcloud, pointcloud_tgt, normals, normals_tgt, thresholds=thresholds)
|
||||||
pointcloud, pointcloud_tgt, normals, normals_tgt, thresholds=thresholds)
|
|
||||||
|
|
||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
def eval_pointcloud(self, pointcloud, pointcloud_tgt,
|
def eval_pointcloud(
|
||||||
normals=None, normals_tgt=None,
|
self, pointcloud, pointcloud_tgt, normals=None, normals_tgt=None, thresholds=np.linspace(1.0 / 1000, 1, 1000),
|
||||||
thresholds=np.linspace(1./1000, 1, 1000)):
|
):
|
||||||
''' Evaluates a point cloud.
|
"""Evaluates a point cloud.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pointcloud (numpy array): predicted point cloud
|
pointcloud (numpy array): predicted point cloud
|
||||||
pointcloud_tgt (numpy array): target point cloud
|
pointcloud_tgt (numpy array): target point cloud
|
||||||
normals (numpy array): predicted normals
|
normals (numpy array): predicted normals
|
||||||
normals_tgt (numpy array): target normals
|
normals_tgt (numpy array): target normals
|
||||||
thresholds (numpy array): threshold values for the F-score calculation
|
thresholds (numpy array): threshold values for the F-score calculation.
|
||||||
'''
|
"""
|
||||||
# Return maximum losses if pointcloud is empty
|
# Return maximum losses if pointcloud is empty
|
||||||
if pointcloud.shape[0] == 0:
|
if pointcloud.shape[0] == 0:
|
||||||
logger.warn('Empty pointcloud / mesh detected!')
|
logger.warning("Empty pointcloud / mesh detected!")
|
||||||
out_dict = EMPTY_PCL_DICT.copy()
|
out_dict = EMPTY_PCL_DICT.copy()
|
||||||
if normals is not None and normals_tgt is not None:
|
if normals is not None and normals_tgt is not None:
|
||||||
out_dict.update(EMPTY_PCL_DICT_NORMALS)
|
out_dict.update(EMPTY_PCL_DICT_NORMALS)
|
||||||
|
@ -74,11 +76,13 @@ class MeshEvaluator(object):
|
||||||
pointcloud = np.asarray(pointcloud)
|
pointcloud = np.asarray(pointcloud)
|
||||||
pointcloud_tgt = np.asarray(pointcloud_tgt)
|
pointcloud_tgt = np.asarray(pointcloud_tgt)
|
||||||
|
|
||||||
|
|
||||||
# Completeness: how far are the points of the target point cloud
|
# Completeness: how far are the points of the target point cloud
|
||||||
# from thre predicted point cloud
|
# from thre predicted point cloud
|
||||||
completeness, completeness_normals = distance_p2p(
|
completeness, completeness_normals = distance_p2p(
|
||||||
pointcloud_tgt, normals_tgt, pointcloud, normals
|
pointcloud_tgt,
|
||||||
|
normals_tgt,
|
||||||
|
pointcloud,
|
||||||
|
normals,
|
||||||
)
|
)
|
||||||
recall = get_threshold_percentage(completeness, thresholds)
|
recall = get_threshold_percentage(completeness, thresholds)
|
||||||
completeness2 = completeness**2
|
completeness2 = completeness**2
|
||||||
|
@ -90,7 +94,10 @@ class MeshEvaluator(object):
|
||||||
# Accuracy: how far are th points of the predicted pointcloud
|
# Accuracy: how far are th points of the predicted pointcloud
|
||||||
# from the target pointcloud
|
# from the target pointcloud
|
||||||
accuracy, accuracy_normals = distance_p2p(
|
accuracy, accuracy_normals = distance_p2p(
|
||||||
pointcloud, normals, pointcloud_tgt, normals_tgt
|
pointcloud,
|
||||||
|
normals,
|
||||||
|
pointcloud_tgt,
|
||||||
|
normals_tgt,
|
||||||
)
|
)
|
||||||
precision = get_threshold_percentage(accuracy, thresholds)
|
precision = get_threshold_percentage(accuracy, thresholds)
|
||||||
accuracy2 = accuracy**2
|
accuracy2 = accuracy**2
|
||||||
|
@ -101,68 +108,61 @@ class MeshEvaluator(object):
|
||||||
|
|
||||||
# Chamfer distance
|
# Chamfer distance
|
||||||
chamferL2 = 0.5 * (completeness2 + accuracy2)
|
chamferL2 = 0.5 * (completeness2 + accuracy2)
|
||||||
normals_correctness = (
|
normals_correctness = 0.5 * completeness_normals + 0.5 * accuracy_normals
|
||||||
0.5 * completeness_normals + 0.5 * accuracy_normals
|
|
||||||
)
|
|
||||||
chamferL1 = 0.5 * (completeness + accuracy)
|
chamferL1 = 0.5 * (completeness + accuracy)
|
||||||
|
|
||||||
# F-Score
|
# F-Score
|
||||||
F = [
|
F = [2 * precision[i] * recall[i] / (precision[i] + recall[i]) for i in range(len(precision))]
|
||||||
2 * precision[i] * recall[i] / (precision[i] + recall[i])
|
|
||||||
for i in range(len(precision))
|
|
||||||
]
|
|
||||||
|
|
||||||
out_dict = {
|
out_dict = {
|
||||||
'completeness': completeness,
|
"completeness": completeness,
|
||||||
'accuracy': accuracy,
|
"accuracy": accuracy,
|
||||||
'normals completeness': completeness_normals,
|
"normals completeness": completeness_normals,
|
||||||
'normals accuracy': accuracy_normals,
|
"normals accuracy": accuracy_normals,
|
||||||
'normals': normals_correctness,
|
"normals": normals_correctness,
|
||||||
'completeness2': completeness2,
|
"completeness2": completeness2,
|
||||||
'accuracy2': accuracy2,
|
"accuracy2": accuracy2,
|
||||||
'chamfer-L2': chamferL2,
|
"chamfer-L2": chamferL2,
|
||||||
'chamfer-L1': chamferL1,
|
"chamfer-L1": chamferL1,
|
||||||
'f-score': F[9], # threshold = 1.0%
|
"f-score": F[9], # threshold = 1.0%
|
||||||
'f-score-15': F[14], # threshold = 1.5%
|
"f-score-15": F[14], # threshold = 1.5%
|
||||||
'f-score-20': F[19], # threshold = 2.0%
|
"f-score-20": F[19], # threshold = 2.0%
|
||||||
}
|
}
|
||||||
|
|
||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
def distance_p2p(points_src, normals_src, points_tgt, normals_tgt):
|
def distance_p2p(points_src, normals_src, points_tgt, normals_tgt):
|
||||||
''' Computes minimal distances of each point in points_src to points_tgt.
|
"""Computes minimal distances of each point in points_src to points_tgt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
points_src (numpy array): source points
|
points_src (numpy array): source points
|
||||||
normals_src (numpy array): source normals
|
normals_src (numpy array): source normals
|
||||||
points_tgt (numpy array): target points
|
points_tgt (numpy array): target points
|
||||||
normals_tgt (numpy array): target normals
|
normals_tgt (numpy array): target normals.
|
||||||
'''
|
"""
|
||||||
kdtree = KDTree(points_tgt)
|
kdtree = KDTree(points_tgt)
|
||||||
dist, idx = kdtree.query(points_src)
|
dist, idx = kdtree.query(points_src)
|
||||||
|
|
||||||
if normals_src is not None and normals_tgt is not None:
|
if normals_src is not None and normals_tgt is not None:
|
||||||
normals_src = \
|
normals_src = normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True)
|
||||||
normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True)
|
normals_tgt = normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True)
|
||||||
normals_tgt = \
|
|
||||||
normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True)
|
|
||||||
|
|
||||||
normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1)
|
normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1)
|
||||||
# Handle normals that point into wrong direction gracefully
|
# Handle normals that point into wrong direction gracefully
|
||||||
# (mostly due to mehtod not caring about this in generation)
|
# (mostly due to mehtod not caring about this in generation)
|
||||||
normals_dot_product = np.abs(normals_dot_product)
|
normals_dot_product = np.abs(normals_dot_product)
|
||||||
else:
|
else:
|
||||||
normals_dot_product = np.array(
|
normals_dot_product = np.array([np.nan] * points_src.shape[0], dtype=np.float32)
|
||||||
[np.nan] * points_src.shape[0], dtype=np.float32)
|
|
||||||
return dist, normals_dot_product
|
return dist, normals_dot_product
|
||||||
|
|
||||||
|
|
||||||
def get_threshold_percentage(dist, thresholds):
|
def get_threshold_percentage(dist, thresholds):
|
||||||
''' Evaluates a point cloud.
|
"""Evaluates a point cloud.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dist (numpy array): calculated distance
|
dist (numpy array): calculated distance
|
||||||
thresholds (numpy array): threshold values for the F-score calculation
|
thresholds (numpy array): threshold values for the F-score calculation.
|
||||||
'''
|
"""
|
||||||
in_threshold = [
|
in_threshold = [(dist <= t).mean() for t in thresholds]
|
||||||
(dist <= t).mean() for t in thresholds
|
|
||||||
]
|
|
||||||
return in_threshold
|
return in_threshold
|
|
@ -1,11 +1,12 @@
|
||||||
import torch
|
|
||||||
import time
|
import time
|
||||||
import trimesh
|
|
||||||
import numpy as np
|
import torch
|
||||||
|
|
||||||
from src.utils import mc_from_psr
|
from src.utils import mc_from_psr
|
||||||
|
|
||||||
class Generator3D(object):
|
|
||||||
''' Generator class for Occupancy Networks.
|
class Generator3D:
|
||||||
|
"""Generator class for Occupancy Networks.
|
||||||
|
|
||||||
It provides functions to generate the final mesh as well refining options.
|
It provides functions to generate the final mesh as well refining options.
|
||||||
|
|
||||||
|
@ -17,11 +18,20 @@ class Generator3D(object):
|
||||||
padding (float): how much padding should be used for MISE
|
padding (float): how much padding should be used for MISE
|
||||||
sample (bool): whether z should be sampled
|
sample (bool): whether z should be sampled
|
||||||
input_type (str): type of input
|
input_type (str): type of input
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, model, points_batch_size=100000,
|
def __init__(
|
||||||
threshold=0.5, device=None, padding=0.1,
|
self,
|
||||||
sample=False, input_type = None, dpsr=None, psr_tanh=True):
|
model,
|
||||||
|
points_batch_size=100000,
|
||||||
|
threshold=0.5,
|
||||||
|
device=None,
|
||||||
|
padding=0.1,
|
||||||
|
sample=False,
|
||||||
|
input_type=None,
|
||||||
|
dpsr=None,
|
||||||
|
psr_tanh=True,
|
||||||
|
):
|
||||||
self.model = model.to(device)
|
self.model = model.to(device)
|
||||||
self.points_batch_size = points_batch_size
|
self.points_batch_size = points_batch_size
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
|
@ -33,29 +43,28 @@ class Generator3D(object):
|
||||||
self.psr_tanh = psr_tanh
|
self.psr_tanh = psr_tanh
|
||||||
|
|
||||||
def generate_mesh(self, data, return_stats=True):
|
def generate_mesh(self, data, return_stats=True):
|
||||||
''' Generates the output mesh.
|
"""Generates the output mesh.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (tensor): data tensor
|
data (tensor): data tensor
|
||||||
return_stats (bool): whether stats should be returned
|
return_stats (bool): whether stats should be returned
|
||||||
'''
|
"""
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
device = self.device
|
device = self.device
|
||||||
stats_dict = {}
|
stats_dict = {}
|
||||||
|
|
||||||
p = data.get('inputs', torch.empty(1, 0)).to(device)
|
p = data.get("inputs", torch.empty(1, 0)).to(device)
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
points, normals = self.model(p)
|
points, normals = self.model(p)
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
psr_grid = self.dpsr(points, normals)
|
psr_grid = self.dpsr(points, normals)
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
v, f, _ = mc_from_psr(psr_grid,
|
v, f, _ = mc_from_psr(psr_grid, zero_level=self.threshold)
|
||||||
zero_level=self.threshold)
|
stats_dict["pcl"] = t1 - t0
|
||||||
stats_dict['pcl'] = t1 - t0
|
stats_dict["dpsr"] = t2 - t1
|
||||||
stats_dict['dpsr'] = t2 - t1
|
stats_dict["mc"] = time.time() - t2
|
||||||
stats_dict['mc'] = time.time() - t2
|
stats_dict["total"] = time.time() - t0
|
||||||
stats_dict['total'] = time.time() - t0
|
|
||||||
|
|
||||||
if return_stats:
|
if return_stats:
|
||||||
return v, f, points, normals, stats_dict
|
return v, f, points, normals, stats_dict
|
||||||
|
|
90
src/model.py
90
src/model.py
|
@ -1,18 +1,17 @@
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import time
|
import time
|
||||||
from src.utils import point_rasterize, grid_interp, mc_from_psr, \
|
|
||||||
calc_inters_points
|
import torch
|
||||||
from src.dpsr import DPSR
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from src.network import encoder_dict, decoder_dict
|
|
||||||
|
from src.network import decoder_dict, encoder_dict
|
||||||
from src.network.utils import map2local
|
from src.network.utils import map2local
|
||||||
|
from src.utils import calc_inters_points, grid_interp, mc_from_psr, point_rasterize
|
||||||
|
|
||||||
|
|
||||||
class PSR2Mesh(torch.autograd.Function):
|
class PSR2Mesh(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, psr_grid):
|
def forward(ctx, psr_grid):
|
||||||
"""
|
"""In the forward pass we receive a Tensor containing the input and return
|
||||||
In the forward pass we receive a Tensor containing the input and return
|
|
||||||
a Tensor containing the output. ctx is a context object that can be used
|
a Tensor containing the output. ctx is a context object that can be used
|
||||||
to stash information for backward computation. You can cache arbitrary
|
to stash information for backward computation. You can cache arbitrary
|
||||||
objects for use in the backward pass using the ctx.save_for_backward method.
|
objects for use in the backward pass using the ctx.save_for_backward method.
|
||||||
|
@ -29,8 +28,7 @@ class PSR2Mesh(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals):
|
def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals):
|
||||||
"""
|
"""In the backward pass we receive a Tensor containing the gradient of the loss
|
||||||
In the backward pass we receive a Tensor containing the gradient of the loss
|
|
||||||
with respect to the output, and we need to compute the gradient of the loss
|
with respect to the output, and we need to compute the gradient of the loss
|
||||||
with respect to the input.
|
with respect to the input.
|
||||||
"""
|
"""
|
||||||
|
@ -43,12 +41,12 @@ class PSR2Mesh(torch.autograd.Function):
|
||||||
|
|
||||||
return grad_grid
|
return grad_grid
|
||||||
|
|
||||||
|
|
||||||
class PSR2SurfacePoints(torch.autograd.Function):
|
class PSR2SurfacePoints(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample):
|
def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample):
|
||||||
verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
|
verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
|
||||||
verts = verts * 2. - 1. # within the range of [-1, 1]
|
verts = verts * 2.0 - 1.0 # within the range of [-1, 1]
|
||||||
|
|
||||||
|
|
||||||
p_all, n_all, mask_all = [], [], []
|
p_all, n_all, mask_all = [], [], []
|
||||||
|
|
||||||
|
@ -67,7 +65,6 @@ class PSR2SurfacePoints(torch.autograd.Function):
|
||||||
n_inters_all = torch.cat(n_all, dim=0)
|
n_inters_all = torch.cat(n_all, dim=0)
|
||||||
mask_visible = torch.stack(mask_all, dim=0)
|
mask_visible = torch.stack(mask_all, dim=0)
|
||||||
|
|
||||||
|
|
||||||
res = torch.tensor(psr_grid.detach().shape[2])
|
res = torch.tensor(psr_grid.detach().shape[2])
|
||||||
ctx.save_for_backward(p_inters_all, n_inters_all, res)
|
ctx.save_for_backward(p_inters_all, n_inters_all, res)
|
||||||
|
|
||||||
|
@ -80,30 +77,31 @@ class PSR2SurfacePoints(torch.autograd.Function):
|
||||||
|
|
||||||
# grad from the p_inters via MLP renderer
|
# grad from the p_inters via MLP renderer
|
||||||
grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None])
|
grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None])
|
||||||
grad_grid_pts = point_rasterize((pts[None]+1)/2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res
|
grad_grid_pts = point_rasterize((pts[None] + 1) / 2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res
|
||||||
|
|
||||||
return grad_grid_pts, None, None, None, None, None
|
return grad_grid_pts, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class Encode2Points(nn.Module):
|
class Encode2Points(nn.Module):
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
encoder = cfg['model']['encoder']
|
encoder = cfg["model"]["encoder"]
|
||||||
decoder = cfg['model']['decoder']
|
decoder = cfg["model"]["decoder"]
|
||||||
dim = cfg['data']['dim'] # input dim
|
dim = cfg["data"]["dim"] # input dim
|
||||||
c_dim = cfg['model']['c_dim']
|
c_dim = cfg["model"]["c_dim"]
|
||||||
encoder_kwargs = cfg['model']['encoder_kwargs']
|
encoder_kwargs = cfg["model"]["encoder_kwargs"]
|
||||||
if encoder_kwargs == None:
|
if encoder_kwargs is None:
|
||||||
encoder_kwargs = {}
|
encoder_kwargs = {}
|
||||||
decoder_kwargs = cfg['model']['decoder_kwargs']
|
decoder_kwargs = cfg["model"]["decoder_kwargs"]
|
||||||
padding = cfg['data']['padding']
|
cfg["data"]["padding"]
|
||||||
self.predict_normal = cfg['model']['predict_normal']
|
self.predict_normal = cfg["model"]["predict_normal"]
|
||||||
self.predict_offset = cfg['model']['predict_offset']
|
self.predict_offset = cfg["model"]["predict_offset"]
|
||||||
|
|
||||||
out_dim = 3
|
out_dim = 3
|
||||||
out_dim_offset = 3
|
out_dim_offset = 3
|
||||||
num_offset = cfg['data']['num_offset']
|
num_offset = cfg["data"]["num_offset"]
|
||||||
# each point predict more than one offset to add output points
|
# each point predict more than one offset to add output points
|
||||||
if num_offset > 1:
|
if num_offset > 1:
|
||||||
out_dim_offset = out_dim * num_offset
|
out_dim_offset = out_dim * num_offset
|
||||||
|
@ -111,43 +109,39 @@ class Encode2Points(nn.Module):
|
||||||
|
|
||||||
# local mapping
|
# local mapping
|
||||||
self.map2local = None
|
self.map2local = None
|
||||||
if cfg['model']['local_coord']:
|
if cfg["model"]["local_coord"]:
|
||||||
if 'unet' in encoder_kwargs.keys():
|
if "unet" in encoder_kwargs.keys():
|
||||||
unit_size = 1 / encoder_kwargs['plane_resolution']
|
unit_size = 1 / encoder_kwargs["plane_resolution"]
|
||||||
else:
|
else:
|
||||||
unit_size = 1 / encoder_kwargs['grid_resolution']
|
unit_size = 1 / encoder_kwargs["grid_resolution"]
|
||||||
|
|
||||||
local_mapping = map2local(unit_size)
|
local_mapping = map2local(unit_size)
|
||||||
|
|
||||||
self.encoder = encoder_dict[encoder](
|
self.encoder = encoder_dict[encoder](
|
||||||
dim=dim, c_dim=c_dim, map2local=local_mapping,
|
dim=dim,
|
||||||
**encoder_kwargs
|
c_dim=c_dim,
|
||||||
|
map2local=local_mapping,
|
||||||
|
**encoder_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.predict_normal:
|
if self.predict_normal:
|
||||||
# decoder for normal prediction
|
# decoder for normal prediction
|
||||||
self.decoder_normal = decoder_dict[decoder](
|
self.decoder_normal = decoder_dict[decoder](dim=dim, c_dim=c_dim, out_dim=out_dim, **decoder_kwargs)
|
||||||
dim=dim, c_dim=c_dim, out_dim=out_dim,
|
|
||||||
**decoder_kwargs)
|
|
||||||
if self.predict_offset:
|
if self.predict_offset:
|
||||||
# decoder for offset prediction
|
# decoder for offset prediction
|
||||||
self.decoder_offset = decoder_dict[decoder](
|
self.decoder_offset = decoder_dict[decoder](
|
||||||
dim=dim, c_dim=c_dim, out_dim=out_dim_offset,
|
dim=dim, c_dim=c_dim, out_dim=out_dim_offset, map2local=local_mapping, **decoder_kwargs,
|
||||||
map2local=local_mapping,
|
)
|
||||||
**decoder_kwargs)
|
|
||||||
|
|
||||||
self.s_off = cfg['model']['s_offset']
|
|
||||||
|
|
||||||
|
self.s_off = cfg["model"]["s_offset"]
|
||||||
|
|
||||||
def forward(self, p):
|
def forward(self, p):
|
||||||
''' Performs a forward pass through the network.
|
"""Performs a forward pass through the network.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
p (tensor): input unoriented points
|
p (tensor): input unoriented points
|
||||||
'''
|
"""
|
||||||
|
|
||||||
time_dict = {}
|
time_dict = {}
|
||||||
mask = None
|
|
||||||
|
|
||||||
batch_size = p.size(0)
|
batch_size = p.size(0)
|
||||||
points = p.clone()
|
points = p.clone()
|
||||||
|
@ -169,13 +163,11 @@ class Encode2Points(nn.Module):
|
||||||
normals = self.decoder_normal(points, c)
|
normals = self.decoder_normal(points, c)
|
||||||
t2 = time.perf_counter()
|
t2 = time.perf_counter()
|
||||||
|
|
||||||
time_dict['encode'] = t1 - t0
|
time_dict["encode"] = t1 - t0
|
||||||
time_dict['predict'] = t2 - t1
|
time_dict["predict"] = t2 - t1
|
||||||
|
|
||||||
points = torch.clamp(points, 0.0, 0.99)
|
points = torch.clamp(points, 0.0, 0.99)
|
||||||
if self.cfg['model']['normal_normalize']:
|
if self.cfg["model"]["normal_normalize"]:
|
||||||
normals = normals / (normals.norm(dim=-1, keepdim=True)+1e-8)
|
normals = normals / (normals.norm(dim=-1, keepdim=True) + 1e-8)
|
||||||
|
|
||||||
|
|
||||||
return points, normals
|
return points, normals
|
||||||
|
|
|
@ -1,17 +1,19 @@
|
||||||
import torch
|
import torch
|
||||||
|
from pytorch3d.renderer import (
|
||||||
|
MeshRasterizer,
|
||||||
|
MeshRenderer,
|
||||||
|
PerspectiveCameras,
|
||||||
|
RasterizationSettings,
|
||||||
|
SoftSilhouetteShader,
|
||||||
|
)
|
||||||
|
from pytorch3d.structures import Meshes
|
||||||
|
|
||||||
from src.network.net_rgb import RenderingNetwork
|
from src.network.net_rgb import RenderingNetwork
|
||||||
from src.utils import approx_psr_grad
|
from src.utils import approx_psr_grad
|
||||||
from pytorch3d.renderer import (
|
|
||||||
RasterizationSettings,
|
|
||||||
PerspectiveCameras,
|
|
||||||
MeshRenderer,
|
|
||||||
MeshRasterizer,
|
|
||||||
SoftSilhouetteShader)
|
|
||||||
from pytorch3d.structures import Meshes
|
|
||||||
|
|
||||||
|
|
||||||
def approx_psr_grad(psr_grid, res, normalize=True):
|
def approx_psr_grad(psr_grid, res, normalize=True):
|
||||||
delta_x = delta_y = delta_z = 1/res
|
delta_x = delta_y = delta_z = 1 / res
|
||||||
psr_pad = torch.nn.ReplicationPad3d(1)(psr_grid).squeeze()
|
psr_pad = torch.nn.ReplicationPad3d(1)(psr_grid).squeeze()
|
||||||
|
|
||||||
grad_x = (psr_pad[2:, :, :] - psr_pad[:-2, :, :]) / 2 / delta_x
|
grad_x = (psr_pad[2:, :, :] - psr_pad[:-2, :, :]) / 2 / delta_x
|
||||||
|
@ -35,37 +37,37 @@ class SAP2Image(nn.Module):
|
||||||
self.psr2sur = PSR2SurfacePoints.apply
|
self.psr2sur = PSR2SurfacePoints.apply
|
||||||
self.psr2mesh = PSR2Mesh.apply
|
self.psr2mesh = PSR2Mesh.apply
|
||||||
# initialize DPSR
|
# initialize DPSR
|
||||||
self.dpsr = DPSR(res=(cfg['model']['grid_res'],
|
self.dpsr = DPSR(
|
||||||
cfg['model']['grid_res'],
|
res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
|
||||||
cfg['model']['grid_res']),
|
sig=cfg["model"]["psr_sigma"],
|
||||||
sig=cfg['model']['psr_sigma'])
|
)
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
if cfg['train']['l_weight']['rgb'] != 0.:
|
if cfg["train"]["l_weight"]["rgb"] != 0.0:
|
||||||
self.rendering_network = RenderingNetwork(**cfg['model']['renderer'])
|
self.rendering_network = RenderingNetwork(**cfg["model"]["renderer"])
|
||||||
|
|
||||||
if cfg['train']['l_weight']['mask'] != 0.:
|
if cfg["train"]["l_weight"]["mask"] != 0.0:
|
||||||
# initialize rasterizer
|
# initialize rasterizer
|
||||||
sigma = 1e-4
|
sigma = 1e-4
|
||||||
raster_settings_soft = RasterizationSettings(
|
raster_settings_soft = RasterizationSettings(
|
||||||
image_size=img_size,
|
image_size=img_size,
|
||||||
blur_radius=np.log(1. / 1e-4 - 1.)*sigma,
|
blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma,
|
||||||
faces_per_pixel=150,
|
faces_per_pixel=150,
|
||||||
perspective_correct=False
|
perspective_correct=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# initialize silhouette renderer
|
# initialize silhouette renderer
|
||||||
self.mesh_rasterizer = MeshRenderer(
|
self.mesh_rasterizer = MeshRenderer(
|
||||||
rasterizer=MeshRasterizer(
|
rasterizer=MeshRasterizer(
|
||||||
raster_settings=raster_settings_soft
|
raster_settings=raster_settings_soft,
|
||||||
),
|
),
|
||||||
shader=SoftSilhouetteShader()
|
shader=SoftSilhouetteShader(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.img_size = img_size
|
self.img_size = img_size
|
||||||
|
|
||||||
def forward(self, inputs, data):
|
def forward(self, inputs, data):
|
||||||
points, normals = inputs[...,:3], inputs[...,3:]
|
points, normals = inputs[..., :3], inputs[..., 3:]
|
||||||
points = torch.sigmoid(points)
|
points = torch.sigmoid(points)
|
||||||
normals = normals / normals.norm(dim=-1, keepdim=True)
|
normals = normals / normals.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
@ -76,35 +78,36 @@ class SAP2Image(nn.Module):
|
||||||
return self.render_img(psr_grid, data)
|
return self.render_img(psr_grid, data)
|
||||||
|
|
||||||
def render_img(self, psr_grid, data):
|
def render_img(self, psr_grid, data):
|
||||||
|
n_views = len(data["masks"])
|
||||||
|
n_views_per_iter = self.cfg["data"]["n_views_per_iter"]
|
||||||
|
|
||||||
n_views = len(data['masks'])
|
self.cfg["model"]["renderer"]["mode"]
|
||||||
n_views_per_iter = self.cfg['data']['n_views_per_iter']
|
uv = data["uv"]
|
||||||
|
|
||||||
rgb_render_mode = self.cfg['model']['renderer']['mode']
|
|
||||||
uv = data['uv']
|
|
||||||
|
|
||||||
idx = np.random.randint(0, n_views, n_views_per_iter)
|
idx = np.random.randint(0, n_views, n_views_per_iter)
|
||||||
pose = [data['poses'][i] for i in idx]
|
pose = [data["poses"][i] for i in idx]
|
||||||
rgb = data['rgbs'][idx]
|
rgb = data["rgbs"][idx]
|
||||||
mask_gt = data['masks'][idx]
|
mask_gt = data["masks"][idx]
|
||||||
ray = None
|
ray = None
|
||||||
pred_rgb = None
|
pred_rgb = None
|
||||||
pred_mask = None
|
pred_mask = None
|
||||||
|
|
||||||
if self.cfg['train']['l_weight']['rgb'] != 0.:
|
if self.cfg["train"]["l_weight"]["rgb"] != 0.0:
|
||||||
psr_grad = approx_psr_grad(psr_grid, self.cfg['model']['grid_res'])
|
psr_grad = approx_psr_grad(psr_grid, self.cfg["model"]["grid_res"])
|
||||||
p_inters, visible_mask = self.psr2sur(psr_grid, pose, self.img_size, uv, psr_grad, None)
|
p_inters, visible_mask = self.psr2sur(psr_grid, pose, self.img_size, uv, psr_grad, None)
|
||||||
n_inters = grid_interp(psr_grad[None], (p_inters.detach()[None] + 1) / 2)
|
n_inters = grid_interp(psr_grad[None], (p_inters.detach()[None] + 1) / 2)
|
||||||
fea_interp = None
|
fea_interp = None
|
||||||
if 'rays' in data.keys():
|
if "rays" in data.keys():
|
||||||
ray = data['rays'].squeeze()[idx][visible_mask]
|
ray = data["rays"].squeeze()[idx][visible_mask]
|
||||||
pred_rgb = self.rendering_network(p_inters, normals=n_inters.squeeze(), view_dirs=ray, feature_vectors=fea_interp)
|
pred_rgb = self.rendering_network(
|
||||||
|
p_inters, normals=n_inters.squeeze(), view_dirs=ray, feature_vectors=fea_interp,
|
||||||
|
)
|
||||||
|
|
||||||
# silhouette loss
|
# silhouette loss
|
||||||
if self.cfg['train']['l_weight']['mask'] != 0.:
|
if self.cfg["train"]["l_weight"]["mask"] != 0.0:
|
||||||
# build mesh
|
# build mesh
|
||||||
v, f, _ = self.psr2mesh(psr_grid)
|
v, f, _ = self.psr2mesh(psr_grid)
|
||||||
v = v * 2. - 1 # within the range of [-1, 1]
|
v = v * 2.0 - 1 # within the range of [-1, 1]
|
||||||
# ! Fast but more GPU usage
|
# ! Fast but more GPU usage
|
||||||
mesh = Meshes(verts=[v.squeeze()], faces=[f.squeeze()])
|
mesh = Meshes(verts=[v.squeeze()], faces=[f.squeeze()])
|
||||||
if True:
|
if True:
|
||||||
|
@ -114,11 +117,7 @@ class SAP2Image(nn.Module):
|
||||||
T = torch.cat([p.T for p in pose], dim=0)
|
T = torch.cat([p.T for p in pose], dim=0)
|
||||||
focal = torch.cat([p.focal_length for p in pose], dim=0)
|
focal = torch.cat([p.focal_length for p in pose], dim=0)
|
||||||
pp = torch.cat([p.principal_point for p in pose], dim=0)
|
pp = torch.cat([p.principal_point for p in pose], dim=0)
|
||||||
pose_cur = PerspectiveCameras(
|
pose_cur = PerspectiveCameras(focal_length=focal, principal_point=pp, R=R, T=T, device=R.device)
|
||||||
focal_length=focal,
|
|
||||||
principal_point=pp,
|
|
||||||
R=R, T=T,
|
|
||||||
device=R.device)
|
|
||||||
pred_mask = self.mesh_rasterizer(mesh.extend(n_views_per_iter), cameras=pose_cur)[..., 3]
|
pred_mask = self.mesh_rasterizer(mesh.extend(n_views_per_iter), cameras=pose_cur)[..., 3]
|
||||||
else:
|
else:
|
||||||
pred_mask = []
|
pred_mask = []
|
||||||
|
@ -129,11 +128,11 @@ class SAP2Image(nn.Module):
|
||||||
pred_mask = torch.cat(pred_mask, dim=0)
|
pred_mask = torch.cat(pred_mask, dim=0)
|
||||||
|
|
||||||
output = {
|
output = {
|
||||||
'rgb': pred_rgb,
|
"rgb": pred_rgb,
|
||||||
'rgb_gt': rgb,
|
"rgb_gt": rgb,
|
||||||
'mask': pred_mask,
|
"mask": pred_mask,
|
||||||
'mask_gt': mask_gt,
|
"mask_gt": mask_gt,
|
||||||
'vis_mask': visible_mask,
|
"vis_mask": visible_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
return output
|
return output
|
|
@ -1,8 +1,8 @@
|
||||||
from src.network import encoder, decoder
|
from src.network import decoder, encoder
|
||||||
|
|
||||||
encoder_dict = {
|
encoder_dict = {
|
||||||
'local_pool_pointnet': encoder.LocalPoolPointnet,
|
"local_pool_pointnet": encoder.LocalPoolPointnet,
|
||||||
}
|
}
|
||||||
decoder_dict = {
|
decoder_dict = {
|
||||||
'simple_local': decoder.LocalDecoder,
|
"simple_local": decoder.LocalDecoder,
|
||||||
}
|
}
|
|
@ -1,15 +1,17 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
|
||||||
from ipdb import set_trace as st
|
from src.network.utils import (
|
||||||
from src.network.utils import normalize_3d_coordinate, ResnetBlockFC, \
|
ResnetBlockFC,
|
||||||
normalize_coordinate
|
normalize_3d_coordinate,
|
||||||
|
normalize_coordinate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LocalDecoder(nn.Module):
|
class LocalDecoder(nn.Module):
|
||||||
''' Decoder.
|
"""Decoder.
|
||||||
Instead of conditioning on global features, on plane/volume local features.
|
Instead of conditioning on global features, on plane/volume local features.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dim (int): input dimension
|
dim (int): input dimension
|
||||||
c_dim (int): dimension of latent conditioned code c
|
c_dim (int): dimension of latent conditioned code c
|
||||||
|
@ -17,26 +19,31 @@ class LocalDecoder(nn.Module):
|
||||||
n_blocks (int): number of blocks ResNetBlockFC layers
|
n_blocks (int): number of blocks ResNetBlockFC layers
|
||||||
leaky (bool): whether to use leaky ReLUs
|
leaky (bool): whether to use leaky ReLUs
|
||||||
sample_mode (str): sampling feature strategy, bilinear|nearest
|
sample_mode (str): sampling feature strategy, bilinear|nearest
|
||||||
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55].
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, dim=3, c_dim=128, out_dim=3,
|
def __init__(
|
||||||
hidden_size=256, n_blocks=5, leaky=False, sample_mode='bilinear', padding=0.1, map2local=None):
|
self,
|
||||||
|
dim=3,
|
||||||
|
c_dim=128,
|
||||||
|
out_dim=3,
|
||||||
|
hidden_size=256,
|
||||||
|
n_blocks=5,
|
||||||
|
leaky=False,
|
||||||
|
sample_mode="bilinear",
|
||||||
|
padding=0.1,
|
||||||
|
map2local=None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.c_dim = c_dim
|
self.c_dim = c_dim
|
||||||
self.n_blocks = n_blocks
|
self.n_blocks = n_blocks
|
||||||
|
|
||||||
if c_dim != 0:
|
if c_dim != 0:
|
||||||
self.fc_c = nn.ModuleList([
|
self.fc_c = nn.ModuleList([nn.Linear(c_dim, hidden_size) for i in range(n_blocks)])
|
||||||
nn.Linear(c_dim, hidden_size) for i in range(n_blocks)
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
self.fc_p = nn.Linear(dim, hidden_size)
|
self.fc_p = nn.Linear(dim, hidden_size)
|
||||||
|
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([ResnetBlockFC(hidden_size) for i in range(n_blocks)])
|
||||||
ResnetBlockFC(hidden_size) for i in range(n_blocks)
|
|
||||||
])
|
|
||||||
|
|
||||||
self.fc_out = nn.Linear(hidden_size, out_dim)
|
self.fc_out = nn.Linear(hidden_size, out_dim)
|
||||||
|
|
||||||
|
@ -50,14 +57,11 @@ class LocalDecoder(nn.Module):
|
||||||
self.map2local = map2local
|
self.map2local = map2local
|
||||||
self.out_dim = out_dim
|
self.out_dim = out_dim
|
||||||
|
|
||||||
|
def sample_plane_feature(self, p, c, plane="xz"):
|
||||||
def sample_plane_feature(self, p, c, plane='xz'):
|
|
||||||
xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
|
xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
|
||||||
xy = xy[:, :, None].float()
|
xy = xy[:, :, None].float()
|
||||||
vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
|
vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
|
||||||
c = F.grid_sample(c, vgrid, padding_mode='border',
|
c = F.grid_sample(c, vgrid, padding_mode="border", align_corners=True, mode=self.sample_mode).squeeze(-1)
|
||||||
align_corners=True,
|
|
||||||
mode=self.sample_mode).squeeze(-1)
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def sample_grid_feature(self, p, c):
|
def sample_grid_feature(self, p, c):
|
||||||
|
@ -65,23 +69,25 @@ class LocalDecoder(nn.Module):
|
||||||
p_nor = p_nor[:, :, None, None].float()
|
p_nor = p_nor[:, :, None, None].float()
|
||||||
vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1)
|
vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1)
|
||||||
# acutally trilinear interpolation if mode = 'bilinear'
|
# acutally trilinear interpolation if mode = 'bilinear'
|
||||||
c = F.grid_sample(c, vgrid, padding_mode='border',
|
c = (
|
||||||
align_corners=True,
|
F.grid_sample(c, vgrid, padding_mode="border", align_corners=True, mode=self.sample_mode)
|
||||||
mode=self.sample_mode).squeeze(-1).squeeze(-1)
|
.squeeze(-1)
|
||||||
|
.squeeze(-1)
|
||||||
|
)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def forward(self, p, c_plane, **kwargs):
|
def forward(self, p, c_plane, **kwargs):
|
||||||
batch_size = p.shape[0]
|
batch_size = p.shape[0]
|
||||||
plane_type = list(c_plane.keys())
|
plane_type = list(c_plane.keys())
|
||||||
c = 0
|
c = 0
|
||||||
if 'grid' in plane_type:
|
if "grid" in plane_type:
|
||||||
c += self.sample_grid_feature(p, c_plane['grid'])
|
c += self.sample_grid_feature(p, c_plane["grid"])
|
||||||
if 'xz' in plane_type:
|
if "xz" in plane_type:
|
||||||
c += self.sample_plane_feature(p, c_plane['xz'], plane='xz')
|
c += self.sample_plane_feature(p, c_plane["xz"], plane="xz")
|
||||||
if 'xy' in plane_type:
|
if "xy" in plane_type:
|
||||||
c += self.sample_plane_feature(p, c_plane['xy'], plane='xy')
|
c += self.sample_plane_feature(p, c_plane["xy"], plane="xy")
|
||||||
if 'yz' in plane_type:
|
if "yz" in plane_type:
|
||||||
c += self.sample_plane_feature(p, c_plane['yz'], plane='yz')
|
c += self.sample_plane_feature(p, c_plane["yz"], plane="yz")
|
||||||
c = c.transpose(1, 2)
|
c = c.transpose(1, 2)
|
||||||
|
|
||||||
p = p.float()
|
p = p.float()
|
||||||
|
@ -99,7 +105,6 @@ class LocalDecoder(nn.Module):
|
||||||
|
|
||||||
out = self.fc_out(self.actvn(net))
|
out = self.fc_out(self.actvn(net))
|
||||||
|
|
||||||
|
|
||||||
if self.out_dim > 3:
|
if self.out_dim > 3:
|
||||||
out = out.reshape(batch_size, -1, 3)
|
out = out.reshape(batch_size, -1, 3)
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,23 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
from torch_scatter import scatter_max, scatter_mean
|
||||||
from src.network.unet3d import UNet3D
|
|
||||||
from src.network.unet import UNet
|
from src.network.unet import UNet
|
||||||
from ipdb import set_trace as st
|
from src.network.unet3d import UNet3D
|
||||||
from torch_scatter import scatter_mean, scatter_max
|
from src.network.utils import (
|
||||||
from src.network.utils import get_embedder, normalize_3d_coordinate,\
|
ResnetBlockFC,
|
||||||
coordinate2index, ResnetBlockFC, normalize_coordinate
|
coordinate2index,
|
||||||
|
get_embedder,
|
||||||
|
normalize_3d_coordinate,
|
||||||
|
normalize_coordinate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LocalPoolPointnet(nn.Module):
|
class LocalPoolPointnet(nn.Module):
|
||||||
''' PointNet-based encoder network with ResNet blocks for each point.
|
"""PointNet-based encoder network with ResNet blocks for each point.
|
||||||
Number of input points are fixed.
|
Number of input points are fixed.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
c_dim (int): dimension of latent code c
|
c_dim (int): dimension of latent code c
|
||||||
dim (int): input points dimension
|
dim (int): input points dimension
|
||||||
|
@ -28,20 +34,32 @@ class LocalPoolPointnet(nn.Module):
|
||||||
n_blocks (int): number of blocks ResNetBlockFC layers
|
n_blocks (int): number of blocks ResNetBlockFC layers
|
||||||
map2local (function): map global coordintes to local ones
|
map2local (function): map global coordintes to local ones
|
||||||
pos_encoding (int): frequency for the positional encoding
|
pos_encoding (int): frequency for the positional encoding
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max',
|
def __init__(
|
||||||
unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None,
|
self,
|
||||||
plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5,
|
c_dim=128,
|
||||||
map2local=None, pos_encoding=0):
|
dim=3,
|
||||||
|
hidden_dim=128,
|
||||||
|
scatter_type="max",
|
||||||
|
unet=False,
|
||||||
|
unet_kwargs=None,
|
||||||
|
unet3d=False,
|
||||||
|
unet3d_kwargs=None,
|
||||||
|
plane_resolution=None,
|
||||||
|
grid_resolution=None,
|
||||||
|
plane_type="xz",
|
||||||
|
padding=0.1,
|
||||||
|
n_blocks=5,
|
||||||
|
map2local=None,
|
||||||
|
pos_encoding=0,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.c_dim = c_dim
|
self.c_dim = c_dim
|
||||||
|
|
||||||
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
|
self.fc_pos = nn.Linear(dim, 2 * hidden_dim)
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks)])
|
||||||
ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
|
|
||||||
])
|
|
||||||
self.fc_c = nn.Linear(hidden_dim, c_dim)
|
self.fc_c = nn.Linear(hidden_dim, c_dim)
|
||||||
|
|
||||||
self.actvn = nn.ReLU()
|
self.actvn = nn.ReLU()
|
||||||
|
@ -60,24 +78,24 @@ class LocalPoolPointnet(nn.Module):
|
||||||
self.plane_type = plane_type
|
self.plane_type = plane_type
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
|
|
||||||
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
|
self.fc_pos = nn.Linear(dim, 2 * hidden_dim)
|
||||||
self.pe = None
|
self.pe = None
|
||||||
if pos_encoding > 0:
|
if pos_encoding > 0:
|
||||||
embed_fn, input_ch = get_embedder(pos_encoding, d_in=dim)
|
embed_fn, input_ch = get_embedder(pos_encoding, d_in=dim)
|
||||||
self.pe = embed_fn
|
self.pe = embed_fn
|
||||||
self.fc_pos = nn.Linear(input_ch, 2*hidden_dim)
|
self.fc_pos = nn.Linear(input_ch, 2 * hidden_dim)
|
||||||
|
|
||||||
self.map2local = map2local
|
self.map2local = map2local
|
||||||
|
|
||||||
|
if scatter_type == "max":
|
||||||
if scatter_type == 'max':
|
|
||||||
self.scatter = scatter_max
|
self.scatter = scatter_max
|
||||||
elif scatter_type == 'mean':
|
elif scatter_type == "mean":
|
||||||
self.scatter = scatter_mean
|
self.scatter = scatter_mean
|
||||||
else:
|
else:
|
||||||
raise ValueError('incorrect scatter type')
|
msg = "incorrect scatter type"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
def generate_plane_features(self, p, c, plane='xz'):
|
def generate_plane_features(self, p, c, plane="xz"):
|
||||||
# acquire indices of features in plane
|
# acquire indices of features in plane
|
||||||
xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
|
xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
|
||||||
index = coordinate2index(xy, self.reso_plane)
|
index = coordinate2index(xy, self.reso_plane)
|
||||||
|
@ -86,7 +104,9 @@ class LocalPoolPointnet(nn.Module):
|
||||||
fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
|
fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
|
||||||
c = c.permute(0, 2, 1) # B x 512 x T
|
c = c.permute(0, 2, 1) # B x 512 x T
|
||||||
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
|
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
|
||||||
fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)
|
fea_plane = fea_plane.reshape(
|
||||||
|
p.size(0), self.c_dim, self.reso_plane, self.reso_plane,
|
||||||
|
) # sparce matrix (B x 512 x reso x reso)
|
||||||
|
|
||||||
# process the plane features with UNet
|
# process the plane features with UNet
|
||||||
if self.unet is not None:
|
if self.unet is not None:
|
||||||
|
@ -96,12 +116,14 @@ class LocalPoolPointnet(nn.Module):
|
||||||
|
|
||||||
def generate_grid_features(self, p, c):
|
def generate_grid_features(self, p, c):
|
||||||
p_nor = normalize_3d_coordinate(p.clone())
|
p_nor = normalize_3d_coordinate(p.clone())
|
||||||
index = coordinate2index(p_nor, self.reso_grid, coord_type='3d')
|
index = coordinate2index(p_nor, self.reso_grid, coord_type="3d")
|
||||||
# scatter grid features from points
|
# scatter grid features from points
|
||||||
fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3)
|
fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3)
|
||||||
c = c.permute(0, 2, 1)
|
c = c.permute(0, 2, 1)
|
||||||
fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3
|
fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3
|
||||||
fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) # sparce matrix (B x 512 x reso x reso)
|
fea_grid = fea_grid.reshape(
|
||||||
|
p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid,
|
||||||
|
) # sparce matrix (B x 512 x reso x reso)
|
||||||
|
|
||||||
if self.unet3d is not None:
|
if self.unet3d is not None:
|
||||||
fea_grid = self.unet3d(fea_grid)
|
fea_grid = self.unet3d(fea_grid)
|
||||||
|
@ -115,7 +137,7 @@ class LocalPoolPointnet(nn.Module):
|
||||||
c_out = 0
|
c_out = 0
|
||||||
for key in keys:
|
for key in keys:
|
||||||
# scatter plane features from points
|
# scatter plane features from points
|
||||||
if key == 'grid':
|
if key == "grid":
|
||||||
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid**3)
|
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid**3)
|
||||||
else:
|
else:
|
||||||
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2)
|
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2)
|
||||||
|
@ -126,29 +148,27 @@ class LocalPoolPointnet(nn.Module):
|
||||||
c_out += fea
|
c_out += fea
|
||||||
return c_out.permute(0, 2, 1)
|
return c_out.permute(0, 2, 1)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, p, normalize=True):
|
def forward(self, p, normalize=True):
|
||||||
batch_size, T, D = p.size()
|
batch_size, T, D = p.size()
|
||||||
|
|
||||||
# acquire the index for each point
|
# acquire the index for each point
|
||||||
coord = {}
|
coord = {}
|
||||||
index = {}
|
index = {}
|
||||||
if 'xz' in self.plane_type:
|
if "xz" in self.plane_type:
|
||||||
coord['xz'] = normalize_coordinate(p.clone(), plane='xz')
|
coord["xz"] = normalize_coordinate(p.clone(), plane="xz")
|
||||||
index['xz'] = coordinate2index(coord['xz'], self.reso_plane)
|
index["xz"] = coordinate2index(coord["xz"], self.reso_plane)
|
||||||
if 'xy' in self.plane_type:
|
if "xy" in self.plane_type:
|
||||||
coord['xy'] = normalize_coordinate(p.clone(), plane='xy')
|
coord["xy"] = normalize_coordinate(p.clone(), plane="xy")
|
||||||
index['xy'] = coordinate2index(coord['xy'], self.reso_plane)
|
index["xy"] = coordinate2index(coord["xy"], self.reso_plane)
|
||||||
if 'yz' in self.plane_type:
|
if "yz" in self.plane_type:
|
||||||
coord['yz'] = normalize_coordinate(p.clone(), plane='yz')
|
coord["yz"] = normalize_coordinate(p.clone(), plane="yz")
|
||||||
index['yz'] = coordinate2index(coord['yz'], self.reso_plane)
|
index["yz"] = coordinate2index(coord["yz"], self.reso_plane)
|
||||||
if 'grid' in self.plane_type:
|
if "grid" in self.plane_type:
|
||||||
if normalize:
|
if normalize:
|
||||||
coord['grid'] = normalize_3d_coordinate(p.clone())
|
coord["grid"] = normalize_3d_coordinate(p.clone())
|
||||||
else:
|
else:
|
||||||
coord['grid'] = p.clone()[...,:3]
|
coord["grid"] = p.clone()[..., :3]
|
||||||
index['grid'] = coordinate2index(coord['grid'], self.reso_grid, coord_type='3d')
|
index["grid"] = coordinate2index(coord["grid"], self.reso_grid, coord_type="3d")
|
||||||
|
|
||||||
|
|
||||||
if self.pe:
|
if self.pe:
|
||||||
p = self.pe(p)
|
p = self.pe(p)
|
||||||
|
@ -169,13 +189,13 @@ class LocalPoolPointnet(nn.Module):
|
||||||
c = self.fc_c(net)
|
c = self.fc_c(net)
|
||||||
|
|
||||||
fea = {}
|
fea = {}
|
||||||
if 'grid' in self.plane_type:
|
if "grid" in self.plane_type:
|
||||||
fea['grid'] = self.generate_grid_features(p, c)
|
fea["grid"] = self.generate_grid_features(p, c)
|
||||||
if 'xz' in self.plane_type:
|
if "xz" in self.plane_type:
|
||||||
fea['xz'] = self.generate_plane_features(p, c, plane='xz')
|
fea["xz"] = self.generate_plane_features(p, c, plane="xz")
|
||||||
if 'xy' in self.plane_type:
|
if "xy" in self.plane_type:
|
||||||
fea['xy'] = self.generate_plane_features(p, c, plane='xy')
|
fea["xy"] = self.generate_plane_features(p, c, plane="xy")
|
||||||
if 'yz' in self.plane_type:
|
if "yz" in self.plane_type:
|
||||||
fea['yz'] = self.generate_plane_features(p, c, plane='yz')
|
fea["yz"] = self.generate_plane_features(p, c, plane="yz")
|
||||||
|
|
||||||
return fea
|
return fea
|
|
@ -1,39 +1,41 @@
|
||||||
# code from IDR (https://github.com/lioryariv/idr/blob/main/code/model/implicit_differentiable_renderer.py)
|
# code from IDR (https://github.com/lioryariv/idr/blob/main/code/model/implicit_differentiable_renderer.py)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
|
||||||
from src.network.utils import get_embedder
|
from src.network.utils import get_embedder
|
||||||
from pdb import set_trace as st
|
|
||||||
|
|
||||||
class RenderingNetwork(nn.Module):
|
class RenderingNetwork(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
fea_size=0,
|
fea_size=0,
|
||||||
mode='naive',
|
mode="naive",
|
||||||
d_out=3,
|
d_out=3,
|
||||||
dims=[512, 512, 512, 512],
|
dims=[512, 512, 512, 512],
|
||||||
weight_norm=True,
|
weight_norm=True,
|
||||||
pe_freq_view=0 # for positional encoding
|
pe_freq_view=0, # for positional encoding
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
if mode == 'naive':
|
if mode == "naive":
|
||||||
d_in = 3
|
d_in = 3
|
||||||
elif mode == 'no_feature':
|
elif mode == "no_feature":
|
||||||
d_in = 3 + 3 + 3
|
d_in = 3 + 3 + 3
|
||||||
fea_size = 0
|
fea_size = 0
|
||||||
elif mode == 'full':
|
elif mode == "full":
|
||||||
d_in = 3 + 3 + 3
|
d_in = 3 + 3 + 3
|
||||||
else:
|
else:
|
||||||
d_in = 3 + 3
|
d_in = 3 + 3
|
||||||
dims = [d_in + fea_size] + dims + [d_out]
|
dims = [d_in + fea_size, *dims, d_out]
|
||||||
|
|
||||||
self.embedview_fn = None
|
self.embedview_fn = None
|
||||||
if pe_freq_view > 0:
|
if pe_freq_view > 0:
|
||||||
embedview_fn, input_ch = get_embedder(pe_freq_view, d_in=3)
|
embedview_fn, input_ch = get_embedder(pe_freq_view, d_in=3)
|
||||||
self.embedview_fn = embedview_fn
|
self.embedview_fn = embedview_fn
|
||||||
dims[0] += (input_ch - 3)
|
dims[0] += input_ch - 3
|
||||||
|
|
||||||
self.num_layers = len(dims)
|
self.num_layers = len(dims)
|
||||||
|
|
||||||
|
@ -54,13 +56,13 @@ class RenderingNetwork(nn.Module):
|
||||||
view_dirs = self.embedview_fn(view_dirs)
|
view_dirs = self.embedview_fn(view_dirs)
|
||||||
# points = self.embedview_fn(points)
|
# points = self.embedview_fn(points)
|
||||||
|
|
||||||
if (self.mode == 'full') & (feature_vectors is not None):
|
if (self.mode == "full") & (feature_vectors is not None):
|
||||||
rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
|
rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
|
||||||
elif (self.mode == 'no_feature') | ((self.mode == 'full') & (feature_vectors is None)):
|
elif (self.mode == "no_feature") | ((self.mode == "full") & (feature_vectors is None)):
|
||||||
rendering_input = torch.cat([points, view_dirs, normals], dim=-1)
|
rendering_input = torch.cat([points, view_dirs, normals], dim=-1)
|
||||||
elif self.mode == 'no_view_dir':
|
elif self.mode == "no_view_dir":
|
||||||
rendering_input = torch.cat([points, normals], dim=-1)
|
rendering_input = torch.cat([points, normals], dim=-1)
|
||||||
elif self.mode == 'no_normal':
|
elif self.mode == "no_normal":
|
||||||
rendering_input = torch.cat([points, view_dirs], dim=-1)
|
rendering_input = torch.cat([points, view_dirs], dim=-1)
|
||||||
else:
|
else:
|
||||||
rendering_input = points
|
rendering_input = points
|
||||||
|
@ -83,25 +85,24 @@ class NeRFRenderingNetwork(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
feature_vector_size=0,
|
feature_vector_size=0,
|
||||||
mode='naive',
|
mode="naive",
|
||||||
d_in=3,
|
d_in=3,
|
||||||
d_out=3,
|
d_out=3,
|
||||||
dims=[512, 512, 512, 256],
|
dims=[512, 512, 512, 256],
|
||||||
weight_norm=True,
|
weight_norm=True,
|
||||||
multires=0, # positional encoding of points
|
multires=0, # positional encoding of points
|
||||||
multires_view=0 # positional encoding of view
|
multires_view=0, # positional encoding of view
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
dims = [d_in + feature_vector_size] + dims
|
dims = [d_in + feature_vector_size, *dims]
|
||||||
|
|
||||||
|
|
||||||
self.embed_fn = None
|
self.embed_fn = None
|
||||||
if multires > 0:
|
if multires > 0:
|
||||||
embed_fn, input_ch = get_embedder(multires, d_in=d_in)
|
embed_fn, input_ch = get_embedder(multires, d_in=d_in)
|
||||||
self.embed_fn = embed_fn
|
self.embed_fn = embed_fn
|
||||||
dims[0] += (input_ch - 3)
|
dims[0] += input_ch - 3
|
||||||
|
|
||||||
self.num_layers = len(dims)
|
self.num_layers = len(dims)
|
||||||
|
|
||||||
|
@ -113,13 +114,12 @@ class NeRFRenderingNetwork(nn.Module):
|
||||||
self.embedview_fn = embedview_fn
|
self.embedview_fn = embedview_fn
|
||||||
# dims[0] += (input_ch - 3)
|
# dims[0] += (input_ch - 3)
|
||||||
|
|
||||||
if mode == 'full':
|
if mode == "full":
|
||||||
self.view_net = nn.ModuleList([nn.Linear(dims[-1]+view_ch, 128)])
|
self.view_net = nn.ModuleList([nn.Linear(dims[-1] + view_ch, 128)])
|
||||||
self.rgb_net = nn.Linear(128, 3)
|
self.rgb_net = nn.Linear(128, 3)
|
||||||
else:
|
else:
|
||||||
self.rgb_net = nn.Linear(dims[-1], 3)
|
self.rgb_net = nn.Linear(dims[-1], 3)
|
||||||
|
|
||||||
|
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.tanh = nn.Tanh()
|
self.tanh = nn.Tanh()
|
||||||
|
|
||||||
|
@ -134,7 +134,7 @@ class NeRFRenderingNetwork(nn.Module):
|
||||||
x = net(x)
|
x = net(x)
|
||||||
x = self.relu(x)
|
x = self.relu(x)
|
||||||
|
|
||||||
if self.mode=='full':
|
if self.mode == "full":
|
||||||
x = torch.cat([x, view_dirs], -1)
|
x = torch.cat([x, view_dirs], -1)
|
||||||
for net in self.view_net:
|
for net in self.view_net:
|
||||||
x = net(x)
|
x = net(x)
|
||||||
|
@ -144,6 +144,7 @@ class NeRFRenderingNetwork(nn.Module):
|
||||||
x = self.tanh(x)
|
x = self.tanh(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ImplicitNetwork(nn.Module):
|
class ImplicitNetwork(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -155,11 +156,11 @@ class ImplicitNetwork(nn.Module):
|
||||||
bias=1.0,
|
bias=1.0,
|
||||||
skip_in=(),
|
skip_in=(),
|
||||||
weight_norm=True,
|
weight_norm=True,
|
||||||
multires=0
|
multires=0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
dims = [d_in] + dims + [d_out + feature_vector_size]
|
dims = [d_in, *dims, d_out + feature_vector_size]
|
||||||
|
|
||||||
self.embed_fn = None
|
self.embed_fn = None
|
||||||
if multires > 0:
|
if multires > 0:
|
||||||
|
@ -189,7 +190,7 @@ class ImplicitNetwork(nn.Module):
|
||||||
elif multires > 0 and l in self.skip_in:
|
elif multires > 0 and l in self.skip_in:
|
||||||
torch.nn.init.constant_(lin.bias, 0.0)
|
torch.nn.init.constant_(lin.bias, 0.0)
|
||||||
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
||||||
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
|
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0)
|
||||||
else:
|
else:
|
||||||
torch.nn.init.constant_(lin.bias, 0.0)
|
torch.nn.init.constant_(lin.bias, 0.0)
|
||||||
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
||||||
|
@ -222,13 +223,9 @@ class ImplicitNetwork(nn.Module):
|
||||||
|
|
||||||
def gradient(self, x):
|
def gradient(self, x):
|
||||||
x.requires_grad_(True)
|
x.requires_grad_(True)
|
||||||
y = self.forward(x)[:,:1]
|
y = self.forward(x)[:, :1]
|
||||||
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
||||||
gradients = torch.autograd.grad(
|
gradients = torch.autograd.grad(
|
||||||
outputs=y,
|
outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True,
|
||||||
inputs=x,
|
)[0]
|
||||||
grad_outputs=d_output,
|
|
||||||
create_graph=True,
|
|
||||||
retain_graph=True,
|
|
||||||
only_inputs=True)[0]
|
|
||||||
return gradients.unsqueeze(1)
|
return gradients.unsqueeze(1)
|
|
@ -1,55 +1,38 @@
|
||||||
'''
|
"""Codes are from:
|
||||||
Codes are from:
|
https://github.com/jaxony/unet-pytorch/blob/master/model.py.
|
||||||
https://github.com/jaxony/unet-pytorch/blob/master/model.py
|
"""
|
||||||
'''
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import init
|
from torch.nn import init
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
def conv3x3(in_channels, out_channels, stride=1,
|
|
||||||
padding=1, bias=True, groups=1):
|
|
||||||
return nn.Conv2d(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=stride,
|
|
||||||
padding=padding,
|
|
||||||
bias=bias,
|
|
||||||
groups=groups)
|
|
||||||
|
|
||||||
def upconv2x2(in_channels, out_channels, mode='transpose'):
|
def conv3x3(in_channels, out_channels, stride=1, padding=1, bias=True, groups=1):
|
||||||
if mode == 'transpose':
|
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, bias=bias, groups=groups)
|
||||||
return nn.ConvTranspose2d(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
def upconv2x2(in_channels, out_channels, mode="transpose"):
|
||||||
kernel_size=2,
|
if mode == "transpose":
|
||||||
stride=2)
|
return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
||||||
else:
|
else:
|
||||||
# out_channels is always going to be the same
|
# out_channels is always going to be the same
|
||||||
# as in_channels
|
# as in_channels
|
||||||
return nn.Sequential(
|
return nn.Sequential(nn.Upsample(mode="bilinear", scale_factor=2), conv1x1(in_channels, out_channels))
|
||||||
nn.Upsample(mode='bilinear', scale_factor=2),
|
|
||||||
conv1x1(in_channels, out_channels))
|
|
||||||
|
|
||||||
def conv1x1(in_channels, out_channels, groups=1):
|
def conv1x1(in_channels, out_channels, groups=1):
|
||||||
return nn.Conv2d(
|
return nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=groups, stride=1)
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
groups=groups,
|
|
||||||
stride=1)
|
|
||||||
|
|
||||||
|
|
||||||
class DownConv(nn.Module):
|
class DownConv(nn.Module):
|
||||||
"""
|
"""A helper Module that performs 2 convolutions and 1 MaxPool.
|
||||||
A helper Module that performs 2 convolutions and 1 MaxPool.
|
|
||||||
A ReLU activation follows each convolution.
|
A ReLU activation follows each convolution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, pooling=True):
|
def __init__(self, in_channels, out_channels, pooling=True):
|
||||||
super(DownConv, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
@ -71,39 +54,35 @@ class DownConv(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class UpConv(nn.Module):
|
class UpConv(nn.Module):
|
||||||
"""
|
"""A helper Module that performs 2 convolutions and 1 UpConvolution.
|
||||||
A helper Module that performs 2 convolutions and 1 UpConvolution.
|
|
||||||
A ReLU activation follows each convolution.
|
A ReLU activation follows each convolution.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_channels, out_channels,
|
|
||||||
merge_mode='concat', up_mode='transpose'):
|
def __init__(self, in_channels, out_channels, merge_mode="concat", up_mode="transpose"):
|
||||||
super(UpConv, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.merge_mode = merge_mode
|
self.merge_mode = merge_mode
|
||||||
self.up_mode = up_mode
|
self.up_mode = up_mode
|
||||||
|
|
||||||
self.upconv = upconv2x2(self.in_channels, self.out_channels,
|
self.upconv = upconv2x2(self.in_channels, self.out_channels, mode=self.up_mode)
|
||||||
mode=self.up_mode)
|
|
||||||
|
|
||||||
if self.merge_mode == 'concat':
|
if self.merge_mode == "concat":
|
||||||
self.conv1 = conv3x3(
|
self.conv1 = conv3x3(2 * self.out_channels, self.out_channels)
|
||||||
2*self.out_channels, self.out_channels)
|
|
||||||
else:
|
else:
|
||||||
# num of input channels to conv2 is same
|
# num of input channels to conv2 is same
|
||||||
self.conv1 = conv3x3(self.out_channels, self.out_channels)
|
self.conv1 = conv3x3(self.out_channels, self.out_channels)
|
||||||
self.conv2 = conv3x3(self.out_channels, self.out_channels)
|
self.conv2 = conv3x3(self.out_channels, self.out_channels)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, from_down, from_up):
|
def forward(self, from_down, from_up):
|
||||||
""" Forward pass
|
"""Forward pass
|
||||||
Arguments:
|
Arguments:
|
||||||
from_down: tensor from the encoder pathway
|
from_down: tensor from the encoder pathway
|
||||||
from_up: upconv'd tensor from the decoder pathway
|
from_up: upconv'd tensor from the decoder pathway.
|
||||||
"""
|
"""
|
||||||
from_up = self.upconv(from_up)
|
from_up = self.upconv(from_up)
|
||||||
if self.merge_mode == 'concat':
|
if self.merge_mode == "concat":
|
||||||
x = torch.cat((from_up, from_down), 1)
|
x = torch.cat((from_up, from_down), 1)
|
||||||
else:
|
else:
|
||||||
x = from_up + from_down
|
x = from_up + from_down
|
||||||
|
@ -113,7 +92,7 @@ class UpConv(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
""" `UNet` class is based on https://arxiv.org/abs/1505.04597
|
"""`UNet` class is based on https://arxiv.org/abs/1505.04597.
|
||||||
|
|
||||||
The U-Net is a convolutional encoder-decoder neural network.
|
The U-Net is a convolutional encoder-decoder neural network.
|
||||||
Contextual spatial information (from the decoding,
|
Contextual spatial information (from the decoding,
|
||||||
|
@ -135,11 +114,10 @@ class UNet(nn.Module):
|
||||||
the tranpose convolution (specified by upmode='transpose')
|
the tranpose convolution (specified by upmode='transpose')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_classes, in_channels=3, depth=5,
|
def __init__(
|
||||||
start_filts=64, up_mode='transpose',
|
self, num_classes, in_channels=3, depth=5, start_filts=64, up_mode="transpose", merge_mode="concat", **kwargs,
|
||||||
merge_mode='concat', **kwargs):
|
):
|
||||||
"""
|
"""Arguments:
|
||||||
Arguments:
|
|
||||||
in_channels: int, number of channels in the input tensor.
|
in_channels: int, number of channels in the input tensor.
|
||||||
Default is 3 for RGB images.
|
Default is 3 for RGB images.
|
||||||
depth: int, number of MaxPools in the U-Net.
|
depth: int, number of MaxPools in the U-Net.
|
||||||
|
@ -149,30 +127,24 @@ class UNet(nn.Module):
|
||||||
for transpose convolution or 'upsample' for nearest neighbour
|
for transpose convolution or 'upsample' for nearest neighbour
|
||||||
upsampling.
|
upsampling.
|
||||||
"""
|
"""
|
||||||
super(UNet, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
if up_mode in ('transpose', 'upsample'):
|
if up_mode in ("transpose", "upsample"):
|
||||||
self.up_mode = up_mode
|
self.up_mode = up_mode
|
||||||
else:
|
else:
|
||||||
raise ValueError("\"{}\" is not a valid mode for "
|
msg = f'"{up_mode}" is not a valid mode for upsampling. Only "transpose" and "upsample" are allowed.'
|
||||||
"upsampling. Only \"transpose\" and "
|
raise ValueError(msg)
|
||||||
"\"upsample\" are allowed.".format(up_mode))
|
|
||||||
|
|
||||||
if merge_mode in ('concat', 'add'):
|
if merge_mode in ("concat", "add"):
|
||||||
self.merge_mode = merge_mode
|
self.merge_mode = merge_mode
|
||||||
else:
|
else:
|
||||||
raise ValueError("\"{}\" is not a valid mode for"
|
msg = f'"{up_mode}" is not a valid mode formerging up and down paths. Only "concat" and "add" are allowed.'
|
||||||
"merging up and down paths. "
|
raise ValueError(msg)
|
||||||
"Only \"concat\" and "
|
|
||||||
"\"add\" are allowed.".format(up_mode))
|
|
||||||
|
|
||||||
# NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
|
# NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
|
||||||
if self.up_mode == 'upsample' and self.merge_mode == 'add':
|
if self.up_mode == "upsample" and self.merge_mode == "add":
|
||||||
raise ValueError("up_mode \"upsample\" is incompatible "
|
msg = 'up_mode "upsample" is incompatible with merge_mode "add" at the moment because it doesn\'t make sense to use nearest neighbour to reduce depth channels (by half).'
|
||||||
"with merge_mode \"add\" at the moment "
|
raise ValueError(msg)
|
||||||
"because it doesn't make sense to use "
|
|
||||||
"nearest neighbour to reduce "
|
|
||||||
"depth channels (by half).")
|
|
||||||
|
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
@ -185,19 +157,18 @@ class UNet(nn.Module):
|
||||||
# create the encoder pathway and add to a list
|
# create the encoder pathway and add to a list
|
||||||
for i in range(depth):
|
for i in range(depth):
|
||||||
ins = self.in_channels if i == 0 else outs
|
ins = self.in_channels if i == 0 else outs
|
||||||
outs = self.start_filts*(2**i)
|
outs = self.start_filts * (2**i)
|
||||||
pooling = True if i < depth-1 else False
|
pooling = True if i < depth - 1 else False
|
||||||
|
|
||||||
down_conv = DownConv(ins, outs, pooling=pooling)
|
down_conv = DownConv(ins, outs, pooling=pooling)
|
||||||
self.down_convs.append(down_conv)
|
self.down_convs.append(down_conv)
|
||||||
|
|
||||||
# create the decoder pathway and add to a list
|
# create the decoder pathway and add to a list
|
||||||
# - careful! decoding only requires depth-1 blocks
|
# - careful! decoding only requires depth-1 blocks
|
||||||
for i in range(depth-1):
|
for i in range(depth - 1):
|
||||||
ins = outs
|
ins = outs
|
||||||
outs = ins // 2
|
outs = ins // 2
|
||||||
up_conv = UpConv(ins, outs, up_mode=up_mode,
|
up_conv = UpConv(ins, outs, up_mode=up_mode, merge_mode=merge_mode)
|
||||||
merge_mode=merge_mode)
|
|
||||||
self.up_convs.append(up_conv)
|
self.up_convs.append(up_conv)
|
||||||
|
|
||||||
# add the list of modules to current module
|
# add the list of modules to current module
|
||||||
|
@ -214,12 +185,10 @@ class UNet(nn.Module):
|
||||||
init.xavier_normal_(m.weight)
|
init.xavier_normal_(m.weight)
|
||||||
init.constant_(m.bias, 0)
|
init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
|
||||||
def reset_params(self):
|
def reset_params(self):
|
||||||
for i, m in enumerate(self.modules()):
|
for _i, m in enumerate(self.modules()):
|
||||||
self.weight_init(m)
|
self.weight_init(m)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
encoder_outs = []
|
encoder_outs = []
|
||||||
# encoder pathway, save outputs for merging
|
# encoder pathway, save outputs for merging
|
||||||
|
@ -227,7 +196,7 @@ class UNet(nn.Module):
|
||||||
x, before_pool = module(x)
|
x, before_pool = module(x)
|
||||||
encoder_outs.append(before_pool)
|
encoder_outs.append(before_pool)
|
||||||
for i, module in enumerate(self.up_convs):
|
for i, module in enumerate(self.up_convs):
|
||||||
before_pool = encoder_outs[-(i+2)]
|
before_pool = encoder_outs[-(i + 2)]
|
||||||
x = module(before_pool, x)
|
x = module(before_pool, x)
|
||||||
|
|
||||||
# No softmax is used. This means you need to use
|
# No softmax is used. This means you need to use
|
||||||
|
@ -236,21 +205,22 @@ class UNet(nn.Module):
|
||||||
x = self.conv_final(x)
|
x = self.conv_final(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""
|
"""
|
||||||
testing
|
testing
|
||||||
"""
|
"""
|
||||||
model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32)
|
model = UNet(1, depth=5, merge_mode="concat", in_channels=1, start_filts=32)
|
||||||
print(model)
|
print(model)
|
||||||
print(sum(p.numel() for p in model.parameters()))
|
print(sum(p.numel() for p in model.parameters()))
|
||||||
|
|
||||||
reso = 176
|
reso = 176
|
||||||
x = np.zeros((1, 1, reso, reso))
|
x = np.zeros((1, 1, reso, reso))
|
||||||
x[:,:,int(reso/2-1), int(reso/2-1)] = np.nan
|
x[:, :, int(reso / 2 - 1), int(reso / 2 - 1)] = np.nan
|
||||||
x = torch.FloatTensor(x)
|
x = torch.FloatTensor(x)
|
||||||
|
|
||||||
out = model(x)
|
out = model(x)
|
||||||
print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso)))
|
print("%f" % (torch.sum(torch.isnan(out)).detach().cpu().numpy() / (reso * reso)))
|
||||||
|
|
||||||
# loss = torch.sum(out)
|
# loss = torch.sum(out)
|
||||||
# loss.backward()
|
# loss.backward()
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
'''
|
"""Code from the 3D UNet implementation:
|
||||||
Code from the 3D UNet implementation:
|
https://github.com/wolny/pytorch-3dunet/.
|
||||||
https://github.com/wolny/pytorch-3dunet/
|
"""
|
||||||
'''
|
|
||||||
import importlib
|
import importlib
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from functools import partial
|
|
||||||
from src.network.utils import get_embedder
|
from src.network.utils import get_embedder
|
||||||
|
|
||||||
|
|
||||||
def number_of_features_per_level(init_channel_number, num_levels):
|
def number_of_features_per_level(init_channel_number, num_levels):
|
||||||
return [init_channel_number * 2 ** k for k in range(num_levels)]
|
return [init_channel_number * 2**k for k in range(num_levels)]
|
||||||
|
|
||||||
|
|
||||||
def conv3d(in_channels, out_channels, kernel_size, bias, padding=1):
|
def conv3d(in_channels, out_channels, kernel_size, bias, padding=1):
|
||||||
|
@ -18,8 +20,7 @@ def conv3d(in_channels, out_channels, kernel_size, bias, padding=1):
|
||||||
|
|
||||||
|
|
||||||
def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1):
|
def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1):
|
||||||
"""
|
"""Create a list of modules with together constitute a single conv layer with non-linearity
|
||||||
Create a list of modules with together constitute a single conv layer with non-linearity
|
|
||||||
and optional batchnorm/groupnorm.
|
and optional batchnorm/groupnorm.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -37,23 +38,23 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi
|
||||||
Return:
|
Return:
|
||||||
list of tuple (name, module)
|
list of tuple (name, module)
|
||||||
"""
|
"""
|
||||||
assert 'c' in order, "Conv layer MUST be present"
|
assert "c" in order, "Conv layer MUST be present"
|
||||||
assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer'
|
assert order[0] not in "rle", "Non-linearity cannot be the first operation in the layer"
|
||||||
|
|
||||||
modules = []
|
modules = []
|
||||||
for i, char in enumerate(order):
|
for i, char in enumerate(order):
|
||||||
if char == 'r':
|
if char == "r":
|
||||||
modules.append(('ReLU', nn.ReLU(inplace=True)))
|
modules.append(("ReLU", nn.ReLU(inplace=True)))
|
||||||
elif char == 'l':
|
elif char == "l":
|
||||||
modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True)))
|
modules.append(("LeakyReLU", nn.LeakyReLU(negative_slope=0.1, inplace=True)))
|
||||||
elif char == 'e':
|
elif char == "e":
|
||||||
modules.append(('ELU', nn.ELU(inplace=True)))
|
modules.append(("ELU", nn.ELU(inplace=True)))
|
||||||
elif char == 'c':
|
elif char == "c":
|
||||||
# add learnable bias only in the absence of batchnorm/groupnorm
|
# add learnable bias only in the absence of batchnorm/groupnorm
|
||||||
bias = not ('g' in order or 'b' in order)
|
bias = not ("g" in order or "b" in order)
|
||||||
modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding)))
|
modules.append(("conv", conv3d(in_channels, out_channels, kernel_size, bias, padding=padding)))
|
||||||
elif char == 'g':
|
elif char == "g":
|
||||||
is_before_conv = i < order.index('c')
|
is_before_conv = i < order.index("c")
|
||||||
if is_before_conv:
|
if is_before_conv:
|
||||||
num_channels = in_channels
|
num_channels = in_channels
|
||||||
else:
|
else:
|
||||||
|
@ -63,14 +64,16 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi
|
||||||
if num_channels < num_groups:
|
if num_channels < num_groups:
|
||||||
num_groups = 1
|
num_groups = 1
|
||||||
|
|
||||||
assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}'
|
assert (
|
||||||
modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)))
|
num_channels % num_groups == 0
|
||||||
elif char == 'b':
|
), f"Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}"
|
||||||
is_before_conv = i < order.index('c')
|
modules.append(("groupnorm", nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)))
|
||||||
|
elif char == "b":
|
||||||
|
is_before_conv = i < order.index("c")
|
||||||
if is_before_conv:
|
if is_before_conv:
|
||||||
modules.append(('batchnorm', nn.BatchNorm3d(in_channels)))
|
modules.append(("batchnorm", nn.BatchNorm3d(in_channels)))
|
||||||
else:
|
else:
|
||||||
modules.append(('batchnorm', nn.BatchNorm3d(out_channels)))
|
modules.append(("batchnorm", nn.BatchNorm3d(out_channels)))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']")
|
raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']")
|
||||||
|
|
||||||
|
@ -78,9 +81,8 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi
|
||||||
|
|
||||||
|
|
||||||
class SingleConv(nn.Sequential):
|
class SingleConv(nn.Sequential):
|
||||||
"""
|
"""Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
|
||||||
Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
|
of operations can be specified via the `order` parameter.
|
||||||
of operations can be specified via the `order` parameter
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_channels (int): number of input channels
|
in_channels (int): number of input channels
|
||||||
|
@ -94,16 +96,15 @@ class SingleConv(nn.Sequential):
|
||||||
num_groups (int): number of groups for the GroupNorm
|
num_groups (int): number of groups for the GroupNorm
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8, padding=1):
|
def __init__(self, in_channels, out_channels, kernel_size=3, order="crg", num_groups=8, padding=1):
|
||||||
super(SingleConv, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding):
|
for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding):
|
||||||
self.add_module(name, module)
|
self.add_module(name, module)
|
||||||
|
|
||||||
|
|
||||||
class DoubleConv(nn.Sequential):
|
class DoubleConv(nn.Sequential):
|
||||||
"""
|
"""A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
|
||||||
A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
|
|
||||||
We use (Conv3d+ReLU+GroupNorm3d) by default.
|
We use (Conv3d+ReLU+GroupNorm3d) by default.
|
||||||
This can be changed however by providing the 'order' argument, e.g. in order
|
This can be changed however by providing the 'order' argument, e.g. in order
|
||||||
to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
|
to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
|
||||||
|
@ -123,8 +124,8 @@ class DoubleConv(nn.Sequential):
|
||||||
num_groups (int): number of groups for the GroupNorm
|
num_groups (int): number of groups for the GroupNorm
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='crg', num_groups=8):
|
def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order="crg", num_groups=8):
|
||||||
super(DoubleConv, self).__init__()
|
super().__init__()
|
||||||
if encoder:
|
if encoder:
|
||||||
# we're in the encoder path
|
# we're in the encoder path
|
||||||
conv1_in_channels = in_channels
|
conv1_in_channels = in_channels
|
||||||
|
@ -138,26 +139,27 @@ class DoubleConv(nn.Sequential):
|
||||||
conv2_in_channels, conv2_out_channels = out_channels, out_channels
|
conv2_in_channels, conv2_out_channels = out_channels, out_channels
|
||||||
|
|
||||||
# conv1
|
# conv1
|
||||||
self.add_module('SingleConv1',
|
self.add_module(
|
||||||
SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups))
|
"SingleConv1", SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups),
|
||||||
|
)
|
||||||
# conv2
|
# conv2
|
||||||
self.add_module('SingleConv2',
|
self.add_module(
|
||||||
SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups))
|
"SingleConv2", SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExtResNetBlock(nn.Module):
|
class ExtResNetBlock(nn.Module):
|
||||||
"""
|
"""Basic UNet block consisting of a SingleConv followed by the residual block.
|
||||||
Basic UNet block consisting of a SingleConv followed by the residual block.
|
|
||||||
The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number
|
The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number
|
||||||
of output channels is compatible with the residual block that follows.
|
of output channels is compatible with the residual block that follows.
|
||||||
This block can be used instead of standard DoubleConv in the Encoder module.
|
This block can be used instead of standard DoubleConv in the Encoder module.
|
||||||
Motivated by: https://arxiv.org/pdf/1706.00120.pdf
|
Motivated by: https://arxiv.org/pdf/1706.00120.pdf.
|
||||||
|
|
||||||
Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm.
|
Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs):
|
def __init__(self, in_channels, out_channels, kernel_size=3, order="cge", num_groups=8, **kwargs):
|
||||||
super(ExtResNetBlock, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
# first convolution
|
# first convolution
|
||||||
self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
|
self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
|
||||||
|
@ -165,15 +167,16 @@ class ExtResNetBlock(nn.Module):
|
||||||
self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
|
self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
|
||||||
# remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
|
# remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
|
||||||
n_order = order
|
n_order = order
|
||||||
for c in 'rel':
|
for c in "rel":
|
||||||
n_order = n_order.replace(c, '')
|
n_order = n_order.replace(c, "")
|
||||||
self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order,
|
self.conv3 = SingleConv(
|
||||||
num_groups=num_groups)
|
out_channels, out_channels, kernel_size=kernel_size, order=n_order, num_groups=num_groups,
|
||||||
|
)
|
||||||
|
|
||||||
# create non-linearity separately
|
# create non-linearity separately
|
||||||
if 'l' in order:
|
if "l" in order:
|
||||||
self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
elif 'e' in order:
|
elif "e" in order:
|
||||||
self.non_linearity = nn.ELU(inplace=True)
|
self.non_linearity = nn.ELU(inplace=True)
|
||||||
else:
|
else:
|
||||||
self.non_linearity = nn.ReLU(inplace=True)
|
self.non_linearity = nn.ReLU(inplace=True)
|
||||||
|
@ -194,12 +197,12 @@ class ExtResNetBlock(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
"""
|
"""A single module from the encoder path consisting of the optional max
|
||||||
A single module from the encoder path consisting of the optional max
|
|
||||||
pooling layer (one may specify the MaxPool kernel_size to be different
|
pooling layer (one may specify the MaxPool kernel_size to be different
|
||||||
than the standard (2,2,2), e.g. if the volumetric data is anisotropic
|
than the standard (2,2,2), e.g. if the volumetric data is anisotropic
|
||||||
(make sure to use complementary scale_factor in the decoder path) followed by
|
(make sure to use complementary scale_factor in the decoder path) followed by
|
||||||
a DoubleConv module.
|
a DoubleConv module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_channels (int): number of input channels
|
in_channels (int): number of input channels
|
||||||
out_channels (int): number of output channels
|
out_channels (int): number of output channels
|
||||||
|
@ -210,27 +213,39 @@ class Encoder(nn.Module):
|
||||||
basic_module(nn.Module): either ResNetBlock or DoubleConv
|
basic_module(nn.Module): either ResNetBlock or DoubleConv
|
||||||
conv_layer_order (string): determines the order of layers
|
conv_layer_order (string): determines the order of layers
|
||||||
in `DoubleConv` module. See `DoubleConv` for more info.
|
in `DoubleConv` module. See `DoubleConv` for more info.
|
||||||
num_groups (int): number of groups for the GroupNorm
|
num_groups (int): number of groups for the GroupNorm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True,
|
def __init__(
|
||||||
pool_kernel_size=(2, 2, 2), pool_type='max', basic_module=DoubleConv, conv_layer_order='crg',
|
self,
|
||||||
num_groups=8):
|
in_channels,
|
||||||
super(Encoder, self).__init__()
|
out_channels,
|
||||||
assert pool_type in ['max', 'avg']
|
conv_kernel_size=3,
|
||||||
|
apply_pooling=True,
|
||||||
|
pool_kernel_size=(2, 2, 2),
|
||||||
|
pool_type="max",
|
||||||
|
basic_module=DoubleConv,
|
||||||
|
conv_layer_order="crg",
|
||||||
|
num_groups=8,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert pool_type in ["max", "avg"]
|
||||||
if apply_pooling:
|
if apply_pooling:
|
||||||
if pool_type == 'max':
|
if pool_type == "max":
|
||||||
self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
|
self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
|
||||||
else:
|
else:
|
||||||
self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size)
|
self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size)
|
||||||
else:
|
else:
|
||||||
self.pooling = None
|
self.pooling = None
|
||||||
|
|
||||||
self.basic_module = basic_module(in_channels, out_channels,
|
self.basic_module = basic_module(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
encoder=True,
|
encoder=True,
|
||||||
kernel_size=conv_kernel_size,
|
kernel_size=conv_kernel_size,
|
||||||
order=conv_layer_order,
|
order=conv_layer_order,
|
||||||
num_groups=num_groups)
|
num_groups=num_groups,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.pooling is not None:
|
if self.pooling is not None:
|
||||||
|
@ -240,9 +255,9 @@ class Encoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
"""
|
"""A single module for decoder path consisting of the upsampling layer
|
||||||
A single module for decoder path consisting of the upsampling layer
|
|
||||||
(either learned ConvTranspose3d or nearest neighbor interpolation) followed by a basic module (DoubleConv or ExtResNetBlock).
|
(either learned ConvTranspose3d or nearest neighbor interpolation) followed by a basic module (DoubleConv or ExtResNetBlock).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_channels (int): number of input channels
|
in_channels (int): number of input channels
|
||||||
out_channels (int): number of output channels
|
out_channels (int): number of output channels
|
||||||
|
@ -253,32 +268,56 @@ class Decoder(nn.Module):
|
||||||
basic_module(nn.Module): either ResNetBlock or DoubleConv
|
basic_module(nn.Module): either ResNetBlock or DoubleConv
|
||||||
conv_layer_order (string): determines the order of layers
|
conv_layer_order (string): determines the order of layers
|
||||||
in `DoubleConv` module. See `DoubleConv` for more info.
|
in `DoubleConv` module. See `DoubleConv` for more info.
|
||||||
num_groups (int): number of groups for the GroupNorm
|
num_groups (int): number of groups for the GroupNorm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv,
|
def __init__(
|
||||||
conv_layer_order='crg', num_groups=8, mode='nearest'):
|
self,
|
||||||
super(Decoder, self).__init__()
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
scale_factor=(2, 2, 2),
|
||||||
|
basic_module=DoubleConv,
|
||||||
|
conv_layer_order="crg",
|
||||||
|
num_groups=8,
|
||||||
|
mode="nearest",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
if basic_module == DoubleConv:
|
if basic_module == DoubleConv:
|
||||||
# if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining
|
# if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining
|
||||||
self.upsampling = Upsampling(transposed_conv=False, in_channels=in_channels, out_channels=out_channels,
|
self.upsampling = Upsampling(
|
||||||
kernel_size=kernel_size, scale_factor=scale_factor, mode=mode)
|
transposed_conv=False,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
scale_factor=scale_factor,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
# concat joining
|
# concat joining
|
||||||
self.joining = partial(self._joining, concat=True)
|
self.joining = partial(self._joining, concat=True)
|
||||||
else:
|
else:
|
||||||
# if basic_module=ExtResNetBlock use transposed convolution upsampling and summation joining
|
# if basic_module=ExtResNetBlock use transposed convolution upsampling and summation joining
|
||||||
self.upsampling = Upsampling(transposed_conv=True, in_channels=in_channels, out_channels=out_channels,
|
self.upsampling = Upsampling(
|
||||||
kernel_size=kernel_size, scale_factor=scale_factor, mode=mode)
|
transposed_conv=True,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
scale_factor=scale_factor,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
# sum joining
|
# sum joining
|
||||||
self.joining = partial(self._joining, concat=False)
|
self.joining = partial(self._joining, concat=False)
|
||||||
# adapt the number of in_channels for the ExtResNetBlock
|
# adapt the number of in_channels for the ExtResNetBlock
|
||||||
in_channels = out_channels
|
in_channels = out_channels
|
||||||
|
|
||||||
self.basic_module = basic_module(in_channels, out_channels,
|
self.basic_module = basic_module(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
encoder=False,
|
encoder=False,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
order=conv_layer_order,
|
order=conv_layer_order,
|
||||||
num_groups=num_groups)
|
num_groups=num_groups,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, encoder_features, x):
|
def forward(self, encoder_features, x):
|
||||||
x = self.upsampling(encoder_features=encoder_features, x=x)
|
x = self.upsampling(encoder_features=encoder_features, x=x)
|
||||||
|
@ -295,8 +334,7 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Upsampling(nn.Module):
|
class Upsampling(nn.Module):
|
||||||
"""
|
"""Upsamples a given multi-channel 3D data using either interpolation or learned transposed convolution.
|
||||||
Upsamples a given multi-channel 3D data using either interpolation or learned transposed convolution.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
transposed_conv (bool): if True uses ConvTranspose3d for upsampling, otherwise uses interpolation
|
transposed_conv (bool): if True uses ConvTranspose3d for upsampling, otherwise uses interpolation
|
||||||
|
@ -310,15 +348,23 @@ class Upsampling(nn.Module):
|
||||||
'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest'
|
'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, transposed_conv, in_channels=None, out_channels=None, kernel_size=3,
|
def __init__(
|
||||||
scale_factor=(2, 2, 2), mode='nearest'):
|
self,
|
||||||
super(Upsampling, self).__init__()
|
transposed_conv,
|
||||||
|
in_channels=None,
|
||||||
|
out_channels=None,
|
||||||
|
kernel_size=3,
|
||||||
|
scale_factor=(2, 2, 2),
|
||||||
|
mode="nearest",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
if transposed_conv:
|
if transposed_conv:
|
||||||
# make sure that the output size reverses the MaxPool3d from the corresponding encoder
|
# make sure that the output size reverses the MaxPool3d from the corresponding encoder
|
||||||
# (D_out = (D_in − 1) × stride[0] − 2 × padding[0] + kernel_size[0] + output_padding[0])
|
# (D_out = (D_in - 1) x stride[0] - 2 x padding[0] + kernel_size[0] + output_padding[0])
|
||||||
self.upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor,
|
self.upsample = nn.ConvTranspose3d(
|
||||||
padding=1)
|
in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor, padding=1,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.upsample = partial(self._interpolate, mode=mode)
|
self.upsample = partial(self._interpolate, mode=mode)
|
||||||
|
|
||||||
|
@ -332,13 +378,13 @@ class Upsampling(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FinalConv(nn.Sequential):
|
class FinalConv(nn.Sequential):
|
||||||
"""
|
"""A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution
|
||||||
A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution
|
|
||||||
which reduces the number of channels to 'out_channels'.
|
which reduces the number of channels to 'out_channels'.
|
||||||
with the number of output channels 'out_channels // 2' and 'out_channels' respectively.
|
with the number of output channels 'out_channels // 2' and 'out_channels' respectively.
|
||||||
We use (Conv3d+ReLU+GroupNorm3d) by default.
|
We use (Conv3d+ReLU+GroupNorm3d) by default.
|
||||||
This can be change however by providing the 'order' argument, e.g. in order
|
This can be change however by providing the 'order' argument, e.g. in order
|
||||||
to change to Conv3d+BatchNorm3d+ReLU use order='cbr'.
|
to change to Conv3d+BatchNorm3d+ReLU use order='cbr'.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_channels (int): number of input channels
|
in_channels (int): number of input channels
|
||||||
out_channels (int): number of output channels
|
out_channels (int): number of output channels
|
||||||
|
@ -346,22 +392,22 @@ class FinalConv(nn.Sequential):
|
||||||
order (string): determines the order of layers, e.g.
|
order (string): determines the order of layers, e.g.
|
||||||
'cr' -> conv + ReLU
|
'cr' -> conv + ReLU
|
||||||
'crg' -> conv + ReLU + groupnorm
|
'crg' -> conv + ReLU + groupnorm
|
||||||
num_groups (int): number of groups for the GroupNorm
|
num_groups (int): number of groups for the GroupNorm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8):
|
def __init__(self, in_channels, out_channels, kernel_size=3, order="crg", num_groups=8):
|
||||||
super(FinalConv, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
# conv1
|
# conv1
|
||||||
self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups))
|
self.add_module("SingleConv", SingleConv(in_channels, in_channels, kernel_size, order, num_groups))
|
||||||
|
|
||||||
# in the last layer a 1×1 convolution reduces the number of output channels to out_channels
|
# in the last layer a 1x1 convolution reduces the number of output channels to out_channels
|
||||||
final_conv = nn.Conv3d(in_channels, out_channels, 1)
|
final_conv = nn.Conv3d(in_channels, out_channels, 1)
|
||||||
self.add_module('final_conv', final_conv)
|
self.add_module("final_conv", final_conv)
|
||||||
|
|
||||||
|
|
||||||
class Abstract3DUNet(nn.Module):
|
class Abstract3DUNet(nn.Module):
|
||||||
"""
|
"""Base class for standard and residual UNet.
|
||||||
Base class for standard and residual UNet.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_channels (int): number of input channels
|
in_channels (int): number of input channels
|
||||||
|
@ -391,9 +437,22 @@ class Abstract3DUNet(nn.Module):
|
||||||
and the `final_activation` (even if present) won't be applied; default: False
|
and the `final_activation` (even if present) won't be applied; default: False
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr',
|
def __init__(
|
||||||
num_groups=8, num_levels=4, is_segmentation=False, testing=False, pe_freq=0, **kwargs):
|
self,
|
||||||
super(Abstract3DUNet, self).__init__()
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
final_sigmoid,
|
||||||
|
basic_module,
|
||||||
|
f_maps=64,
|
||||||
|
layer_order="gcr",
|
||||||
|
num_groups=8,
|
||||||
|
num_levels=4,
|
||||||
|
is_segmentation=False,
|
||||||
|
testing=False,
|
||||||
|
pe_freq=0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
self.testing = testing
|
self.testing = testing
|
||||||
|
|
||||||
|
@ -411,13 +470,24 @@ class Abstract3DUNet(nn.Module):
|
||||||
encoders = []
|
encoders = []
|
||||||
for i, out_feature_num in enumerate(f_maps):
|
for i, out_feature_num in enumerate(f_maps):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=basic_module,
|
encoder = Encoder(
|
||||||
conv_layer_order=layer_order, num_groups=num_groups)
|
in_channels,
|
||||||
|
out_feature_num,
|
||||||
|
apply_pooling=False,
|
||||||
|
basic_module=basic_module,
|
||||||
|
conv_layer_order=layer_order,
|
||||||
|
num_groups=num_groups,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# TODO: adapt for anisotropy in the data, i.e. use proper pooling kernel to make the data isotropic after 1-2 pooling operations
|
# TODO: adapt for anisotropy in the data, i.e. use proper pooling kernel to make the data isotropic after 1-2 pooling operations
|
||||||
# currently pools with a constant kernel: (2, 2, 2)
|
# currently pools with a constant kernel: (2, 2, 2)
|
||||||
encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=basic_module,
|
encoder = Encoder(
|
||||||
conv_layer_order=layer_order, num_groups=num_groups)
|
f_maps[i - 1],
|
||||||
|
out_feature_num,
|
||||||
|
basic_module=basic_module,
|
||||||
|
conv_layer_order=layer_order,
|
||||||
|
num_groups=num_groups,
|
||||||
|
)
|
||||||
encoders.append(encoder)
|
encoders.append(encoder)
|
||||||
|
|
||||||
self.encoders = nn.ModuleList(encoders)
|
self.encoders = nn.ModuleList(encoders)
|
||||||
|
@ -434,13 +504,18 @@ class Abstract3DUNet(nn.Module):
|
||||||
out_feature_num = reversed_f_maps[i + 1]
|
out_feature_num = reversed_f_maps[i + 1]
|
||||||
# TODO: if non-standard pooling was used, make sure to use correct striding for transpose conv
|
# TODO: if non-standard pooling was used, make sure to use correct striding for transpose conv
|
||||||
# currently strides with a constant stride: (2, 2, 2)
|
# currently strides with a constant stride: (2, 2, 2)
|
||||||
decoder = Decoder(in_feature_num, out_feature_num, basic_module=basic_module,
|
decoder = Decoder(
|
||||||
conv_layer_order=layer_order, num_groups=num_groups)
|
in_feature_num,
|
||||||
|
out_feature_num,
|
||||||
|
basic_module=basic_module,
|
||||||
|
conv_layer_order=layer_order,
|
||||||
|
num_groups=num_groups,
|
||||||
|
)
|
||||||
decoders.append(decoder)
|
decoders.append(decoder)
|
||||||
|
|
||||||
self.decoders = nn.ModuleList(decoders)
|
self.decoders = nn.ModuleList(decoders)
|
||||||
|
|
||||||
# in the last layer a 1×1 convolution reduces the number of output
|
# in the last layer a 1x1 convolution reduces the number of output
|
||||||
# channels to the number of labels
|
# channels to the number of labels
|
||||||
self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)
|
self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)
|
||||||
|
|
||||||
|
@ -455,7 +530,6 @@ class Abstract3DUNet(nn.Module):
|
||||||
self.final_activation = None
|
self.final_activation = None
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
||||||
if self.embed_fn is not None:
|
if self.embed_fn is not None:
|
||||||
x = self.embed_fn(x.permute(0, 2, 3, 4, 1))
|
x = self.embed_fn(x.permute(0, 2, 3, 4, 1))
|
||||||
x = x.permute(0, 4, 1, 2, 3)
|
x = x.permute(0, 4, 1, 2, 3)
|
||||||
|
@ -488,49 +562,81 @@ class Abstract3DUNet(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class UNet3D(Abstract3DUNet):
|
class UNet3D(Abstract3DUNet):
|
||||||
"""
|
"""3DUnet model from
|
||||||
3DUnet model from
|
|
||||||
`"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation"
|
`"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation"
|
||||||
<https://arxiv.org/pdf/1606.06650.pdf>`.
|
<https://arxiv.org/pdf/1606.06650.pdf>`.
|
||||||
|
|
||||||
Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder
|
Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
|
def __init__(
|
||||||
num_groups=8, num_levels=4, is_segmentation=True, **kwargs):
|
self,
|
||||||
super(UNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid,
|
in_channels,
|
||||||
basic_module=DoubleConv, f_maps=f_maps, layer_order=layer_order,
|
out_channels,
|
||||||
num_groups=num_groups, num_levels=num_levels, is_segmentation=is_segmentation,
|
final_sigmoid=True,
|
||||||
**kwargs)
|
f_maps=64,
|
||||||
|
layer_order="gcr",
|
||||||
|
num_groups=8,
|
||||||
|
num_levels=4,
|
||||||
|
is_segmentation=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
final_sigmoid=final_sigmoid,
|
||||||
|
basic_module=DoubleConv,
|
||||||
|
f_maps=f_maps,
|
||||||
|
layer_order=layer_order,
|
||||||
|
num_groups=num_groups,
|
||||||
|
num_levels=num_levels,
|
||||||
|
is_segmentation=is_segmentation,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ResidualUNet3D(Abstract3DUNet):
|
class ResidualUNet3D(Abstract3DUNet):
|
||||||
"""
|
"""Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
|
||||||
Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
|
|
||||||
Uses ExtResNetBlock as a basic building block, summation joining instead
|
Uses ExtResNetBlock as a basic building block, summation joining instead
|
||||||
of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts).
|
of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts).
|
||||||
Since the model effectively becomes a residual net, in theory it allows for deeper UNet.
|
Since the model effectively becomes a residual net, in theory it allows for deeper UNet.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
|
def __init__(
|
||||||
num_groups=8, num_levels=5, is_segmentation=True, **kwargs):
|
self,
|
||||||
super(ResidualUNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels,
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
final_sigmoid=True,
|
||||||
|
f_maps=64,
|
||||||
|
layer_order="gcr",
|
||||||
|
num_groups=8,
|
||||||
|
num_levels=5,
|
||||||
|
is_segmentation=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
final_sigmoid=final_sigmoid,
|
final_sigmoid=final_sigmoid,
|
||||||
basic_module=ExtResNetBlock, f_maps=f_maps, layer_order=layer_order,
|
basic_module=ExtResNetBlock,
|
||||||
num_groups=num_groups, num_levels=num_levels,
|
f_maps=f_maps,
|
||||||
|
layer_order=layer_order,
|
||||||
|
num_groups=num_groups,
|
||||||
|
num_levels=num_levels,
|
||||||
is_segmentation=is_segmentation,
|
is_segmentation=is_segmentation,
|
||||||
**kwargs)
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_model(config):
|
def get_model(config):
|
||||||
def _model_class(class_name):
|
def _model_class(class_name):
|
||||||
m = importlib.import_module('pytorch3dunet.unet3d.model')
|
m = importlib.import_module("pytorch3dunet.unet3d.model")
|
||||||
clazz = getattr(m, class_name)
|
clazz = getattr(m, class_name)
|
||||||
return clazz
|
return clazz
|
||||||
|
|
||||||
assert 'model' in config, 'Could not find model configuration'
|
assert "model" in config, "Could not find model configuration"
|
||||||
model_config = config['model']
|
model_config = config["model"]
|
||||||
model_class = _model_class(model_config['name'])
|
model_class = _model_class(model_config["name"])
|
||||||
return model_class(**model_config)
|
return model_class(**model_config)
|
||||||
|
|
||||||
|
|
||||||
|
@ -542,18 +648,18 @@ if __name__ == "__main__":
|
||||||
out_channels = 1
|
out_channels = 1
|
||||||
f_maps = 32
|
f_maps = 32
|
||||||
num_levels = 2
|
num_levels = 2
|
||||||
model = UNet3D(in_channels, out_channels, f_maps=f_maps, num_levels=num_levels, layer_order='cr')
|
model = UNet3D(in_channels, out_channels, f_maps=f_maps, num_levels=num_levels, layer_order="cr")
|
||||||
print(model)
|
print(model)
|
||||||
print('number of parameters: ', sum(p.numel() for p in model.parameters()))
|
print("number of parameters: ", sum(p.numel() for p in model.parameters()))
|
||||||
|
|
||||||
reso = 18
|
reso = 18
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
x = np.zeros((1, 1, reso, reso, reso))
|
x = np.zeros((1, 1, reso, reso, reso))
|
||||||
x[:,:, int(reso/2-1), int(reso/2-1), int(reso/2-1)] = np.nan
|
x[:, :, int(reso / 2 - 1), int(reso / 2 - 1), int(reso / 2 - 1)] = np.nan
|
||||||
x = torch.FloatTensor(x)
|
x = torch.FloatTensor(x)
|
||||||
|
|
||||||
out = model(x)
|
out = model(x)
|
||||||
print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso*reso)))
|
print("%f" % (torch.sum(torch.isnan(out)).detach().cpu().numpy() / (reso * reso * reso)))
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
""" Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """
|
"""Positional encoding embedding. Code was taken from https://github.com/bmild/nerf."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
class Embedder:
|
class Embedder:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
@ -10,24 +11,23 @@ class Embedder:
|
||||||
|
|
||||||
def create_embedding_fn(self):
|
def create_embedding_fn(self):
|
||||||
embed_fns = []
|
embed_fns = []
|
||||||
d = self.kwargs['input_dims']
|
d = self.kwargs["input_dims"]
|
||||||
out_dim = 0
|
out_dim = 0
|
||||||
if self.kwargs['include_input']:
|
if self.kwargs["include_input"]:
|
||||||
embed_fns.append(lambda x: x)
|
embed_fns.append(lambda x: x)
|
||||||
out_dim += d
|
out_dim += d
|
||||||
|
|
||||||
max_freq = self.kwargs['max_freq_log2']
|
max_freq = self.kwargs["max_freq_log2"]
|
||||||
N_freqs = self.kwargs['num_freqs']
|
N_freqs = self.kwargs["num_freqs"]
|
||||||
|
|
||||||
if self.kwargs['log_sampling']:
|
if self.kwargs["log_sampling"]:
|
||||||
freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
|
freq_bands = 2.0 ** torch.linspace(0.0, max_freq, N_freqs)
|
||||||
else:
|
else:
|
||||||
freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
|
freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, N_freqs)
|
||||||
|
|
||||||
for freq in freq_bands:
|
for freq in freq_bands:
|
||||||
for p_fn in self.kwargs['periodic_fns']:
|
for p_fn in self.kwargs["periodic_fns"]:
|
||||||
embed_fns.append(lambda x, p_fn=p_fn,
|
embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
|
||||||
freq=freq: p_fn(x * freq))
|
|
||||||
out_dim += d
|
out_dim += d
|
||||||
|
|
||||||
self.embed_fns = embed_fns
|
self.embed_fns = embed_fns
|
||||||
|
@ -36,31 +36,36 @@ class Embedder:
|
||||||
def embed(self, inputs):
|
def embed(self, inputs):
|
||||||
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
|
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
|
||||||
|
|
||||||
|
|
||||||
def get_embedder(multires, d_in=3):
|
def get_embedder(multires, d_in=3):
|
||||||
embed_kwargs = {
|
embed_kwargs = {
|
||||||
'include_input': True,
|
"include_input": True,
|
||||||
'input_dims': d_in,
|
"input_dims": d_in,
|
||||||
'max_freq_log2': multires-1,
|
"max_freq_log2": multires - 1,
|
||||||
'num_freqs': multires,
|
"num_freqs": multires,
|
||||||
'log_sampling': True,
|
"log_sampling": True,
|
||||||
'periodic_fns': [torch.sin, torch.cos],
|
"periodic_fns": [torch.sin, torch.cos],
|
||||||
}
|
}
|
||||||
|
|
||||||
embedder_obj = Embedder(**embed_kwargs)
|
embedder_obj = Embedder(**embed_kwargs)
|
||||||
def embed(x, eo=embedder_obj): return eo.embed(x)
|
|
||||||
|
def embed(x, eo=embedder_obj):
|
||||||
|
return eo.embed(x)
|
||||||
|
|
||||||
return embed, embedder_obj.out_dim
|
return embed, embedder_obj.out_dim
|
||||||
|
|
||||||
def normalize_coordinate(p, plane='xz'):
|
|
||||||
''' Normalize coordinate to [0, 1] for unit cube experiments
|
def normalize_coordinate(p, plane="xz"):
|
||||||
|
"""Normalize coordinate to [0, 1] for unit cube experiments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
p (tensor): point
|
p (tensor): point
|
||||||
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
||||||
plane (str): plane feature type, ['xz', 'xy', 'yz']
|
plane (str): plane feature type, ['xz', 'xy', 'yz']
|
||||||
'''
|
"""
|
||||||
if plane == 'xz':
|
if plane == "xz":
|
||||||
xy = p[:, :, [0, 2]]
|
xy = p[:, :, [0, 2]]
|
||||||
elif plane =='xy':
|
elif plane == "xy":
|
||||||
xy = p[:, :, [0, 1]]
|
xy = p[:, :, [0, 1]]
|
||||||
else:
|
else:
|
||||||
xy = p[:, :, [1, 2]]
|
xy = p[:, :, [1, 2]]
|
||||||
|
@ -75,40 +80,41 @@ def normalize_coordinate(p, plane='xz'):
|
||||||
|
|
||||||
|
|
||||||
def normalize_3d_coordinate(p):
|
def normalize_3d_coordinate(p):
|
||||||
''' Normalize coordinate to [0, 1] for unit cube experiments.
|
"""Normalize coordinate to [0, 1] for unit cube experiments."""
|
||||||
'''
|
|
||||||
if p.max() >= 1:
|
if p.max() >= 1:
|
||||||
p[p >= 1] = 1 - 10e-6
|
p[p >= 1] = 1 - 10e-6
|
||||||
if p.min() < 0:
|
if p.min() < 0:
|
||||||
p[p < 0] = 0.0
|
p[p < 0] = 0.0
|
||||||
return p
|
return p
|
||||||
|
|
||||||
def coordinate2index(x, reso, coord_type='2d'):
|
|
||||||
''' Normalize coordinate to [0, 1] for unit cube experiments.
|
def coordinate2index(x, reso, coord_type="2d"):
|
||||||
Corresponds to our 3D model
|
"""Normalize coordinate to [0, 1] for unit cube experiments.
|
||||||
|
Corresponds to our 3D model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (tensor): coordinate
|
x (tensor): coordinate
|
||||||
reso (int): defined resolution
|
reso (int): defined resolution
|
||||||
coord_type (str): coordinate type
|
coord_type (str): coordinate type
|
||||||
'''
|
"""
|
||||||
x = (x * reso).long()
|
x = (x * reso).long()
|
||||||
if coord_type == '2d': # plane
|
if coord_type == "2d": # plane
|
||||||
index = x[:, :, 0] + reso * x[:, :, 1]
|
index = x[:, :, 0] + reso * x[:, :, 1]
|
||||||
elif coord_type == '3d': # grid
|
elif coord_type == "3d": # grid
|
||||||
index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2])
|
index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2])
|
||||||
index = index[:, None, :]
|
index = index[:, None, :]
|
||||||
return index
|
return index
|
||||||
|
|
||||||
|
|
||||||
class map2local(object):
|
class map2local:
|
||||||
''' Add new keys to the given input
|
"""Add new keys to the given input.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
s (float): the defined voxel size
|
s (float): the defined voxel size
|
||||||
pos_encoding (str): method for the positional encoding, linear|sin_cos
|
pos_encoding (str): method for the positional encoding, linear|sin_cos
|
||||||
'''
|
"""
|
||||||
def __init__(self, s, pos_encoding='linear'):
|
|
||||||
|
def __init__(self, s, pos_encoding="linear"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.s = s
|
self.s = s
|
||||||
# self.pe = positional_encoding(basis_function=pos_encoding, local=True)
|
# self.pe = positional_encoding(basis_function=pos_encoding, local=True)
|
||||||
|
@ -121,15 +127,16 @@ class map2local(object):
|
||||||
# p = self.pe(p)
|
# p = self.pe(p)
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
# Resnet Blocks
|
# Resnet Blocks
|
||||||
class ResnetBlockFC(nn.Module):
|
class ResnetBlockFC(nn.Module):
|
||||||
''' Fully connected ResNet Block class.
|
"""Fully connected ResNet Block class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
size_in (int): input dimension
|
size_in (int): input dimension
|
||||||
size_out (int): output dimension
|
size_out (int): output dimension
|
||||||
size_h (int): hidden dimension
|
size_h (int): hidden dimension
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, size_in, size_out=None, size_h=None, siren=False):
|
def __init__(self, size_in, size_out=None, size_h=None, siren=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -1,54 +1,56 @@
|
||||||
import time, os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import open3d as o3d
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
|
||||||
import trimesh
|
import trimesh
|
||||||
|
from pytorch3d.loss import chamfer_distance
|
||||||
|
from torchvision.io import write_video
|
||||||
|
from torchvision.utils import save_image
|
||||||
|
|
||||||
from src.dpsr import DPSR
|
from src.dpsr import DPSR
|
||||||
from src.model import PSR2Mesh
|
from src.model import PSR2Mesh
|
||||||
from src.utils import grid_interp, verts_on_largest_mesh,\
|
from src.utils import export_pointcloud, mc_from_psr, verts_on_largest_mesh
|
||||||
export_pointcloud, mc_from_psr, GaussianSmoothing
|
from src.visualize import (
|
||||||
from src.visualize import visualize_points_mesh, visualize_psr_grid, \
|
render_rgb,
|
||||||
visualize_mesh_phong, render_rgb
|
visualize_mesh_phong,
|
||||||
from torchvision.utils import save_image
|
visualize_points_mesh,
|
||||||
from torchvision.io import write_video
|
visualize_psr_grid,
|
||||||
from pytorch3d.loss import chamfer_distance
|
)
|
||||||
import open3d as o3d
|
|
||||||
|
|
||||||
class Trainer(object):
|
|
||||||
'''
|
class Trainer:
|
||||||
Args:
|
"""Args:
|
||||||
cfg : config file
|
cfg : config file
|
||||||
optimizer : pytorch optimizer object
|
optimizer : pytorch optimizer object
|
||||||
device : pytorch device
|
device : pytorch device.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg, optimizer, device=None):
|
def __init__(self, cfg, optimizer, device=None):
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.device = device
|
self.device = device
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.psr2mesh = PSR2Mesh.apply
|
self.psr2mesh = PSR2Mesh.apply
|
||||||
self.data_type = cfg['data']['data_type']
|
self.data_type = cfg["data"]["data_type"]
|
||||||
|
|
||||||
# initialize DPSR
|
# initialize DPSR
|
||||||
self.dpsr = DPSR(res=(cfg['model']['grid_res'],
|
self.dpsr = DPSR(
|
||||||
cfg['model']['grid_res'],
|
res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
|
||||||
cfg['model']['grid_res']),
|
sig=cfg["model"]["psr_sigma"],
|
||||||
sig=cfg['model']['psr_sigma'])
|
)
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR
|
self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR
|
||||||
self.dpsr = self.dpsr.to(device)
|
self.dpsr = self.dpsr.to(device)
|
||||||
|
|
||||||
def train_step(self, data, inputs, model, it):
|
def train_step(self, data, inputs, model, it):
|
||||||
''' Performs a training step.
|
"""Performs a training step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (dict) : data dictionary
|
data (dict) : data dictionary
|
||||||
inputs (torch.tensor) : input point clouds
|
inputs (torch.tensor) : input point clouds
|
||||||
model (nn.Module or None): a neural network or None
|
model (nn.Module or None): a neural network or None
|
||||||
it (int) : the number of iterations
|
it (int) : the number of iterations
|
||||||
'''
|
"""
|
||||||
|
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
loss, loss_each = self.compute_loss(inputs, data, model, it)
|
loss, loss_each = self.compute_loss(inputs, data, model, it)
|
||||||
|
|
||||||
|
@ -58,16 +60,15 @@ class Trainer(object):
|
||||||
return loss.item(), loss_each
|
return loss.item(), loss_each
|
||||||
|
|
||||||
def compute_loss(self, inputs, data, model, it=0):
|
def compute_loss(self, inputs, data, model, it=0):
|
||||||
''' Compute the loss.
|
"""Compute the loss.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (dict) : data dictionary
|
data (dict) : data dictionary
|
||||||
inputs (torch.tensor) : input point clouds
|
inputs (torch.tensor) : input point clouds
|
||||||
model (nn.Module or None): a neural network or None
|
model (nn.Module or None): a neural network or None
|
||||||
it (int) : the number of iterations
|
it (int) : the number of iterations.
|
||||||
'''
|
"""
|
||||||
|
res = self.cfg["model"]["grid_res"]
|
||||||
device = self.device
|
|
||||||
res = self.cfg['model']['grid_res']
|
|
||||||
|
|
||||||
# source oriented point clouds to PSR grid
|
# source oriented point clouds to PSR grid
|
||||||
psr_grid, points, normals = self.pcl2psr(inputs)
|
psr_grid, points, normals = self.pcl2psr(inputs)
|
||||||
|
@ -77,36 +78,33 @@ class Trainer(object):
|
||||||
|
|
||||||
# the output is in the range of [0, 1), we make it to the real range [0, 1].
|
# the output is in the range of [0, 1), we make it to the real range [0, 1].
|
||||||
# This is a hack for our DPSR solver
|
# This is a hack for our DPSR solver
|
||||||
v = v * res / (res-1)
|
v = v * res / (res - 1)
|
||||||
|
|
||||||
points = points * 2. - 1.
|
points = points * 2.0 - 1.0
|
||||||
v = v * 2. - 1. # within the range of (-1, 1)
|
v = v * 2.0 - 1.0 # within the range of (-1, 1)
|
||||||
|
|
||||||
loss = 0
|
loss = 0
|
||||||
loss_each = {}
|
loss_each = {}
|
||||||
# compute loss
|
# compute loss
|
||||||
if self.data_type == 'point':
|
if self.data_type == "point":
|
||||||
if self.cfg['train']['w_chamfer'] > 0:
|
if self.cfg["train"]["w_chamfer"] > 0:
|
||||||
loss_ = self.cfg['train']['w_chamfer'] * \
|
loss_ = self.cfg["train"]["w_chamfer"] * self.compute_3d_loss(v, data)
|
||||||
self.compute_3d_loss(v, data)
|
loss_each["chamfer"] = loss_
|
||||||
loss_each['chamfer'] = loss_
|
|
||||||
loss += loss_
|
loss += loss_
|
||||||
elif self.data_type == 'img':
|
elif self.data_type == "img":
|
||||||
loss, loss_each = self.compute_2d_loss(inputs, data, model)
|
loss, loss_each = self.compute_2d_loss(inputs, data, model)
|
||||||
|
|
||||||
return loss, loss_each
|
return loss, loss_each
|
||||||
|
|
||||||
|
|
||||||
def pcl2psr(self, inputs):
|
def pcl2psr(self, inputs):
|
||||||
''' Convert an oriented point cloud to PSR indicator grid
|
"""Convert an oriented point cloud to PSR indicator grid
|
||||||
Args:
|
Args:
|
||||||
inputs (torch.tensor): input oriented point clouds
|
inputs (torch.tensor): input oriented point clouds.
|
||||||
'''
|
"""
|
||||||
|
points, normals = inputs[..., :3], inputs[..., 3:]
|
||||||
points, normals = inputs[...,:3], inputs[...,3:]
|
if self.cfg["model"]["apply_sigmoid"]:
|
||||||
if self.cfg['model']['apply_sigmoid']:
|
|
||||||
points = torch.sigmoid(points)
|
points = torch.sigmoid(points)
|
||||||
if self.cfg['model']['normal_normalize']:
|
if self.cfg["model"]["normal_normalize"]:
|
||||||
normals = normals / normals.norm(dim=-1, keepdim=True)
|
normals = normals / normals.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
# DPSR to get grid
|
# DPSR to get grid
|
||||||
|
@ -116,17 +114,17 @@ class Trainer(object):
|
||||||
return psr_grid, points, normals
|
return psr_grid, points, normals
|
||||||
|
|
||||||
def compute_3d_loss(self, v, data):
|
def compute_3d_loss(self, v, data):
|
||||||
''' Compute the loss for point clouds.
|
"""Compute the loss for point clouds.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
v (torch.tensor) : mesh vertices
|
v (torch.tensor) : mesh vertices
|
||||||
data (dict) : data dictionary
|
data (dict) : data dictionary.
|
||||||
'''
|
"""
|
||||||
|
pts_gt = data.get("target_points")
|
||||||
pts_gt = data.get('target_points')
|
idx = np.random.randint(pts_gt.shape[1], size=self.cfg["train"]["n_sup_point"])
|
||||||
idx = np.random.randint(pts_gt.shape[1], size=self.cfg['train']['n_sup_point'])
|
if self.cfg["train"]["subsample_vertex"]:
|
||||||
if self.cfg['train']['subsample_vertex']:
|
# chamfer distance only on random sampled vertices
|
||||||
#chamfer distance only on random sampled vertices
|
idx = np.random.randint(v.shape[1], size=self.cfg["train"]["n_sup_point"])
|
||||||
idx = np.random.randint(v.shape[1], size=self.cfg['train']['n_sup_point'])
|
|
||||||
loss, _ = chamfer_distance(v[:, idx], pts_gt)
|
loss, _ = chamfer_distance(v[:, idx], pts_gt)
|
||||||
else:
|
else:
|
||||||
loss, _ = chamfer_distance(v, pts_gt)
|
loss, _ = chamfer_distance(v, pts_gt)
|
||||||
|
@ -134,34 +132,31 @@ class Trainer(object):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def compute_2d_loss(self, inputs, data, model):
|
def compute_2d_loss(self, inputs, data, model):
|
||||||
''' Compute the 2D losses.
|
"""Compute the 2D losses.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (torch.tensor) : input source point clouds
|
inputs (torch.tensor) : input source point clouds
|
||||||
data (dict) : data dictionary
|
data (dict) : data dictionary
|
||||||
model (nn.Module or None): neural network or None
|
model (nn.Module or None): neural network or None.
|
||||||
'''
|
"""
|
||||||
|
losses = {
|
||||||
losses = {"color":
|
"color": {
|
||||||
{"weight": self.cfg['train']['l_weight']['rgb'],
|
"weight": self.cfg["train"]["l_weight"]["rgb"],
|
||||||
"values": []
|
"values": [],
|
||||||
},
|
},
|
||||||
"silhouette":
|
"silhouette": {"weight": self.cfg["train"]["l_weight"]["mask"], "values": []},
|
||||||
{"weight": self.cfg['train']['l_weight']['mask'],
|
|
||||||
"values": []},
|
|
||||||
}
|
}
|
||||||
loss_all = {k: torch.tensor(0.0, device=self.device) for k in losses}
|
loss_all = {k: torch.tensor(0.0, device=self.device) for k in losses}
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
out = model(inputs, data)
|
out = model(inputs, data)
|
||||||
|
|
||||||
if out['rgb'] is not None:
|
if out["rgb"] is not None:
|
||||||
rgb_gt = out['rgb_gt'].reshape(self.cfg['data']['n_views_per_iter'],
|
rgb_gt = out["rgb_gt"].reshape(self.cfg["data"]["n_views_per_iter"], -1, 3)[out["vis_mask"]]
|
||||||
-1, 3)[out['vis_mask']]
|
loss_all["color"] += torch.nn.L1Loss(reduction="sum")(rgb_gt, out["rgb"]) / out["rgb"].shape[0]
|
||||||
loss_all["color"] += torch.nn.L1Loss(reduction='sum')(rgb_gt,
|
|
||||||
out['rgb']) / out['rgb'].shape[0]
|
|
||||||
|
|
||||||
if out['mask'] is not None:
|
if out["mask"] is not None:
|
||||||
loss_all["silhouette"] += ((out['mask'] - out['mask_gt']) ** 2).mean()
|
loss_all["silhouette"] += ((out["mask"] - out["mask_gt"]) ** 2).mean()
|
||||||
|
|
||||||
# weighted sum of the losses
|
# weighted sum of the losses
|
||||||
loss = torch.tensor(0.0, device=self.device)
|
loss = torch.tensor(0.0, device=self.device)
|
||||||
|
@ -172,15 +167,14 @@ class Trainer(object):
|
||||||
return loss, loss_all
|
return loss, loss_all
|
||||||
|
|
||||||
def point_resampling(self, inputs):
|
def point_resampling(self, inputs):
|
||||||
''' Resample points
|
"""Resample points
|
||||||
Args:
|
Args:
|
||||||
inputs (torch.tensor): oriented point clouds
|
inputs (torch.tensor): oriented point clouds.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
psr_grid, points, normals = self.pcl2psr(inputs)
|
psr_grid, points, normals = self.pcl2psr(inputs)
|
||||||
|
|
||||||
# shortcuts
|
# shortcuts
|
||||||
n_grow = self.cfg['train']['n_grow_points']
|
n_grow = self.cfg["train"]["n_grow_points"]
|
||||||
|
|
||||||
# [hack] for points resampled from the mesh from marching cubes,
|
# [hack] for points resampled from the mesh from marching cubes,
|
||||||
# we need to divide by s instead of (s-1), and the scale is correct.
|
# we need to divide by s instead of (s-1), and the scale is correct.
|
||||||
|
@ -191,13 +185,13 @@ class Trainer(object):
|
||||||
|
|
||||||
# sample vertices only from the largest component, not from fragments
|
# sample vertices only from the largest component, not from fragments
|
||||||
mesh = trimesh.Trimesh(vertices=pts_mesh, faces=faces_mesh)
|
mesh = trimesh.Trimesh(vertices=pts_mesh, faces=faces_mesh)
|
||||||
pi, face_idx = mesh.sample(n_grow+points.shape[1], return_index=True)
|
pi, face_idx = mesh.sample(n_grow + points.shape[1], return_index=True)
|
||||||
normals_i = mesh.face_normals[face_idx].astype('float32')
|
normals_i = mesh.face_normals[face_idx].astype("float32")
|
||||||
pts_mesh = torch.tensor(pi.astype('float32')).to(self.device)[None]
|
pts_mesh = torch.tensor(pi.astype("float32")).to(self.device)[None]
|
||||||
n_mesh = torch.tensor(normals_i).to(self.device)[None]
|
n_mesh = torch.tensor(normals_i).to(self.device)[None]
|
||||||
|
|
||||||
points, normals = pts_mesh, n_mesh
|
points, normals = pts_mesh, n_mesh
|
||||||
print('{} total points are resampled'.format(points.shape[1]))
|
print(f"{points.shape[1]} total points are resampled")
|
||||||
|
|
||||||
# update inputs
|
# update inputs
|
||||||
points = torch.log(points / (1 - points)) # inverse sigmoid
|
points = torch.log(points / (1 - points)) # inverse sigmoid
|
||||||
|
@ -207,119 +201,112 @@ class Trainer(object):
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def visualize(self, data, inputs, renderer, epoch, o3d_vis=None):
|
def visualize(self, data, inputs, renderer, epoch, o3d_vis=None):
|
||||||
''' Visualization.
|
"""Visualization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (dict) : data dictionary
|
data (dict) : data dictionary
|
||||||
inputs (torch.tensor) : source point clouds
|
inputs (torch.tensor) : source point clouds
|
||||||
renderer (nn.Module or None): a neural network or None
|
renderer (nn.Module or None): a neural network or None
|
||||||
epoch (int) : the number of iterations
|
epoch (int) : the number of iterations
|
||||||
o3d_vis (o3d.Visualizer) : open3d visualizer
|
o3d_vis (o3d.Visualizer) : open3d visualizer.
|
||||||
'''
|
"""
|
||||||
|
data_type = self.cfg["data"]["data_type"]
|
||||||
|
it = "{:04d}".format(int(epoch / self.cfg["train"]["visualize_every"]))
|
||||||
|
|
||||||
data_type = self.cfg['data']['data_type']
|
if (self.cfg["train"]["exp_mesh"]) | (self.cfg["train"]["exp_pcl"]) | (self.cfg["train"]["o3d_show"]):
|
||||||
it = '{:04d}'.format(int(epoch/self.cfg['train']['visualize_every']))
|
|
||||||
|
|
||||||
|
|
||||||
if (self.cfg['train']['exp_mesh']) \
|
|
||||||
| (self.cfg['train']['exp_pcl']) \
|
|
||||||
| (self.cfg['train']['o3d_show']):
|
|
||||||
psr_grid, points, normals = self.pcl2psr(inputs)
|
psr_grid, points, normals = self.pcl2psr(inputs)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
v, f, n = mc_from_psr(psr_grid, pytorchify=True,
|
v, f, n = mc_from_psr(
|
||||||
zero_level=self.cfg['data']['zero_level'], real_scale=True)
|
psr_grid, pytorchify=True, zero_level=self.cfg["data"]["zero_level"], real_scale=True,
|
||||||
|
)
|
||||||
v, f, n = v[None], f[None], n[None]
|
v, f, n = v[None], f[None], n[None]
|
||||||
|
|
||||||
v = v * 2. - 1. # change to the range of [-1, 1]
|
v = v * 2.0 - 1.0 # change to the range of [-1, 1]
|
||||||
|
|
||||||
color_v = None
|
color_v = None
|
||||||
if data_type == 'img':
|
if data_type == "img":
|
||||||
if self.cfg['train']['vis_vert_color'] & \
|
if self.cfg["train"]["vis_vert_color"] & (self.cfg["train"]["l_weight"]["rgb"] != 0.0):
|
||||||
(self.cfg['train']['l_weight']['rgb'] != 0.):
|
color_v = renderer["color"](v, n).squeeze().detach().cpu().numpy()
|
||||||
color_v = renderer['color'](v, n).squeeze().detach().cpu().numpy()
|
color_v[color_v < 0], color_v[color_v > 1] = 0.0, 1.0
|
||||||
color_v[color_v<0], color_v[color_v>1] = 0., 1.
|
|
||||||
|
|
||||||
vv = v.detach().squeeze().cpu().numpy()
|
vv = v.detach().squeeze().cpu().numpy()
|
||||||
ff = f.detach().squeeze().cpu().numpy()
|
ff = f.detach().squeeze().cpu().numpy()
|
||||||
points = points * 2 - 1
|
points = points * 2 - 1
|
||||||
visualize_points_mesh(o3d_vis, points, normals,
|
visualize_points_mesh(o3d_vis, points, normals, vv, ff, self.cfg, it, epoch, color_v=color_v)
|
||||||
vv, ff, self.cfg, it, epoch, color_v=color_v)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
v, f, n = inputs
|
v, f, n = inputs
|
||||||
|
|
||||||
|
if (data_type == "img") & (self.cfg["train"]["vis_rendering"]):
|
||||||
if (data_type == 'img') & (self.cfg['train']['vis_rendering']):
|
|
||||||
pred_imgs = []
|
pred_imgs = []
|
||||||
pred_masks = []
|
len(data["poses"])
|
||||||
n_views = len(data['poses'])
|
|
||||||
# idx_list = trange(n_views)
|
# idx_list = trange(n_views)
|
||||||
idx_list = [13, 24, 27, 48]
|
idx_list = [13, 24, 27, 48]
|
||||||
|
|
||||||
#!
|
#!
|
||||||
model = renderer.eval()
|
model = renderer.eval()
|
||||||
for idx in idx_list:
|
for idx in idx_list:
|
||||||
pose = data['poses'][idx]
|
pose = data["poses"][idx]
|
||||||
rgb = data['rgbs'][idx]
|
rgb = data["rgbs"][idx]
|
||||||
mask_gt = data['masks'][idx]
|
data["masks"][idx]
|
||||||
img_size = rgb.shape[0] if rgb.shape[0]== rgb.shape[1] else (rgb.shape[0], rgb.shape[1])
|
img_size = rgb.shape[0] if rgb.shape[0] == rgb.shape[1] else (rgb.shape[0], rgb.shape[1])
|
||||||
ray = None
|
ray = None
|
||||||
if 'rays' in data.keys():
|
if "rays" in data.keys():
|
||||||
ray = data['rays'][idx]
|
ray = data["rays"][idx]
|
||||||
if self.cfg['train']['l_weight']['rgb'] != 0.:
|
if self.cfg["train"]["l_weight"]["rgb"] != 0.0:
|
||||||
fea_grid = None
|
fea_grid = None
|
||||||
if model.unet3d is not None:
|
if model.unet3d is not None:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
fea_grid = model.unet3d(psr_grid).permute(0, 2, 3, 4, 1)
|
fea_grid = model.unet3d(psr_grid).permute(0, 2, 3, 4, 1)
|
||||||
if model.encoder is not None:
|
if model.encoder is not None:
|
||||||
pp = torch.cat([(points+1)/2, normals], dim=-1)
|
pp = torch.cat([(points + 1) / 2, normals], dim=-1)
|
||||||
fea_grid = model.encoder(pp,
|
fea_grid = model.encoder(pp, normalize=False).permute(0, 2, 3, 4, 1)
|
||||||
normalize=False).permute(0, 2, 3, 4, 1)
|
|
||||||
|
|
||||||
pred, visible_mask = render_rgb(v, f, n, pose,
|
pred, visible_mask = render_rgb(
|
||||||
model.rendering_network.eval(),
|
v, f, n, pose, model.rendering_network.eval(), img_size, ray=ray, fea_grid=fea_grid,
|
||||||
img_size, ray=ray, fea_grid=fea_grid)
|
)
|
||||||
img_pred = torch.zeros([rgb.shape[0]*rgb.shape[1], 3])
|
img_pred = torch.zeros([rgb.shape[0] * rgb.shape[1], 3])
|
||||||
img_pred[visible_mask] = pred.detach().cpu()
|
img_pred[visible_mask] = pred.detach().cpu()
|
||||||
|
|
||||||
img_pred = img_pred.reshape(rgb.shape[0], rgb.shape[1], 3)
|
img_pred = img_pred.reshape(rgb.shape[0], rgb.shape[1], 3)
|
||||||
img_pred[img_pred<0], img_pred[img_pred>1] = 0., 1.
|
img_pred[img_pred < 0], img_pred[img_pred > 1] = 0.0, 1.0
|
||||||
filename=os.path.join(self.cfg['train']['dir_rendering'],
|
filename = os.path.join(self.cfg["train"]["dir_rendering"], f"rendering_{it}_{idx:d}.png")
|
||||||
'rendering_{}_{:d}.png'.format(it, idx))
|
|
||||||
save_image(img_pred.permute(2, 0, 1), filename)
|
save_image(img_pred.permute(2, 0, 1), filename)
|
||||||
pred_imgs.append(img_pred)
|
pred_imgs.append(img_pred)
|
||||||
|
|
||||||
#! Mesh rendering using Phong shading model
|
#! Mesh rendering using Phong shading model
|
||||||
filename=os.path.join(self.cfg['train']['dir_rendering'],
|
filename = os.path.join(self.cfg["train"]["dir_rendering"], f"mesh_{it}_{idx:d}.png")
|
||||||
'mesh_{}_{:d}.png'.format(it, idx))
|
|
||||||
visualize_mesh_phong(v, f, n, pose, img_size, name=filename)
|
visualize_mesh_phong(v, f, n, pose, img_size, name=filename)
|
||||||
|
|
||||||
if len(pred_imgs) >= 1:
|
if len(pred_imgs) >= 1:
|
||||||
pred_imgs = torch.stack(pred_imgs, dim=0)
|
pred_imgs = torch.stack(pred_imgs, dim=0)
|
||||||
save_image(pred_imgs.permute(0, 3, 1, 2),
|
save_image(
|
||||||
os.path.join(self.cfg['train']['dir_rendering'],
|
pred_imgs.permute(0, 3, 1, 2), os.path.join(self.cfg["train"]["dir_rendering"], f"{it}.png"), nrow=4,
|
||||||
'{}.png'.format(it)), nrow=4)
|
)
|
||||||
if self.cfg['train']['save_video']:
|
if self.cfg["train"]["save_video"]:
|
||||||
write_video(os.path.join(self.cfg['train']['dir_rendering'],
|
write_video(
|
||||||
'{}.mp4'.format(it)),
|
os.path.join(self.cfg["train"]["dir_rendering"], f"{it}.mp4"),
|
||||||
(pred_imgs*255.).type(torch.uint8), fps=24)
|
(pred_imgs * 255.0).type(torch.uint8),
|
||||||
|
fps=24,
|
||||||
|
)
|
||||||
|
|
||||||
def save_mesh_pointclouds(self, inputs, epoch, center=None, scale=None):
|
def save_mesh_pointclouds(self, inputs, epoch, center=None, scale=None):
|
||||||
''' Save meshes and point clouds.
|
"""Save meshes and point clouds.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (torch.tensor) : source point clouds
|
inputs (torch.tensor) : source point clouds
|
||||||
epoch (int) : the number of iterations
|
epoch (int) : the number of iterations
|
||||||
center (numpy.array) : center of the shape
|
center (numpy.array) : center of the shape
|
||||||
scale (numpy.array) : scale of the shape
|
scale (numpy.array) : scale of the shape.
|
||||||
'''
|
"""
|
||||||
|
exp_pcl = self.cfg["train"]["exp_pcl"]
|
||||||
exp_pcl = self.cfg['train']['exp_pcl']
|
exp_mesh = self.cfg["train"]["exp_mesh"]
|
||||||
exp_mesh = self.cfg['train']['exp_mesh']
|
|
||||||
|
|
||||||
psr_grid, points, normals = self.pcl2psr(inputs)
|
psr_grid, points, normals = self.pcl2psr(inputs)
|
||||||
|
|
||||||
if exp_pcl:
|
if exp_pcl:
|
||||||
dir_pcl = self.cfg['train']['dir_pcl']
|
dir_pcl = self.cfg["train"]["dir_pcl"]
|
||||||
p = points.squeeze(0).detach().cpu().numpy()
|
p = points.squeeze(0).detach().cpu().numpy()
|
||||||
p = p * 2 - 1
|
p = p * 2 - 1
|
||||||
n = normals.squeeze(0).detach().cpu().numpy()
|
n = normals.squeeze(0).detach().cpu().numpy()
|
||||||
|
@ -327,12 +314,11 @@ class Trainer(object):
|
||||||
p *= scale
|
p *= scale
|
||||||
if center is not None:
|
if center is not None:
|
||||||
p += center
|
p += center
|
||||||
export_pointcloud(os.path.join(dir_pcl, '{:04d}.ply'.format(epoch)), p, n)
|
export_pointcloud(os.path.join(dir_pcl, f"{epoch:04d}.ply"), p, n)
|
||||||
if exp_mesh:
|
if exp_mesh:
|
||||||
dir_mesh = self.cfg['train']['dir_mesh']
|
dir_mesh = self.cfg["train"]["dir_mesh"]
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
v, f, _ = mc_from_psr(psr_grid,
|
v, f, _ = mc_from_psr(psr_grid, zero_level=self.cfg["data"]["zero_level"], real_scale=True)
|
||||||
zero_level=self.cfg['data']['zero_level'], real_scale=True)
|
|
||||||
v = v * 2 - 1
|
v = v * 2 - 1
|
||||||
if scale is not None:
|
if scale is not None:
|
||||||
v *= scale
|
v *= scale
|
||||||
|
@ -341,9 +327,9 @@ class Trainer(object):
|
||||||
mesh = o3d.geometry.TriangleMesh()
|
mesh = o3d.geometry.TriangleMesh()
|
||||||
mesh.vertices = o3d.utility.Vector3dVector(v)
|
mesh.vertices = o3d.utility.Vector3dVector(v)
|
||||||
mesh.triangles = o3d.utility.Vector3iVector(f)
|
mesh.triangles = o3d.utility.Vector3iVector(f)
|
||||||
outdir_mesh = os.path.join(dir_mesh, '{:04d}.ply'.format(epoch))
|
outdir_mesh = os.path.join(dir_mesh, f"{epoch:04d}.ply")
|
||||||
o3d.io.write_triangle_mesh(outdir_mesh, mesh)
|
o3d.io.write_triangle_mesh(outdir_mesh, mesh)
|
||||||
|
|
||||||
if self.cfg['train']['vis_psr']:
|
if self.cfg["train"]["vis_psr"]:
|
||||||
dir_psr_vis = self.cfg['train']['out_dir']+'/psr_vis_all'
|
dir_psr_vis = self.cfg["train"]["out_dir"] + "/psr_vis_all"
|
||||||
visualize_psr_grid(psr_grid, out_dir=dir_psr_vis)
|
visualize_psr_grid(psr_grid, out_dir=dir_psr_vis)
|
||||||
|
|
153
src/training.py
153
src/training.py
|
@ -1,55 +1,61 @@
|
||||||
import os
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from pytorch3d.loss import chamfer_distance
|
||||||
|
from pytorch3d.ops.knn import knn_gather, knn_points
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from collections import defaultdict
|
|
||||||
import trimesh
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from src.dpsr import DPSR
|
from src.dpsr import DPSR
|
||||||
from src.utils import grid_interp, export_pointcloud, export_mesh, \
|
from src.utils import (
|
||||||
mc_from_psr, scale2onet, GaussianSmoothing
|
GaussianSmoothing,
|
||||||
from pytorch3d.ops.knn import knn_gather, knn_points
|
export_mesh,
|
||||||
from pytorch3d.loss import chamfer_distance
|
export_pointcloud,
|
||||||
from pdb import set_trace as st
|
mc_from_psr,
|
||||||
|
scale2onet,
|
||||||
|
)
|
||||||
|
|
||||||
class Trainer(object):
|
|
||||||
'''
|
class Trainer:
|
||||||
Args:
|
"""Args:
|
||||||
model (nn.Module): our defined model
|
model (nn.Module): our defined model
|
||||||
optimizer (optimizer): pytorch optimizer object
|
optimizer (optimizer): pytorch optimizer object
|
||||||
device (device): pytorch device
|
device (device): pytorch device
|
||||||
input_type (str): input type
|
input_type (str): input type
|
||||||
vis_dir (str): visualization directory
|
vis_dir (str): visualization directory.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg, optimizer, device=None):
|
def __init__(self, cfg, optimizer, device=None):
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.device = device
|
self.device = device
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
if self.cfg['train']['w_raw'] != 0:
|
if self.cfg["train"]["w_raw"] != 0:
|
||||||
from src.model import PSR2Mesh
|
from src.model import PSR2Mesh
|
||||||
|
|
||||||
self.psr2mesh = PSR2Mesh.apply
|
self.psr2mesh = PSR2Mesh.apply
|
||||||
|
|
||||||
# initialize DPSR
|
# initialize DPSR
|
||||||
self.dpsr = DPSR(res=(cfg['model']['grid_res'],
|
self.dpsr = DPSR(
|
||||||
cfg['model']['grid_res'],
|
res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
|
||||||
cfg['model']['grid_res']),
|
sig=cfg["model"]["psr_sigma"],
|
||||||
sig=cfg['model']['psr_sigma'])
|
)
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR
|
self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR
|
||||||
self.dpsr = self.dpsr.to(device)
|
self.dpsr = self.dpsr.to(device)
|
||||||
|
|
||||||
if cfg['train']['gauss_weight']>0.:
|
if cfg["train"]["gauss_weight"] > 0.0:
|
||||||
self.gauss_smooth = GaussianSmoothing(1, 7, 2).to(device)
|
self.gauss_smooth = GaussianSmoothing(1, 7, 2).to(device)
|
||||||
|
|
||||||
def train_step(self, inputs, data, model):
|
def train_step(self, inputs, data, model):
|
||||||
''' Performs a training step.
|
"""Performs a training step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (dict): data dictionary
|
data (dict): data dictionary
|
||||||
'''
|
"""
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
p = data.get('inputs').to(self.device)
|
p = data.get("inputs").to(self.device)
|
||||||
|
|
||||||
out = model(p)
|
out = model(p)
|
||||||
|
|
||||||
|
@ -57,18 +63,18 @@ class Trainer(object):
|
||||||
|
|
||||||
loss = 0
|
loss = 0
|
||||||
loss_each = {}
|
loss_each = {}
|
||||||
if self.cfg['train']['w_psr'] != 0:
|
if self.cfg["train"]["w_psr"] != 0:
|
||||||
psr_gt = data.get('gt_psr').to(self.device)
|
psr_gt = data.get("gt_psr").to(self.device)
|
||||||
if self.cfg['model']['psr_tanh']:
|
if self.cfg["model"]["psr_tanh"]:
|
||||||
psr_gt = torch.tanh(psr_gt)
|
psr_gt = torch.tanh(psr_gt)
|
||||||
|
|
||||||
psr_grid = self.dpsr(points, normals)
|
psr_grid = self.dpsr(points, normals)
|
||||||
if self.cfg['model']['psr_tanh']:
|
if self.cfg["model"]["psr_tanh"]:
|
||||||
psr_grid = torch.tanh(psr_grid)
|
psr_grid = torch.tanh(psr_grid)
|
||||||
|
|
||||||
# apply a rescaling weight based on GT SDF values
|
# apply a rescaling weight based on GT SDF values
|
||||||
if self.cfg['train']['gauss_weight']>0:
|
if self.cfg["train"]["gauss_weight"] > 0:
|
||||||
gauss_sigma = self.cfg['train']['gauss_weight']
|
self.cfg["train"]["gauss_weight"]
|
||||||
# set up the weighting for loss, higher weights
|
# set up the weighting for loss, higher weights
|
||||||
# for points near to the surface
|
# for points near to the surface
|
||||||
psr_gt_pad = torch.nn.ReplicationPad3d(1)(psr_gt.unsqueeze(1)).squeeze(1)
|
psr_gt_pad = torch.nn.ReplicationPad3d(1)(psr_gt.unsqueeze(1)).squeeze(1)
|
||||||
|
@ -82,43 +88,43 @@ class Trainer(object):
|
||||||
psr_grad = torch.stack([grad_x, grad_y, grad_z], dim=-1)
|
psr_grad = torch.stack([grad_x, grad_y, grad_z], dim=-1)
|
||||||
psr_grad_norm = psr_grad.norm(dim=-1)[:, None]
|
psr_grad_norm = psr_grad.norm(dim=-1)[:, None]
|
||||||
w = torch.nn.ReplicationPad3d(3)(psr_grad_norm)
|
w = torch.nn.ReplicationPad3d(3)(psr_grad_norm)
|
||||||
w = 2*self.gauss_smooth(w).squeeze(1)
|
w = 2 * self.gauss_smooth(w).squeeze(1)
|
||||||
loss_each['psr'] = self.cfg['train']['w_psr'] * F.mse_loss(w*psr_grid, w*psr_gt)
|
loss_each["psr"] = self.cfg["train"]["w_psr"] * F.mse_loss(w * psr_grid, w * psr_gt)
|
||||||
else:
|
else:
|
||||||
loss_each['psr'] = self.cfg['train']['w_psr'] * F.mse_loss(psr_grid, psr_gt)
|
loss_each["psr"] = self.cfg["train"]["w_psr"] * F.mse_loss(psr_grid, psr_gt)
|
||||||
|
|
||||||
loss += loss_each['psr']
|
loss += loss_each["psr"]
|
||||||
|
|
||||||
# regularization on the input point positions via chamfer distance
|
# regularization on the input point positions via chamfer distance
|
||||||
if self.cfg['train']['w_reg_point'] != 0.:
|
if self.cfg["train"]["w_reg_point"] != 0.0:
|
||||||
points_gt = data.get('gt_points').to(self.device)
|
points_gt = data.get("gt_points").to(self.device)
|
||||||
loss_reg, loss_norm = chamfer_distance(points, points_gt)
|
loss_reg, loss_norm = chamfer_distance(points, points_gt)
|
||||||
|
|
||||||
loss_each['reg'] = self.cfg['train']['w_reg_point'] * loss_reg
|
loss_each["reg"] = self.cfg["train"]["w_reg_point"] * loss_reg
|
||||||
loss += loss_each['reg']
|
loss += loss_each["reg"]
|
||||||
|
|
||||||
if self.cfg['train']['w_normals'] != 0.:
|
if self.cfg["train"]["w_normals"] != 0.0:
|
||||||
points_gt = data.get('gt_points').to(self.device)
|
points_gt = data.get("gt_points").to(self.device)
|
||||||
normals_gt = data.get('gt_points.normals').to(self.device)
|
normals_gt = data.get("gt_points.normals").to(self.device)
|
||||||
x_nn = knn_points(points, points_gt, K=1)
|
x_nn = knn_points(points, points_gt, K=1)
|
||||||
x_normals_near = knn_gather(normals_gt, x_nn.idx)[..., 0, :]
|
x_normals_near = knn_gather(normals_gt, x_nn.idx)[..., 0, :]
|
||||||
|
|
||||||
cham_norm_x = F.l1_loss(normals, x_normals_near)
|
cham_norm_x = F.l1_loss(normals, x_normals_near)
|
||||||
loss_norm = cham_norm_x
|
loss_norm = cham_norm_x
|
||||||
|
|
||||||
loss_each['normals'] = self.cfg['train']['w_normals'] * loss_norm
|
loss_each["normals"] = self.cfg["train"]["w_normals"] * loss_norm
|
||||||
loss += loss_each['normals']
|
loss += loss_each["normals"]
|
||||||
|
|
||||||
if self.cfg['train']['w_raw'] != 0:
|
if self.cfg["train"]["w_raw"] != 0:
|
||||||
res = self.cfg['model']['grid_res']
|
self.cfg["model"]["grid_res"]
|
||||||
# DPSR to get grid
|
# DPSR to get grid
|
||||||
psr_grid = self.dpsr(points, normals)
|
psr_grid = self.dpsr(points, normals)
|
||||||
if self.cfg['model']['psr_tanh']:
|
if self.cfg["model"]["psr_tanh"]:
|
||||||
psr_grid = torch.tanh(psr_grid)
|
psr_grid = torch.tanh(psr_grid)
|
||||||
|
|
||||||
v, f, n = self.psr2mesh(psr_grid)
|
v, f, n = self.psr2mesh(psr_grid)
|
||||||
|
|
||||||
pts_gt = data.get('gt_points').to(self.device)
|
pts_gt = data.get("gt_points").to(self.device)
|
||||||
|
|
||||||
loss, _ = chamfer_distance(v, pts_gt)
|
loss, _ = chamfer_distance(v, pts_gt)
|
||||||
|
|
||||||
|
@ -128,51 +134,51 @@ class Trainer(object):
|
||||||
return loss.item(), loss_each
|
return loss.item(), loss_each
|
||||||
|
|
||||||
def save(self, model, data, epoch, id):
|
def save(self, model, data, epoch, id):
|
||||||
|
p = data.get("inputs").to(self.device)
|
||||||
|
|
||||||
p = data.get('inputs').to(self.device)
|
exp_pcl = self.cfg["train"]["exp_pcl"]
|
||||||
|
exp_mesh = self.cfg["train"]["exp_mesh"]
|
||||||
exp_pcl = self.cfg['train']['exp_pcl']
|
exp_gt = self.cfg["generation"]["exp_gt"]
|
||||||
exp_mesh = self.cfg['train']['exp_mesh']
|
exp_input = self.cfg["generation"]["exp_input"]
|
||||||
exp_gt = self.cfg['generation']['exp_gt']
|
|
||||||
exp_input = self.cfg['generation']['exp_input']
|
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
points, normals = model(p)
|
points, normals = model(p)
|
||||||
|
|
||||||
if exp_gt:
|
if exp_gt:
|
||||||
points_gt = data.get('gt_points').to(self.device)
|
points_gt = data.get("gt_points").to(self.device)
|
||||||
normals_gt = data.get('gt_points.normals').to(self.device)
|
normals_gt = data.get("gt_points.normals").to(self.device)
|
||||||
|
|
||||||
if exp_pcl:
|
if exp_pcl:
|
||||||
dir_pcl = self.cfg['train']['dir_pcl']
|
dir_pcl = self.cfg["train"]["dir_pcl"]
|
||||||
export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}.ply'.format(epoch, id)), scale2onet(points), normals)
|
export_pointcloud(os.path.join(dir_pcl, f"{epoch:04d}_{id:01d}.ply"), scale2onet(points), normals)
|
||||||
if exp_gt:
|
if exp_gt:
|
||||||
export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}_oracle.ply'.format(epoch, id)), scale2onet(points_gt), normals_gt)
|
export_pointcloud(
|
||||||
|
os.path.join(dir_pcl, f"{epoch:04d}_{id:01d}_oracle.ply"), scale2onet(points_gt), normals_gt,
|
||||||
|
)
|
||||||
if exp_input:
|
if exp_input:
|
||||||
export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}_input.ply'.format(epoch, id)), scale2onet(p))
|
export_pointcloud(os.path.join(dir_pcl, f"{epoch:04d}_{id:01d}_input.ply"), scale2onet(p))
|
||||||
|
|
||||||
if exp_mesh:
|
if exp_mesh:
|
||||||
dir_mesh = self.cfg['train']['dir_mesh']
|
dir_mesh = self.cfg["train"]["dir_mesh"]
|
||||||
psr_grid = self.dpsr(points, normals)
|
psr_grid = self.dpsr(points, normals)
|
||||||
# psr_grid = torch.tanh(psr_grid)
|
# psr_grid = torch.tanh(psr_grid)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
v, f, _ = mc_from_psr(psr_grid,
|
v, f, _ = mc_from_psr(psr_grid, zero_level=self.cfg["data"]["zero_level"])
|
||||||
zero_level=self.cfg['data']['zero_level'])
|
outdir_mesh = os.path.join(dir_mesh, f"{epoch:04d}_{id:01d}.ply")
|
||||||
outdir_mesh = os.path.join(dir_mesh, '{:04d}_{:01d}.ply'.format(epoch, id))
|
|
||||||
export_mesh(outdir_mesh, scale2onet(v), f)
|
export_mesh(outdir_mesh, scale2onet(v), f)
|
||||||
if exp_gt:
|
if exp_gt:
|
||||||
psr_gt = self.dpsr(points_gt, normals_gt)
|
psr_gt = self.dpsr(points_gt, normals_gt)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
v, f, _ = mc_from_psr(psr_gt,
|
v, f, _ = mc_from_psr(psr_gt, zero_level=self.cfg["data"]["zero_level"])
|
||||||
zero_level=self.cfg['data']['zero_level'])
|
export_mesh(os.path.join(dir_mesh, f"{epoch:04d}_{id:01d}_oracle.ply"), scale2onet(v), f)
|
||||||
export_mesh(os.path.join(dir_mesh, '{:04d}_{:01d}_oracle.ply'.format(epoch, id)), scale2onet(v), f)
|
|
||||||
|
|
||||||
def evaluate(self, val_loader, model):
|
def evaluate(self, val_loader, model):
|
||||||
''' Performs an evaluation.
|
"""Performs an evaluation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
val_loader (dataloader): pytorch dataloader
|
val_loader (dataloader): pytorch dataloader.
|
||||||
'''
|
"""
|
||||||
eval_list = defaultdict(list)
|
eval_list = defaultdict(list)
|
||||||
|
|
||||||
for data in tqdm(val_loader):
|
for data in tqdm(val_loader):
|
||||||
|
@ -185,15 +191,16 @@ class Trainer(object):
|
||||||
return eval_dict
|
return eval_dict
|
||||||
|
|
||||||
def eval_step(self, data, model):
|
def eval_step(self, data, model):
|
||||||
''' Performs an evaluation step.
|
"""Performs an evaluation step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (dict): data dictionary
|
data (dict): data dictionary.
|
||||||
'''
|
"""
|
||||||
model.eval()
|
model.eval()
|
||||||
eval_dict = {}
|
eval_dict = {}
|
||||||
|
|
||||||
p = data.get('inputs').to(self.device)
|
p = data.get("inputs").to(self.device)
|
||||||
psr_gt = data.get('gt_psr').to(self.device)
|
psr_gt = data.get("gt_psr").to(self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# forward pass
|
# forward pass
|
||||||
|
@ -201,7 +208,7 @@ class Trainer(object):
|
||||||
# DPSR to get predicted psr grid
|
# DPSR to get predicted psr grid
|
||||||
psr_grid = self.dpsr(points, normals)
|
psr_grid = self.dpsr(points, normals)
|
||||||
|
|
||||||
eval_dict['psr_l1'] = F.l1_loss(psr_grid, psr_gt).item()
|
eval_dict["psr_l1"] = F.l1_loss(psr_grid, psr_gt).item()
|
||||||
eval_dict['psr_l2'] = F.mse_loss(psr_grid, psr_gt).item()
|
eval_dict["psr_l2"] = F.mse_loss(psr_grid, psr_gt).item()
|
||||||
|
|
||||||
return eval_dict
|
return eval_dict
|
335
src/utils.py
335
src/utils.py
|
@ -1,53 +1,53 @@
|
||||||
import torch
|
import logging
|
||||||
import io, os, logging, urllib
|
|
||||||
import yaml
|
|
||||||
import trimesh
|
|
||||||
import imageio
|
|
||||||
import numbers
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numbers
|
||||||
|
import os
|
||||||
|
import urllib
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import open3d as o3d
|
||||||
|
import torch
|
||||||
|
import trimesh
|
||||||
|
import yaml
|
||||||
|
from igl import adjacency_matrix, connected_components
|
||||||
from plyfile import PlyData
|
from plyfile import PlyData
|
||||||
|
from pytorch3d.renderer import PerspectiveCameras, rasterize_meshes
|
||||||
|
from pytorch3d.structures import Meshes
|
||||||
|
from skimage import measure
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.utils import model_zoo
|
from torch.utils import model_zoo
|
||||||
from skimage import measure, img_as_float32
|
|
||||||
from pytorch3d.structures import Meshes
|
|
||||||
from pytorch3d.renderer import PerspectiveCameras, rasterize_meshes
|
|
||||||
from igl import adjacency_matrix, connected_components
|
|
||||||
import open3d as o3d
|
|
||||||
|
|
||||||
##################################################
|
##################################################
|
||||||
# Below are functions for DPSR
|
# Below are functions for DPSR
|
||||||
|
|
||||||
|
|
||||||
def fftfreqs(res, dtype=torch.float32, exact=True):
|
def fftfreqs(res, dtype=torch.float32, exact=True):
|
||||||
"""
|
"""Helper function to return frequency tensors
|
||||||
Helper function to return frequency tensors
|
|
||||||
:param res: n_dims int tuple of number of frequency modes
|
:param res: n_dims int tuple of number of frequency modes
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
n_dims = len(res)
|
n_dims = len(res)
|
||||||
freqs = []
|
freqs = []
|
||||||
for dim in range(n_dims - 1):
|
for dim in range(n_dims - 1):
|
||||||
r_ = res[dim]
|
r_ = res[dim]
|
||||||
freq = np.fft.fftfreq(r_, d=1/r_)
|
freq = np.fft.fftfreq(r_, d=1 / r_)
|
||||||
freqs.append(torch.tensor(freq, dtype=dtype))
|
freqs.append(torch.tensor(freq, dtype=dtype))
|
||||||
r_ = res[-1]
|
r_ = res[-1]
|
||||||
if exact:
|
if exact:
|
||||||
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype))
|
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1 / r_), dtype=dtype))
|
||||||
else:
|
else:
|
||||||
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype))
|
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1 / r_)[:-1], dtype=dtype))
|
||||||
omega = torch.meshgrid(freqs)
|
omega = torch.meshgrid(freqs)
|
||||||
omega = list(omega)
|
omega = list(omega)
|
||||||
omega = torch.stack(omega, dim=-1)
|
omega = torch.stack(omega, dim=-1)
|
||||||
|
|
||||||
return omega
|
return omega
|
||||||
|
|
||||||
|
|
||||||
def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag)
|
def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag)
|
||||||
"""
|
"""Multiply tensor x by i ** deg."""
|
||||||
multiply tensor x by i ** deg
|
|
||||||
"""
|
|
||||||
deg %= 4
|
deg %= 4
|
||||||
if deg == 0:
|
if deg == 0:
|
||||||
res = x
|
res = x
|
||||||
|
@ -61,17 +61,18 @@ def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag)
|
||||||
res[..., 1] = -res[..., 1]
|
res[..., 1] = -res[..., 1]
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def spec_gaussian_filter(res, sig):
|
def spec_gaussian_filter(res, sig):
|
||||||
omega = fftfreqs(res, dtype=torch.float64) # [dim0, dim1, dim2, d]
|
omega = fftfreqs(res, dtype=torch.float64) # [dim0, dim1, dim2, d]
|
||||||
dis = torch.sqrt(torch.sum(omega ** 2, dim=-1))
|
dis = torch.sqrt(torch.sum(omega**2, dim=-1))
|
||||||
filter_ = torch.exp(-0.5*((sig*2*dis/res[0])**2)).unsqueeze(-1).unsqueeze(-1)
|
filter_ = torch.exp(-0.5 * ((sig * 2 * dis / res[0]) ** 2)).unsqueeze(-1).unsqueeze(-1)
|
||||||
filter_.requires_grad = False
|
filter_.requires_grad = False
|
||||||
|
|
||||||
return filter_
|
return filter_
|
||||||
|
|
||||||
|
|
||||||
def grid_interp(grid, pts, batched=True):
|
def grid_interp(grid, pts, batched=True):
|
||||||
"""
|
""":param grid: tensor of shape (batch, *size, in_features)
|
||||||
:param grid: tensor of shape (batch, *size, in_features)
|
|
||||||
:param pts: tensor of shape (batch, num_points, dim) within range (0, 1)
|
:param pts: tensor of shape (batch, num_points, dim) within range (0, 1)
|
||||||
:return values at query points
|
:return values at query points
|
||||||
"""
|
"""
|
||||||
|
@ -86,7 +87,7 @@ def grid_interp(grid, pts, batched=True):
|
||||||
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
|
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
|
||||||
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
|
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
|
||||||
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
|
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
|
||||||
tmp = torch.tensor([0,1],dtype=torch.long)
|
tmp = torch.tensor([0, 1], dtype=torch.long)
|
||||||
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
|
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
|
||||||
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
|
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
|
||||||
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
|
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
|
||||||
|
@ -96,14 +97,16 @@ def grid_interp(grid, pts, batched=True):
|
||||||
if dim == 2:
|
if dim == 2:
|
||||||
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1]] # (batch, num_points, 2**dim, in_features)
|
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1]] # (batch, num_points, 2**dim, in_features)
|
||||||
else:
|
else:
|
||||||
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2]] # (batch, num_points, 2**dim, in_features)
|
lat = grid.clone()[
|
||||||
|
ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2],
|
||||||
|
] # (batch, num_points, 2**dim, in_features)
|
||||||
|
|
||||||
# weights of neighboring nodes
|
# weights of neighboring nodes
|
||||||
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
|
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
|
||||||
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
|
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
|
||||||
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
|
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
|
||||||
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
xyz01[com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
|
||||||
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
pos_ = xyz01[1 - com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
|
||||||
pos_ = pos_.type(pts.dtype)
|
pos_ = pos_.type(pts.dtype)
|
||||||
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
|
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
|
||||||
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
|
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
|
||||||
|
@ -113,38 +116,38 @@ def grid_interp(grid, pts, batched=True):
|
||||||
|
|
||||||
return query_values
|
return query_values
|
||||||
|
|
||||||
|
|
||||||
def scatter_to_grid(inds, vals, size):
|
def scatter_to_grid(inds, vals, size):
|
||||||
"""
|
"""Scatter update values into empty tensor of size size.
|
||||||
Scatter update values into empty tensor of size size.
|
|
||||||
:param inds: (#values, dims)
|
:param inds: (#values, dims)
|
||||||
:param vals: (#values)
|
:param vals: (#values)
|
||||||
:param size: tuple for size. len(size)=dims
|
:param size: tuple for size. len(size)=dims.
|
||||||
"""
|
"""
|
||||||
dims = inds.shape[1]
|
dims = inds.shape[1]
|
||||||
assert(inds.shape[0] == vals.shape[0])
|
assert inds.shape[0] == vals.shape[0]
|
||||||
assert(len(size) == dims)
|
assert len(size) == dims
|
||||||
dev = vals.device
|
dev = vals.device
|
||||||
# result = torch.zeros(*size).view(-1).to(dev).type(vals.dtype) # flatten
|
# result = torch.zeros(*size).view(-1).to(dev).type(vals.dtype) # flatten
|
||||||
# # flatten inds
|
# # flatten inds
|
||||||
result = torch.zeros(*size, device=dev).view(-1).type(vals.dtype) # flatten
|
result = torch.zeros(*size, device=dev).view(-1).type(vals.dtype) # flatten
|
||||||
# flatten inds
|
# flatten inds
|
||||||
fac = [np.prod(size[i+1:]) for i in range(len(size)-1)] + [1]
|
fac = [np.prod(size[i + 1 :]) for i in range(len(size) - 1)] + [1]
|
||||||
fac = torch.tensor(fac, device=dev).type(inds.dtype)
|
fac = torch.tensor(fac, device=dev).type(inds.dtype)
|
||||||
inds_fold = torch.sum(inds*fac, dim=-1) # [#values,]
|
inds_fold = torch.sum(inds * fac, dim=-1) # [#values,]
|
||||||
result.scatter_add_(0, inds_fold, vals)
|
result.scatter_add_(0, inds_fold, vals)
|
||||||
result = result.view(*size)
|
result = result.view(*size)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def point_rasterize(pts, vals, size):
|
def point_rasterize(pts, vals, size):
|
||||||
"""
|
""":param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1)
|
||||||
:param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1)
|
|
||||||
:param vals: point values, tensor of shape (batch, num_points, features)
|
:param vals: point values, tensor of shape (batch, num_points, features)
|
||||||
:param size: len(size)=dim tuple for grid size
|
:param size: len(size)=dim tuple for grid size
|
||||||
:return rasterized values (batch, features, res0, res1, res2)
|
:return rasterized values (batch, features, res0, res1, res2)
|
||||||
"""
|
"""
|
||||||
dim = pts.shape[-1]
|
dim = pts.shape[-1]
|
||||||
assert(pts.shape[:2] == vals.shape[:2])
|
assert pts.shape[:2] == vals.shape[:2]
|
||||||
assert(pts.shape[2] == dim)
|
assert pts.shape[2] == dim
|
||||||
size_list = list(size)
|
size_list = list(size)
|
||||||
size = torch.tensor(size).to(pts.device).float()
|
size = torch.tensor(size).to(pts.device).float()
|
||||||
cubesize = 1.0 / size
|
cubesize = 1.0 / size
|
||||||
|
@ -156,20 +159,22 @@ def point_rasterize(pts, vals, size):
|
||||||
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
|
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
|
||||||
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
|
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
|
||||||
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
|
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
|
||||||
tmp = torch.tensor([0,1],dtype=torch.long)
|
tmp = torch.tensor([0, 1], dtype=torch.long)
|
||||||
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
|
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
|
||||||
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
|
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
|
||||||
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
|
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
|
||||||
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
|
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
|
||||||
# ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
|
# ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
|
||||||
ind_b = torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
|
ind_b = (
|
||||||
|
torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1)
|
||||||
|
) # (batch, num_points, 2**dim)
|
||||||
|
|
||||||
# weights of neighboring nodes
|
# weights of neighboring nodes
|
||||||
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
|
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
|
||||||
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
|
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
|
||||||
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
|
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
|
||||||
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
xyz01[com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
|
||||||
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
pos_ = xyz01[1 - com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
|
||||||
pos_ = pos_.type(pts.dtype)
|
pos_ = pos_.type(pts.dtype)
|
||||||
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
|
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
|
||||||
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
|
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
|
||||||
|
@ -187,20 +192,21 @@ def point_rasterize(pts, vals, size):
|
||||||
# weighted values
|
# weighted values
|
||||||
vals = weights.unsqueeze(-1) * vals.unsqueeze(-2) # (batch, num_points, 2**dim, nf)
|
vals = weights.unsqueeze(-1) * vals.unsqueeze(-2) # (batch, num_points, 2**dim, nf)
|
||||||
|
|
||||||
inds = inds.view(-1, dim+2).permute(1, 0).long() # (1+dim+1, bs*npts*2**dim*nf)
|
inds = inds.view(-1, dim + 2).permute(1, 0).long() # (1+dim+1, bs*npts*2**dim*nf)
|
||||||
vals = vals.reshape(-1) # (bs*npts*2**dim*nf)
|
vals = vals.reshape(-1) # (bs*npts*2**dim*nf)
|
||||||
tensor_size = [bs, nf] + size_list
|
[bs, nf, *size_list]
|
||||||
raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf] + size_list)
|
raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf, *size_list])
|
||||||
|
|
||||||
return raster # [batch, nf, res, res, res]
|
return raster # [batch, nf, res, res, res]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
##################################################
|
##################################################
|
||||||
# Below are the utilization functions in general
|
# Below are the utilization functions in general
|
||||||
|
|
||||||
class AverageMeter(object):
|
|
||||||
"""Computes and stores the average and current value"""
|
class AverageMeter:
|
||||||
|
"""Computes and stores the average and current value."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
@ -226,47 +232,47 @@ class AverageMeter(object):
|
||||||
def avgcavg(self):
|
def avgcavg(self):
|
||||||
return self.avg.sum().item() / (self.count != 0).sum().item()
|
return self.avg.sum().item() / (self.count != 0).sum().item()
|
||||||
|
|
||||||
|
|
||||||
def load_model_manual(state_dict, model):
|
def load_model_manual(state_dict, model):
|
||||||
new_state_dict = OrderedDict()
|
new_state_dict = OrderedDict()
|
||||||
is_model_parallel = isinstance(model, torch.nn.DataParallel)
|
is_model_parallel = isinstance(model, torch.nn.DataParallel)
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if k.startswith('module.') != is_model_parallel:
|
if k.startswith("module.") != is_model_parallel:
|
||||||
if k.startswith('module.'):
|
if k.startswith("module."):
|
||||||
# remove module
|
# remove module
|
||||||
k = k[7:]
|
k = k[7:]
|
||||||
else:
|
else:
|
||||||
# add module
|
# add module
|
||||||
k = 'module.' + k
|
k = "module." + k
|
||||||
|
|
||||||
new_state_dict[k]=v
|
new_state_dict[k] = v
|
||||||
|
|
||||||
model.load_state_dict(new_state_dict)
|
model.load_state_dict(new_state_dict)
|
||||||
|
|
||||||
|
|
||||||
def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0):
|
def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0):
|
||||||
'''
|
"""Run marching cubes from PSR grid."""
|
||||||
Run marching cubes from PSR grid
|
|
||||||
'''
|
|
||||||
batch_size = psr_grid.shape[0]
|
batch_size = psr_grid.shape[0]
|
||||||
s = psr_grid.shape[-1] # size of psr_grid
|
s = psr_grid.shape[-1] # size of psr_grid
|
||||||
psr_grid_numpy = psr_grid.squeeze().detach().cpu().numpy()
|
psr_grid_numpy = psr_grid.squeeze().detach().cpu().numpy()
|
||||||
|
|
||||||
if batch_size>1:
|
if batch_size > 1:
|
||||||
verts, faces, normals = [], [], []
|
verts, faces, normals = [], [], []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
verts_cur, faces_cur, normals_cur, values = measure.marching_cubes(psr_grid_numpy[i], level=0)
|
verts_cur, faces_cur, normals_cur, values = measure.marching_cubes(psr_grid_numpy[i], level=0)
|
||||||
verts.append(verts_cur)
|
verts.append(verts_cur)
|
||||||
faces.append(faces_cur)
|
faces.append(faces_cur)
|
||||||
normals.append(normals_cur)
|
normals.append(normals_cur)
|
||||||
verts = np.stack(verts, axis = 0)
|
verts = np.stack(verts, axis=0)
|
||||||
faces = np.stack(faces, axis = 0)
|
faces = np.stack(faces, axis=0)
|
||||||
normals = np.stack(normals, axis = 0)
|
normals = np.stack(normals, axis=0)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy, level=zero_level)
|
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy, level=zero_level)
|
||||||
except:
|
except:
|
||||||
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy)
|
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy)
|
||||||
if real_scale:
|
if real_scale:
|
||||||
verts = verts / (s-1) # scale to range [0, 1]
|
verts = verts / (s - 1) # scale to range [0, 1]
|
||||||
else:
|
else:
|
||||||
verts = verts / s # scale to range [0, 1)
|
verts = verts / s # scale to range [0, 1)
|
||||||
|
|
||||||
|
@ -278,6 +284,7 @@ def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0):
|
||||||
|
|
||||||
return verts, faces, normals
|
return verts, faces, normals
|
||||||
|
|
||||||
|
|
||||||
def calc_inters_points(verts, faces, pose, img_size, mask_gt=None):
|
def calc_inters_points(verts, faces, pose, img_size, mask_gt=None):
|
||||||
verts = verts.squeeze()
|
verts = verts.squeeze()
|
||||||
faces = faces.squeeze()
|
faces = faces.squeeze()
|
||||||
|
@ -293,37 +300,34 @@ def calc_inters_points(verts, faces, pose, img_size, mask_gt=None):
|
||||||
v_a, v_b, v_c = verts[f_p[..., 0]], verts[f_p[..., 1]], verts[f_p[..., 2]]
|
v_a, v_b, v_c = verts[f_p[..., 0]], verts[f_p[..., 1]], verts[f_p[..., 2]]
|
||||||
|
|
||||||
# calculate the intersection point of each pixel and the mesh
|
# calculate the intersection point of each pixel and the mesh
|
||||||
p_inters = w_masked[..., 0, None] * v_a + \
|
p_inters = w_masked[..., 0, None] * v_a + w_masked[..., 1, None] * v_b + w_masked[..., 2, None] * v_c
|
||||||
w_masked[..., 1, None] * v_b + \
|
|
||||||
w_masked[..., 2, None] * v_c
|
|
||||||
else:
|
else:
|
||||||
# backproject ndc to world coordinates using z-buffer
|
# backproject ndc to world coordinates using z-buffer
|
||||||
W, H = img_size[1], img_size[0]
|
W, H = img_size[1], img_size[0]
|
||||||
xy = uv.to(mask.device)[mask]
|
xy = uv.to(mask.device)[mask]
|
||||||
x_ndc = 1 - (2*xy[:, 0]) / (W - 1)
|
x_ndc = 1 - (2 * xy[:, 0]) / (W - 1)
|
||||||
y_ndc = 1 - (2*xy[:, 1]) / (H - 1)
|
y_ndc = 1 - (2 * xy[:, 1]) / (H - 1)
|
||||||
z = zbuf.squeeze().reshape(H * W)[mask]
|
z = zbuf.squeeze().reshape(H * W)[mask]
|
||||||
xy_depth = torch.stack((x_ndc, y_ndc, z), dim=1)
|
xy_depth = torch.stack((x_ndc, y_ndc, z), dim=1)
|
||||||
|
|
||||||
p_inters = pose.unproject_points(xy_depth, world_coordinates=True)
|
p_inters = pose.unproject_points(xy_depth, world_coordinates=True)
|
||||||
|
|
||||||
# if there are outlier points, we should remove it
|
# if there are outlier points, we should remove it
|
||||||
if (p_inters.max()>1) | (p_inters.min()<-1):
|
if (p_inters.max() > 1) | (p_inters.min() < -1):
|
||||||
mask_bound = (p_inters>=-1) & (p_inters<=1)
|
mask_bound = (p_inters >= -1) & (p_inters <= 1)
|
||||||
mask_bound = (mask_bound.sum(dim=-1)==3)
|
mask_bound = mask_bound.sum(dim=-1) == 3
|
||||||
mask[mask==True] = mask_bound
|
mask[mask is True] = mask_bound
|
||||||
p_inters = p_inters[mask_bound]
|
p_inters = p_inters[mask_bound]
|
||||||
print('!!!!!find outlier!')
|
print("!!!!!find outlier!")
|
||||||
|
|
||||||
return p_inters, mask, f_p, w_masked
|
return p_inters, mask, f_p, w_masked
|
||||||
|
|
||||||
|
|
||||||
def mesh_rasterization(verts, faces, pose, img_size):
|
def mesh_rasterization(verts, faces, pose, img_size):
|
||||||
'''
|
"""Use PyTorch3D to rasterize the mesh given a camera."""
|
||||||
Use PyTorch3D to rasterize the mesh given a camera
|
|
||||||
'''
|
|
||||||
transformed_v = pose.transform_points(verts.detach()) # world -> ndc coordinate system
|
transformed_v = pose.transform_points(verts.detach()) # world -> ndc coordinate system
|
||||||
if isinstance(pose, PerspectiveCameras):
|
if isinstance(pose, PerspectiveCameras):
|
||||||
transformed_v[..., 2] = 1/transformed_v[..., 2]
|
transformed_v[..., 2] = 1 / transformed_v[..., 2]
|
||||||
# find p_closest on mesh of each pixel via rasterization
|
# find p_closest on mesh of each pixel via rasterization
|
||||||
transformed_mesh = Meshes(verts=[transformed_v], faces=[faces])
|
transformed_mesh = Meshes(verts=[transformed_v], faces=[faces])
|
||||||
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
|
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
|
||||||
|
@ -331,7 +335,7 @@ def mesh_rasterization(verts, faces, pose, img_size):
|
||||||
image_size=img_size,
|
image_size=img_size,
|
||||||
blur_radius=0,
|
blur_radius=0,
|
||||||
faces_per_pixel=1,
|
faces_per_pixel=1,
|
||||||
perspective_correct=False
|
perspective_correct=False,
|
||||||
)
|
)
|
||||||
pix_to_face = pix_to_face.reshape(1, -1) # B x reso x reso -> B x (reso x reso)
|
pix_to_face = pix_to_face.reshape(1, -1) # B x reso x reso -> B x (reso x reso)
|
||||||
mask = pix_to_face.clone() != -1
|
mask = pix_to_face.clone() != -1
|
||||||
|
@ -341,11 +345,11 @@ def mesh_rasterization(verts, faces, pose, img_size):
|
||||||
|
|
||||||
return pix_to_face, w, mask
|
return pix_to_face, w, mask
|
||||||
|
|
||||||
|
|
||||||
def verts_on_largest_mesh(verts, faces):
|
def verts_on_largest_mesh(verts, faces):
|
||||||
'''
|
"""verts: Numpy array or Torch.Tensor (N, 3)
|
||||||
verts: Numpy array or Torch.Tensor (N, 3)
|
faces: Numpy array (N, 3).
|
||||||
faces: Numpy array (N, 3)
|
"""
|
||||||
'''
|
|
||||||
if torch.is_tensor(faces):
|
if torch.is_tensor(faces):
|
||||||
verts = verts.squeeze().detach().cpu().numpy()
|
verts = verts.squeeze().detach().cpu().numpy()
|
||||||
faces = faces.squeeze().int().detach().cpu().numpy()
|
faces = faces.squeeze().int().detach().cpu().numpy()
|
||||||
|
@ -356,7 +360,7 @@ def verts_on_largest_mesh(verts, faces):
|
||||||
v_large, f_large = verts, faces
|
v_large, f_large = verts, faces
|
||||||
else:
|
else:
|
||||||
max_idx = conn_size.argmax() # find the index of the largest component
|
max_idx = conn_size.argmax() # find the index of the largest component
|
||||||
v_large = verts[conn_idx==max_idx] # keep points on the largest component
|
v_large = verts[conn_idx == max_idx] # keep points on the largest component
|
||||||
|
|
||||||
if True:
|
if True:
|
||||||
mesh_largest = trimesh.Trimesh(verts, faces)
|
mesh_largest = trimesh.Trimesh(verts, faces)
|
||||||
|
@ -366,36 +370,41 @@ def verts_on_largest_mesh(verts, faces):
|
||||||
v_large = v_large.astype(np.float32)
|
v_large = v_large.astype(np.float32)
|
||||||
return v_large, f_large
|
return v_large, f_large
|
||||||
|
|
||||||
|
|
||||||
def load_pointcloud(in_file):
|
def load_pointcloud(in_file):
|
||||||
plydata = PlyData.read(in_file)
|
plydata = PlyData.read(in_file)
|
||||||
vertices = np.stack([
|
vertices = np.stack(
|
||||||
plydata['vertex']['x'],
|
[
|
||||||
plydata['vertex']['y'],
|
plydata["vertex"]["x"],
|
||||||
plydata['vertex']['z']
|
plydata["vertex"]["y"],
|
||||||
], axis=1)
|
plydata["vertex"]["z"],
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
return vertices
|
return vertices
|
||||||
|
|
||||||
|
|
||||||
# General config
|
# General config
|
||||||
def load_config(path, default_path=None):
|
def load_config(path, default_path=None):
|
||||||
''' Loads config file.
|
"""Loads config file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (str): path to config file
|
path (str): path to config file
|
||||||
default_path (bool): whether to use default path
|
default_path (bool): whether to use default path
|
||||||
'''
|
"""
|
||||||
# Load configuration from file itself
|
# Load configuration from file itself
|
||||||
with open(path, 'r') as f:
|
with open(path) as f:
|
||||||
cfg_special = yaml.load(f, Loader=yaml.Loader)
|
cfg_special = yaml.load(f, Loader=yaml.Loader)
|
||||||
|
|
||||||
# Check if we should inherit from a config
|
# Check if we should inherit from a config
|
||||||
inherit_from = cfg_special.get('inherit_from')
|
inherit_from = cfg_special.get("inherit_from")
|
||||||
|
|
||||||
# If yes, load this config first as default
|
# If yes, load this config first as default
|
||||||
# If no, use the default_path
|
# If no, use the default_path
|
||||||
if inherit_from is not None:
|
if inherit_from is not None:
|
||||||
cfg = load_config(inherit_from, default_path)
|
cfg = load_config(inherit_from, default_path)
|
||||||
elif default_path is not None:
|
elif default_path is not None:
|
||||||
with open(default_path, 'r') as f:
|
with open(default_path) as f:
|
||||||
cfg = yaml.load(f, Loader=yaml.Loader)
|
cfg = yaml.load(f, Loader=yaml.Loader)
|
||||||
else:
|
else:
|
||||||
cfg = dict()
|
cfg = dict()
|
||||||
|
@ -405,65 +414,67 @@ def load_config(path, default_path=None):
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
def update_config(config, unknown):
|
def update_config(config, unknown):
|
||||||
# update config given args
|
# update config given args
|
||||||
for idx,arg in enumerate(unknown):
|
for idx, arg in enumerate(unknown):
|
||||||
if arg.startswith("--"):
|
if arg.startswith("--"):
|
||||||
keys = arg.replace("--","").split(':')
|
keys = arg.replace("--", "").split(":")
|
||||||
assert(len(keys)==2)
|
assert len(keys) == 2
|
||||||
k1, k2 = keys
|
k1, k2 = keys
|
||||||
argtype = type(config[k1][k2])
|
argtype = type(config[k1][k2])
|
||||||
if argtype == bool:
|
if argtype == bool:
|
||||||
v = unknown[idx+1].lower() == 'true'
|
v = unknown[idx + 1].lower() == "true"
|
||||||
else:
|
else:
|
||||||
if config[k1][k2] is not None:
|
if config[k1][k2] is not None:
|
||||||
v = type(config[k1][k2])(unknown[idx+1])
|
v = type(config[k1][k2])(unknown[idx + 1])
|
||||||
else:
|
else:
|
||||||
v = unknown[idx+1]
|
v = unknown[idx + 1]
|
||||||
print(f'Changing {k1}:{k2} ---- {config[k1][k2]} to {v}')
|
print(f"Changing {k1}:{k2} ---- {config[k1][k2]} to {v}")
|
||||||
config[k1][k2] = v
|
config[k1][k2] = v
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def initialize_logger(cfg):
|
def initialize_logger(cfg):
|
||||||
out_dir = cfg['train']['out_dir']
|
out_dir = cfg["train"]["out_dir"]
|
||||||
if not out_dir:
|
if not out_dir:
|
||||||
os.makedirs(out_dir)
|
os.makedirs(out_dir)
|
||||||
|
|
||||||
cfg['train']['dir_model'] = os.path.join(out_dir, 'model')
|
cfg["train"]["dir_model"] = os.path.join(out_dir, "model")
|
||||||
os.makedirs(cfg['train']['dir_model'], exist_ok=True)
|
os.makedirs(cfg["train"]["dir_model"], exist_ok=True)
|
||||||
|
|
||||||
if cfg['train']['exp_mesh']:
|
|
||||||
cfg['train']['dir_mesh'] = os.path.join(out_dir, 'vis/mesh')
|
|
||||||
os.makedirs(cfg['train']['dir_mesh'], exist_ok=True)
|
|
||||||
if cfg['train']['exp_pcl']:
|
|
||||||
cfg['train']['dir_pcl'] = os.path.join(out_dir, 'vis/pointcloud')
|
|
||||||
os.makedirs(cfg['train']['dir_pcl'], exist_ok=True)
|
|
||||||
if cfg['train']['vis_rendering']:
|
|
||||||
cfg['train']['dir_rendering'] = os.path.join(out_dir, 'vis/rendering')
|
|
||||||
os.makedirs(cfg['train']['dir_rendering'], exist_ok=True)
|
|
||||||
if cfg['train']['o3d_show']:
|
|
||||||
cfg['train']['dir_o3d'] = os.path.join(out_dir, 'vis/o3d')
|
|
||||||
os.makedirs(cfg['train']['dir_o3d'], exist_ok=True)
|
|
||||||
|
|
||||||
|
if cfg["train"]["exp_mesh"]:
|
||||||
|
cfg["train"]["dir_mesh"] = os.path.join(out_dir, "vis/mesh")
|
||||||
|
os.makedirs(cfg["train"]["dir_mesh"], exist_ok=True)
|
||||||
|
if cfg["train"]["exp_pcl"]:
|
||||||
|
cfg["train"]["dir_pcl"] = os.path.join(out_dir, "vis/pointcloud")
|
||||||
|
os.makedirs(cfg["train"]["dir_pcl"], exist_ok=True)
|
||||||
|
if cfg["train"]["vis_rendering"]:
|
||||||
|
cfg["train"]["dir_rendering"] = os.path.join(out_dir, "vis/rendering")
|
||||||
|
os.makedirs(cfg["train"]["dir_rendering"], exist_ok=True)
|
||||||
|
if cfg["train"]["o3d_show"]:
|
||||||
|
cfg["train"]["dir_o3d"] = os.path.join(out_dir, "vis/o3d")
|
||||||
|
os.makedirs(cfg["train"]["dir_o3d"], exist_ok=True)
|
||||||
|
|
||||||
logger = logging.getLogger("train")
|
logger = logging.getLogger("train")
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
logger.handlers = []
|
logger.handlers = []
|
||||||
# ch = logging.StreamHandler()
|
# ch = logging.StreamHandler()
|
||||||
# logger.addHandler(ch)
|
# logger.addHandler(ch)
|
||||||
fh = logging.FileHandler(os.path.join(cfg['train']['out_dir'], "log.txt"))
|
fh = logging.FileHandler(os.path.join(cfg["train"]["out_dir"], "log.txt"))
|
||||||
logger.addHandler(fh)
|
logger.addHandler(fh)
|
||||||
logger.info('Outout dir: %s', out_dir)
|
logger.info("Outout dir: %s", out_dir)
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
def update_recursive(dict1, dict2):
|
def update_recursive(dict1, dict2):
|
||||||
''' Update two config dictionaries recursively.
|
"""Update two config dictionaries recursively.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dict1 (dict): first dictionary to be updated
|
dict1 (dict): first dictionary to be updated
|
||||||
dict2 (dict): second dictionary which entries should be used
|
dict2 (dict): second dictionary which entries should be used
|
||||||
|
|
||||||
'''
|
"""
|
||||||
for k, v in dict2.items():
|
for k, v in dict2.items():
|
||||||
if k not in dict1:
|
if k not in dict1:
|
||||||
dict1[k] = dict()
|
dict1[k] = dict()
|
||||||
|
@ -472,6 +483,7 @@ def update_recursive(dict1, dict2):
|
||||||
else:
|
else:
|
||||||
dict1[k] = v
|
dict1[k] = v
|
||||||
|
|
||||||
|
|
||||||
def export_pointcloud(name, points, normals=None):
|
def export_pointcloud(name, points, normals=None):
|
||||||
if len(points.shape) > 2:
|
if len(points.shape) > 2:
|
||||||
points = points[0]
|
points = points[0]
|
||||||
|
@ -487,6 +499,7 @@ def export_pointcloud(name, points, normals=None):
|
||||||
pcd.normals = o3d.utility.Vector3dVector(normals)
|
pcd.normals = o3d.utility.Vector3dVector(normals)
|
||||||
o3d.io.write_point_cloud(name, pcd)
|
o3d.io.write_point_cloud(name, pcd)
|
||||||
|
|
||||||
|
|
||||||
def export_mesh(name, v, f):
|
def export_mesh(name, v, f):
|
||||||
if len(v.shape) > 2:
|
if len(v.shape) > 2:
|
||||||
v, f = v[0], f[0]
|
v, f = v[0], f[0]
|
||||||
|
@ -498,59 +511,63 @@ def export_mesh(name, v, f):
|
||||||
mesh.triangles = o3d.utility.Vector3iVector(f)
|
mesh.triangles = o3d.utility.Vector3iVector(f)
|
||||||
o3d.io.write_triangle_mesh(name, mesh)
|
o3d.io.write_triangle_mesh(name, mesh)
|
||||||
|
|
||||||
|
|
||||||
def scale2onet(p, scale=1.2):
|
def scale2onet(p, scale=1.2):
|
||||||
'''
|
"""Scale the point cloud from SAP to ONet range."""
|
||||||
Scale the point cloud from SAP to ONet range
|
|
||||||
'''
|
|
||||||
return (p - 0.5) * scale
|
return (p - 0.5) * scale
|
||||||
|
|
||||||
|
|
||||||
def update_optimizer(inputs, cfg, epoch, model=None, schedule=None):
|
def update_optimizer(inputs, cfg, epoch, model=None, schedule=None):
|
||||||
if model is not None:
|
if model is not None:
|
||||||
if schedule is not None:
|
if schedule is not None:
|
||||||
optimizer = torch.optim.Adam([
|
optimizer = torch.optim.Adam(
|
||||||
{"params": model.parameters(),
|
[
|
||||||
"lr": schedule[0].get_learning_rate(epoch)},
|
{"params": model.parameters(), "lr": schedule[0].get_learning_rate(epoch)},
|
||||||
{"params": inputs,
|
{"params": inputs, "lr": schedule[1].get_learning_rate(epoch)},
|
||||||
"lr": schedule[1].get_learning_rate(epoch)}])
|
],
|
||||||
elif 'lr' in cfg['train']:
|
)
|
||||||
optimizer = torch.optim.Adam([
|
elif "lr" in cfg["train"]:
|
||||||
{"params": model.parameters(),
|
optimizer = torch.optim.Adam(
|
||||||
"lr": float(cfg['train']['lr'])},
|
[
|
||||||
{"params": inputs,
|
{"params": model.parameters(), "lr": float(cfg["train"]["lr"])},
|
||||||
"lr": float(cfg['train']['lr_pcl'])}])
|
{"params": inputs, "lr": float(cfg["train"]["lr_pcl"])},
|
||||||
|
],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception('no known learning rate')
|
msg = "no known learning rate"
|
||||||
|
raise Exception(msg)
|
||||||
else:
|
else:
|
||||||
if schedule is not None:
|
if schedule is not None:
|
||||||
optimizer = torch.optim.Adam([inputs], lr=schedule[0].get_learning_rate(epoch))
|
optimizer = torch.optim.Adam([inputs], lr=schedule[0].get_learning_rate(epoch))
|
||||||
else:
|
else:
|
||||||
optimizer = torch.optim.Adam([inputs], lr=float(cfg['train']['lr_pcl']))
|
optimizer = torch.optim.Adam([inputs], lr=float(cfg["train"]["lr_pcl"]))
|
||||||
|
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
def is_url(url):
|
def is_url(url):
|
||||||
scheme = urllib.parse.urlparse(url).scheme
|
scheme = urllib.parse.urlparse(url).scheme
|
||||||
return scheme in ('http', 'https')
|
return scheme in ("http", "https")
|
||||||
|
|
||||||
|
|
||||||
def load_url(url):
|
def load_url(url):
|
||||||
'''Load a module dictionary from url.
|
"""Load a module dictionary from url.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
url (str): url to saved model
|
url (str): url to saved model
|
||||||
'''
|
"""
|
||||||
print(url)
|
print(url)
|
||||||
print('=> Loading checkpoint from url...')
|
print("=> Loading checkpoint from url...")
|
||||||
state_dict = model_zoo.load_url(url, progress=True)
|
state_dict = model_zoo.load_url(url, progress=True)
|
||||||
|
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
class GaussianSmoothing(nn.Module):
|
class GaussianSmoothing(nn.Module):
|
||||||
"""
|
"""Apply gaussian smoothing on a
|
||||||
Apply gaussian smoothing on a
|
|
||||||
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
|
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
|
||||||
in the input using a depthwise convolution.
|
in the input using a depthwise convolution.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well.
|
channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well.
|
||||||
kernel_size (int, sequence): Size of the gaussian kernel.
|
kernel_size (int, sequence): Size of the gaussian kernel.
|
||||||
|
@ -558,8 +575,9 @@ class GaussianSmoothing(nn.Module):
|
||||||
dim (int, optional): The number of dimensions of the data.
|
dim (int, optional): The number of dimensions of the data.
|
||||||
Default value is 2 (spatial).
|
Default value is 2 (spatial).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, kernel_size, sigma, dim=3):
|
def __init__(self, channels, kernel_size, sigma, dim=3):
|
||||||
super(GaussianSmoothing, self).__init__()
|
super().__init__()
|
||||||
if isinstance(kernel_size, numbers.Number):
|
if isinstance(kernel_size, numbers.Number):
|
||||||
kernel_size = [kernel_size] * dim
|
kernel_size = [kernel_size] * dim
|
||||||
if isinstance(sigma, numbers.Number):
|
if isinstance(sigma, numbers.Number):
|
||||||
|
@ -569,15 +587,11 @@ class GaussianSmoothing(nn.Module):
|
||||||
# gaussian function of each dimension.
|
# gaussian function of each dimension.
|
||||||
kernel = 1
|
kernel = 1
|
||||||
meshgrids = torch.meshgrid(
|
meshgrids = torch.meshgrid(
|
||||||
[
|
[torch.arange(size, dtype=torch.float32) for size in kernel_size],
|
||||||
torch.arange(size, dtype=torch.float32)
|
|
||||||
for size in kernel_size
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
||||||
mean = (size - 1) / 2
|
mean = (size - 1) / 2
|
||||||
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
|
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / std) ** 2) / 2)
|
||||||
torch.exp(-((mgrid - mean) / std) ** 2 / 2)
|
|
||||||
|
|
||||||
# Make sure sum of values in gaussian kernel equals 1.
|
# Make sure sum of values in gaussian kernel equals 1.
|
||||||
kernel = kernel / torch.sum(kernel)
|
kernel = kernel / torch.sum(kernel)
|
||||||
|
@ -586,7 +600,7 @@ class GaussianSmoothing(nn.Module):
|
||||||
kernel = kernel.view(1, 1, *kernel.size())
|
kernel = kernel.view(1, 1, *kernel.size())
|
||||||
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
||||||
|
|
||||||
self.register_buffer('weight', kernel)
|
self.register_buffer("weight", kernel)
|
||||||
self.groups = channels
|
self.groups = channels
|
||||||
|
|
||||||
if dim == 1:
|
if dim == 1:
|
||||||
|
@ -596,36 +610,44 @@ class GaussianSmoothing(nn.Module):
|
||||||
elif dim == 3:
|
elif dim == 3:
|
||||||
self.conv = F.conv3d
|
self.conv = F.conv3d
|
||||||
else:
|
else:
|
||||||
|
msg = f"Only 1, 2 and 3 dimensions are supported. Received {dim}."
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
|
msg,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
"""
|
"""Apply gaussian filter to input.
|
||||||
Apply gaussian filter to input.
|
|
||||||
Arguments:
|
Arguments:
|
||||||
input (torch.Tensor): Input to apply gaussian filter on.
|
input (torch.Tensor): Input to apply gaussian filter on.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
filtered (torch.Tensor): Filtered output.
|
filtered (torch.Tensor): Filtered output.
|
||||||
"""
|
"""
|
||||||
return self.conv(input, weight=self.weight, groups=self.groups)
|
return self.conv(input, weight=self.weight, groups=self.groups)
|
||||||
|
|
||||||
|
|
||||||
# Originally from https://github.com/amosgropp/IGR/blob/0db06b1273/code/utils/general.py
|
# Originally from https://github.com/amosgropp/IGR/blob/0db06b1273/code/utils/general.py
|
||||||
def get_learning_rate_schedules(schedule_specs):
|
def get_learning_rate_schedules(schedule_specs):
|
||||||
|
|
||||||
schedules = []
|
schedules = []
|
||||||
|
|
||||||
for key in schedule_specs.keys():
|
for key in schedule_specs.keys():
|
||||||
schedules.append(StepLearningRateSchedule(
|
schedules.append(
|
||||||
schedule_specs[key]['initial'],
|
StepLearningRateSchedule(
|
||||||
|
schedule_specs[key]["initial"],
|
||||||
schedule_specs[key]["interval"],
|
schedule_specs[key]["interval"],
|
||||||
schedule_specs[key]["factor"],
|
schedule_specs[key]["factor"],
|
||||||
schedule_specs[key]["final"]))
|
schedule_specs[key]["final"],
|
||||||
|
),
|
||||||
|
)
|
||||||
return schedules
|
return schedules
|
||||||
|
|
||||||
|
|
||||||
class LearningRateSchedule:
|
class LearningRateSchedule:
|
||||||
def get_learning_rate(self, epoch):
|
def get_learning_rate(self, epoch):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class StepLearningRateSchedule(LearningRateSchedule):
|
class StepLearningRateSchedule(LearningRateSchedule):
|
||||||
def __init__(self, initial, interval, factor, final=1e-6):
|
def __init__(self, initial, interval, factor, final=1e-6):
|
||||||
self.initial = float(initial)
|
self.initial = float(initial)
|
||||||
|
@ -640,6 +662,7 @@ class StepLearningRateSchedule(LearningRateSchedule):
|
||||||
else:
|
else:
|
||||||
return self.final
|
return self.final
|
||||||
|
|
||||||
|
|
||||||
def adjust_learning_rate(lr_schedules, optimizer, epoch):
|
def adjust_learning_rate(lr_schedules, optimizer, epoch):
|
||||||
for i, param_group in enumerate(optimizer.param_groups):
|
for i, param_group in enumerate(optimizer.param_groups):
|
||||||
param_group["lr"] = lr_schedules[i].get_learning_rate(epoch)
|
param_group["lr"] = lr_schedules[i].get_learning_rate(epoch)
|
||||||
|
|
106
src/visualize.py
106
src/visualize.py
|
@ -1,41 +1,41 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import open3d as o3d
|
import open3d as o3d
|
||||||
import matplotlib.pyplot as plt
|
import torch
|
||||||
from skimage import measure
|
|
||||||
from src.utils import calc_inters_points, grid_interp
|
|
||||||
from scipy import ndimage
|
from scipy import ndimage
|
||||||
from tqdm import trange
|
|
||||||
from torchvision.utils import save_image
|
from torchvision.utils import save_image
|
||||||
from pdb import set_trace as st
|
from tqdm import trange
|
||||||
|
|
||||||
|
from src.utils import calc_inters_points, grid_interp
|
||||||
|
|
||||||
|
|
||||||
def visualize_points_mesh(vis, points, normals, verts, faces, cfg, it, epoch, color_v=None):
|
def visualize_points_mesh(vis, points, normals, verts, faces, cfg, it, epoch, color_v=None):
|
||||||
''' Visualization.
|
"""Visualization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (dict): data dictionary
|
data (dict): data dictionary
|
||||||
depth (int): PSR depth
|
depth (int): PSR depth
|
||||||
out_path (str): output path for the mesh
|
out_path (str): output path for the mesh
|
||||||
'''
|
"""
|
||||||
mesh = o3d.geometry.TriangleMesh()
|
mesh = o3d.geometry.TriangleMesh()
|
||||||
mesh.vertices = o3d.utility.Vector3dVector(verts)
|
mesh.vertices = o3d.utility.Vector3dVector(verts)
|
||||||
mesh.triangles = o3d.utility.Vector3iVector(faces)
|
mesh.triangles = o3d.utility.Vector3iVector(faces)
|
||||||
mesh.paint_uniform_color(np.array([0.7,0.7,0.7]))
|
mesh.paint_uniform_color(np.array([0.7, 0.7, 0.7]))
|
||||||
if color_v is not None:
|
if color_v is not None:
|
||||||
mesh.vertex_colors = o3d.utility.Vector3dVector(color_v)
|
mesh.vertex_colors = o3d.utility.Vector3dVector(color_v)
|
||||||
|
|
||||||
if vis is not None:
|
if vis is not None:
|
||||||
dir_o3d = cfg['train']['dir_o3d']
|
dir_o3d = cfg["train"]["dir_o3d"]
|
||||||
wire = o3d.geometry.LineSet.create_from_triangle_mesh(mesh)
|
o3d.geometry.LineSet.create_from_triangle_mesh(mesh)
|
||||||
|
|
||||||
p = points.squeeze(0).detach().cpu().numpy()
|
p = points.squeeze(0).detach().cpu().numpy()
|
||||||
n = normals.squeeze(0).detach().cpu().numpy()
|
n = normals.squeeze(0).detach().cpu().numpy()
|
||||||
pcd = o3d.geometry.PointCloud()
|
pcd = o3d.geometry.PointCloud()
|
||||||
pcd.points = o3d.utility.Vector3dVector(p)
|
pcd.points = o3d.utility.Vector3dVector(p)
|
||||||
pcd.normals = o3d.utility.Vector3dVector(n)
|
pcd.normals = o3d.utility.Vector3dVector(n)
|
||||||
pcd.paint_uniform_color(np.array([0.7,0.7,1.0]))
|
pcd.paint_uniform_color(np.array([0.7, 0.7, 1.0]))
|
||||||
# pcd = pcd.uniform_down_sample(5)
|
# pcd = pcd.uniform_down_sample(5)
|
||||||
|
|
||||||
vis.clear_geometries()
|
vis.clear_geometries()
|
||||||
|
@ -43,53 +43,51 @@ def visualize_points_mesh(vis, points, normals, verts, faces, cfg, it, epoch, co
|
||||||
vis.update_geometry(mesh)
|
vis.update_geometry(mesh)
|
||||||
|
|
||||||
#! Thingi wheel - an example for how to change cameras in Open3D viewers
|
#! Thingi wheel - an example for how to change cameras in Open3D viewers
|
||||||
vis.get_view_control().set_front([ 0.0461, -0.7467, 0.6635 ])
|
vis.get_view_control().set_front([0.0461, -0.7467, 0.6635])
|
||||||
vis.get_view_control().set_lookat([ 0.0092, 0.0078, 0.0638 ])
|
vis.get_view_control().set_lookat([0.0092, 0.0078, 0.0638])
|
||||||
vis.get_view_control().set_up([ 0.0520, 0.6651, 0.7449 ])
|
vis.get_view_control().set_up([0.0520, 0.6651, 0.7449])
|
||||||
vis.get_view_control().set_zoom(0.7)
|
vis.get_view_control().set_zoom(0.7)
|
||||||
vis.poll_events()
|
vis.poll_events()
|
||||||
|
|
||||||
out_path = os.path.join(dir_o3d, '{}.jpg'.format(it))
|
out_path = os.path.join(dir_o3d, f"{it}.jpg")
|
||||||
vis.capture_screen_image(out_path)
|
vis.capture_screen_image(out_path)
|
||||||
|
|
||||||
vis.clear_geometries()
|
vis.clear_geometries()
|
||||||
vis.add_geometry(pcd, reset_bounding_box=False)
|
vis.add_geometry(pcd, reset_bounding_box=False)
|
||||||
vis.update_geometry(pcd)
|
vis.update_geometry(pcd)
|
||||||
vis.get_render_option().point_show_normal=True # visualize point normals
|
vis.get_render_option().point_show_normal = True # visualize point normals
|
||||||
|
|
||||||
vis.get_view_control().set_front([ 0.0461, -0.7467, 0.6635 ])
|
vis.get_view_control().set_front([0.0461, -0.7467, 0.6635])
|
||||||
vis.get_view_control().set_lookat([ 0.0092, 0.0078, 0.0638 ])
|
vis.get_view_control().set_lookat([0.0092, 0.0078, 0.0638])
|
||||||
vis.get_view_control().set_up([ 0.0520, 0.6651, 0.7449 ])
|
vis.get_view_control().set_up([0.0520, 0.6651, 0.7449])
|
||||||
vis.get_view_control().set_zoom(0.7)
|
vis.get_view_control().set_zoom(0.7)
|
||||||
vis.poll_events()
|
vis.poll_events()
|
||||||
|
|
||||||
out_path = os.path.join(dir_o3d, '{}_pcd.jpg'.format(it))
|
out_path = os.path.join(dir_o3d, f"{it}_pcd.jpg")
|
||||||
vis.capture_screen_image(out_path)
|
vis.capture_screen_image(out_path)
|
||||||
|
|
||||||
def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name='video.mp4'):
|
|
||||||
|
def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name="video.mp4"):
|
||||||
if pose is not None:
|
if pose is not None:
|
||||||
device = psr_grid.device
|
device = psr_grid.device
|
||||||
# get world coordinate of grid points [-1, 1]
|
# get world coordinate of grid points [-1, 1]
|
||||||
res = psr_grid.shape[-1]
|
res = psr_grid.shape[-1]
|
||||||
x = torch.linspace(-1, 1, steps=res)
|
x = torch.linspace(-1, 1, steps=res)
|
||||||
co_x, co_y, co_z = torch.meshgrid(x, x, x)
|
co_x, co_y, co_z = torch.meshgrid(x, x, x)
|
||||||
co_grid = torch.stack(
|
co_grid = torch.stack([co_x.reshape(-1), co_y.reshape(-1), co_z.reshape(-1)], dim=1).to(device).unsqueeze(0)
|
||||||
[co_x.reshape(-1), co_y.reshape(-1), co_z.reshape(-1)],
|
|
||||||
dim=1).to(device).unsqueeze(0)
|
|
||||||
|
|
||||||
# visualize the projected occ_soft value
|
# visualize the projected occ_soft value
|
||||||
res = 128
|
res = 128
|
||||||
psr_grid = psr_grid.reshape(-1)
|
psr_grid = psr_grid.reshape(-1)
|
||||||
out_mask = psr_grid>0
|
out_mask = psr_grid > 0
|
||||||
in_mask = psr_grid<0
|
in_mask = psr_grid < 0
|
||||||
pix = pose.transform_points_screen(co_grid, ((res, res),))[..., :2].round().long().squeeze()
|
pix = pose.transform_points_screen(co_grid, ((res, res),))[..., :2].round().long().squeeze()
|
||||||
vis_mask = (pix[..., 0]>=0) & (pix[..., 0]<=res-1) & \
|
vis_mask = (pix[..., 0] >= 0) & (pix[..., 0] <= res - 1) & (pix[..., 1] >= 0) & (pix[..., 1] <= res - 1)
|
||||||
(pix[..., 1]>=0) & (pix[..., 1]<=res-1)
|
|
||||||
pix_out = pix[vis_mask & out_mask]
|
pix_out = pix[vis_mask & out_mask]
|
||||||
pix_in = pix[vis_mask & in_mask]
|
pix_in = pix[vis_mask & in_mask]
|
||||||
|
|
||||||
img = torch.ones([res,res]).to(device)
|
img = torch.ones([res, res]).to(device)
|
||||||
psr_grid = torch.sigmoid(- psr_grid * 5)
|
psr_grid = torch.sigmoid(-psr_grid * 5)
|
||||||
img[pix_out[:, 1], pix_out[:, 0]] = psr_grid[vis_mask & out_mask]
|
img[pix_out[:, 1], pix_out[:, 0]] = psr_grid[vis_mask & out_mask]
|
||||||
img[pix_in[:, 1], pix_in[:, 0]] = psr_grid[vis_mask & in_mask]
|
img[pix_in[:, 1], pix_in[:, 0]] = psr_grid[vis_mask & in_mask]
|
||||||
# save_image(img, 'tmp.png', normalize=True)
|
# save_image(img, 'tmp.png', normalize=True)
|
||||||
|
@ -98,70 +96,70 @@ def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name='video.
|
||||||
dir_psr_vis = out_dir
|
dir_psr_vis = out_dir
|
||||||
os.makedirs(dir_psr_vis, exist_ok=True)
|
os.makedirs(dir_psr_vis, exist_ok=True)
|
||||||
psr_grid = psr_grid.squeeze().detach().cpu().numpy()
|
psr_grid = psr_grid.squeeze().detach().cpu().numpy()
|
||||||
axis = ['x', 'y', 'z']
|
|
||||||
s = psr_grid.shape[0]
|
s = psr_grid.shape[0]
|
||||||
for idx in trange(s):
|
for idx in trange(s):
|
||||||
my_dpi = 100
|
my_dpi = 100
|
||||||
plt.figure(figsize=(1000/my_dpi, 300/my_dpi), dpi=my_dpi)
|
plt.figure(figsize=(1000 / my_dpi, 300 / my_dpi), dpi=my_dpi)
|
||||||
plt.subplot(1, 3, 1)
|
plt.subplot(1, 3, 1)
|
||||||
plt.imshow(ndimage.rotate(psr_grid[idx], 180, mode='nearest'), cmap='nipy_spectral')
|
plt.imshow(ndimage.rotate(psr_grid[idx], 180, mode="nearest"), cmap="nipy_spectral")
|
||||||
plt.clim(-1, 1)
|
plt.clim(-1, 1)
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
plt.title('x')
|
plt.title("x")
|
||||||
plt.grid("off")
|
plt.grid("off")
|
||||||
plt.axis("off")
|
plt.axis("off")
|
||||||
|
|
||||||
plt.subplot(1, 3, 2)
|
plt.subplot(1, 3, 2)
|
||||||
plt.imshow(ndimage.rotate(psr_grid[:, idx], 180, mode='nearest'), cmap='nipy_spectral')
|
plt.imshow(ndimage.rotate(psr_grid[:, idx], 180, mode="nearest"), cmap="nipy_spectral")
|
||||||
plt.clim(-1, 1)
|
plt.clim(-1, 1)
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
plt.title('y')
|
plt.title("y")
|
||||||
plt.grid("off")
|
plt.grid("off")
|
||||||
plt.axis("off")
|
plt.axis("off")
|
||||||
|
|
||||||
plt.subplot(1, 3, 3)
|
plt.subplot(1, 3, 3)
|
||||||
plt.imshow(ndimage.rotate(psr_grid[:,:,idx], 90, mode='nearest'), cmap='nipy_spectral')
|
plt.imshow(ndimage.rotate(psr_grid[:, :, idx], 90, mode="nearest"), cmap="nipy_spectral")
|
||||||
plt.clim(-1, 1)
|
plt.clim(-1, 1)
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
plt.title('z')
|
plt.title("z")
|
||||||
plt.grid("off")
|
plt.grid("off")
|
||||||
plt.axis("off")
|
plt.axis("off")
|
||||||
|
|
||||||
|
plt.savefig(os.path.join(dir_psr_vis, f"{idx}"), pad_inches=0, dpi=100)
|
||||||
plt.savefig(os.path.join(dir_psr_vis, '{}'.format(idx)), pad_inches = 0, dpi=100)
|
|
||||||
plt.close()
|
plt.close()
|
||||||
os.system("rm {}/{}".format(dir_psr_vis, out_video_name))
|
os.system(f"rm {dir_psr_vis}/{out_video_name}")
|
||||||
os.system("ffmpeg -framerate 25 -start_number 0 -i {}/%d.png -pix_fmt yuv420p -crf 17 {}/{}".format(dir_psr_vis, dir_psr_vis, out_video_name))
|
os.system(
|
||||||
|
f"ffmpeg -framerate 25 -start_number 0 -i {dir_psr_vis}/%d.png -pix_fmt yuv420p -crf 17 {dir_psr_vis}/{out_video_name}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
def visualize_mesh_phong(v, f, n, pose, img_size, name, device='cpu'):
|
|
||||||
|
def visualize_mesh_phong(v, f, n, pose, img_size, name, device="cpu"):
|
||||||
#! Mesh rendering using Phong shading model
|
#! Mesh rendering using Phong shading model
|
||||||
_, mask, f_p, w = calc_inters_points(v, f, pose, img_size)
|
_, mask, f_p, w = calc_inters_points(v, f, pose, img_size)
|
||||||
n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]]
|
n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]]
|
||||||
n_inters = w[..., 0, None] * n_a.squeeze() + \
|
n_inters = w[..., 0, None] * n_a.squeeze() + w[..., 1, None] * n_b.squeeze() + w[..., 2, None] * n_c.squeeze()
|
||||||
w[..., 1, None] * n_b.squeeze() + \
|
|
||||||
w[..., 2, None] * n_c.squeeze()
|
|
||||||
n_inters = n_inters.detach().to(device)
|
n_inters = n_inters.detach().to(device)
|
||||||
light_source = -pose.R@pose.T.squeeze()
|
light_source = -pose.R @ pose.T.squeeze()
|
||||||
light = (light_source / light_source.norm(2)).permute(1, 0).to(device).float()
|
light = (light_source / light_source.norm(2)).permute(1, 0).to(device).float()
|
||||||
diffuse_per = torch.Tensor([0.7,0.7,0.7]).float()
|
diffuse_per = torch.Tensor([0.7, 0.7, 0.7]).float()
|
||||||
ambiant = torch.Tensor([0.3,0.3,0.3]).float()
|
ambiant = torch.Tensor([0.3, 0.3, 0.3]).float()
|
||||||
|
|
||||||
diffuse = torch.mm(n_inters, light).clamp_min(0).repeat(1, 3) * diffuse_per.unsqueeze(0).to(device)
|
diffuse = torch.mm(n_inters, light).clamp_min(0).repeat(1, 3) * diffuse_per.unsqueeze(0).to(device)
|
||||||
|
|
||||||
phong = torch.ones([img_size[0]*img_size[1], 3]).to(device)
|
phong = torch.ones([img_size[0] * img_size[1], 3]).to(device)
|
||||||
phong[mask] = (ambiant.unsqueeze(0).to(device) + diffuse).clamp_max(1.0)
|
phong[mask] = (ambiant.unsqueeze(0).to(device) + diffuse).clamp_max(1.0)
|
||||||
pp = phong.reshape(img_size[0], img_size[1], -1)
|
pp = phong.reshape(img_size[0], img_size[1], -1)
|
||||||
save_image(pp.permute(2, 0, 1), name)
|
save_image(pp.permute(2, 0, 1), name)
|
||||||
|
|
||||||
|
|
||||||
def render_rgb(v, f, n, pose, renderer, img_size, mask_gt=None, ray=None, fea_grid=None):
|
def render_rgb(v, f, n, pose, renderer, img_size, mask_gt=None, ray=None, fea_grid=None):
|
||||||
p_inters, mask, f_p, w = calc_inters_points(v.detach(), f, pose, img_size, mask_gt=mask_gt)
|
p_inters, mask, f_p, w = calc_inters_points(v.detach(), f, pose, img_size, mask_gt=mask_gt)
|
||||||
# normals for p_inters
|
# normals for p_inters
|
||||||
n_inters = None
|
n_inters = None
|
||||||
if n is not None:
|
if n is not None:
|
||||||
n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]]
|
n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]]
|
||||||
n_inters = w[..., 0, None] * n_a.squeeze() + \
|
n_inters = w[..., 0, None] * n_a.squeeze() + w[..., 1, None] * n_b.squeeze() + w[..., 2, None] * n_c.squeeze()
|
||||||
w[..., 1, None] * n_b.squeeze() + \
|
|
||||||
w[..., 2, None] * n_c.squeeze()
|
|
||||||
if ray is not None:
|
if ray is not None:
|
||||||
ray = ray.squeeze()[mask]
|
ray = ray.squeeze()[mask]
|
||||||
|
|
||||||
|
|
159
train.py
159
train.py
|
@ -4,79 +4,90 @@ abspath = os.path.abspath(__file__)
|
||||||
dname = os.path.dirname(abspath)
|
dname = os.path.dirname(abspath)
|
||||||
os.chdir(dname)
|
os.chdir(dname)
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import open3d as o3d
|
|
||||||
|
|
||||||
import numpy as np; np.set_printoptions(precision=4)
|
|
||||||
import shutil, argparse, time
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from src import config
|
from src import config
|
||||||
from src.data import collate_remove_none, collate_stack_together, worker_init_fn
|
from src.data import collate_remove_none, worker_init_fn
|
||||||
from src.training import Trainer
|
|
||||||
from src.model import Encode2Points
|
from src.model import Encode2Points
|
||||||
from src.utils import load_config, initialize_logger, \
|
from src.training import Trainer
|
||||||
AverageMeter, load_model_manual
|
from src.utils import AverageMeter, initialize_logger, load_config, load_model_manual
|
||||||
|
|
||||||
|
np.set_printoptions(precision=4)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='MNIST toy experiment')
|
parser = argparse.ArgumentParser(description="MNIST toy experiment")
|
||||||
parser.add_argument('config', type=str, help='Path to config file.')
|
parser.add_argument("config", type=str, help="Path to config file.")
|
||||||
parser.add_argument('--no_cuda', action='store_true', default=False,
|
parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
|
||||||
help='disables CUDA training')
|
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
|
||||||
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
cfg = load_config(args.config, 'configs/default.yaml')
|
cfg = load_config(args.config, "configs/default.yaml")
|
||||||
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||||
device = torch.device("cuda" if use_cuda else "cpu")
|
device = torch.device("cuda" if use_cuda else "cpu")
|
||||||
input_type = cfg['data']['input_type']
|
cfg["data"]["input_type"]
|
||||||
batch_size = cfg['train']['batch_size']
|
batch_size = cfg["train"]["batch_size"]
|
||||||
model_selection_metric = cfg['train']['model_selection_metric']
|
model_selection_metric = cfg["train"]["model_selection_metric"]
|
||||||
|
|
||||||
# PYTORCH VERSION > 1.0.0
|
# PYTORCH VERSION > 1.0.0
|
||||||
assert(float(torch.__version__.split('.')[-3]) > 0)
|
assert float(torch.__version__.split(".")[-3]) > 0
|
||||||
|
|
||||||
# boiler-plate
|
# boiler-plate
|
||||||
if cfg['train']['timestamp']:
|
if cfg["train"]["timestamp"]:
|
||||||
cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S")
|
cfg["train"]["out_dir"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
|
||||||
logger = initialize_logger(cfg)
|
logger = initialize_logger(cfg)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
shutil.copyfile(args.config, os.path.join(cfg['train']['out_dir'], 'config.yaml'))
|
shutil.copyfile(args.config, os.path.join(cfg["train"]["out_dir"], "config.yaml"))
|
||||||
|
|
||||||
logger.info("using GPU: " + torch.cuda.get_device_name(0))
|
logger.info("using GPU: " + torch.cuda.get_device_name(0))
|
||||||
|
|
||||||
# TensorboardX writer
|
# TensorboardX writer
|
||||||
tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log")
|
tblogdir = os.path.join(cfg["train"]["out_dir"], "tensorboard_log")
|
||||||
if not os.path.exists(tblogdir):
|
if not os.path.exists(tblogdir):
|
||||||
os.makedirs(tblogdir, exist_ok=True)
|
os.makedirs(tblogdir, exist_ok=True)
|
||||||
writer = SummaryWriter(log_dir=tblogdir)
|
writer = SummaryWriter(log_dir=tblogdir)
|
||||||
|
|
||||||
|
|
||||||
inputs = None
|
inputs = None
|
||||||
train_dataset = config.get_dataset('train', cfg)
|
train_dataset = config.get_dataset("train", cfg)
|
||||||
val_dataset = config.get_dataset('val', cfg)
|
val_dataset = config.get_dataset("val", cfg)
|
||||||
vis_dataset = config.get_dataset('vis', cfg)
|
vis_dataset = config.get_dataset("vis", cfg)
|
||||||
|
|
||||||
|
|
||||||
collate_fn = collate_remove_none
|
collate_fn = collate_remove_none
|
||||||
|
|
||||||
train_loader = torch.utils.data.DataLoader(
|
train_loader = torch.utils.data.DataLoader(
|
||||||
train_dataset, batch_size=batch_size, num_workers=cfg['train']['n_workers'], shuffle=True,
|
train_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=cfg["train"]["n_workers"],
|
||||||
|
shuffle=True,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
worker_init_fn=worker_init_fn)
|
worker_init_fn=worker_init_fn,
|
||||||
|
)
|
||||||
|
|
||||||
val_loader = torch.utils.data.DataLoader(
|
val_loader = torch.utils.data.DataLoader(
|
||||||
val_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False,
|
val_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=cfg["train"]["n_workers_val"],
|
||||||
|
shuffle=False,
|
||||||
collate_fn=collate_remove_none,
|
collate_fn=collate_remove_none,
|
||||||
worker_init_fn=worker_init_fn)
|
worker_init_fn=worker_init_fn,
|
||||||
|
)
|
||||||
|
|
||||||
vis_loader = torch.utils.data.DataLoader(
|
vis_loader = torch.utils.data.DataLoader(
|
||||||
vis_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False,
|
vis_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=cfg["train"]["n_workers_val"],
|
||||||
|
shuffle=False,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
worker_init_fn=worker_init_fn)
|
worker_init_fn=worker_init_fn,
|
||||||
|
)
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
model = torch.nn.DataParallel(Encode2Points(cfg)).to(device)
|
model = torch.nn.DataParallel(Encode2Points(cfg)).to(device)
|
||||||
|
@ -84,38 +95,35 @@ def main():
|
||||||
model = Encode2Points(cfg).to(device)
|
model = Encode2Points(cfg).to(device)
|
||||||
|
|
||||||
n_parameter = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
n_parameter = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
logger.info('Number of parameters: %d'% n_parameter)
|
logger.info("Number of parameters: %d" % n_parameter)
|
||||||
# load model
|
# load model
|
||||||
try:
|
try:
|
||||||
# load model
|
# load model
|
||||||
state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt'))
|
state_dict = torch.load(os.path.join(cfg["train"]["out_dir"], "model.pt"))
|
||||||
load_model_manual(state_dict['state_dict'], model)
|
load_model_manual(state_dict["state_dict"], model)
|
||||||
|
|
||||||
out = "Load model from iteration %d" % state_dict.get('it', 0)
|
out = "Load model from iteration %d" % state_dict.get("it", 0)
|
||||||
logger.info(out)
|
logger.info(out)
|
||||||
# load point cloud
|
# load point cloud
|
||||||
except:
|
except:
|
||||||
state_dict = dict()
|
state_dict = dict()
|
||||||
|
|
||||||
metric_val_best = state_dict.get(
|
metric_val_best = state_dict.get("loss_val_best", np.inf)
|
||||||
'loss_val_best', np.inf)
|
|
||||||
|
|
||||||
logger.info('Current best validation metric (%s): %.8f'
|
logger.info(f"Current best validation metric ({model_selection_metric}): {metric_val_best:.8f}")
|
||||||
% (model_selection_metric, metric_val_best))
|
|
||||||
|
|
||||||
LR = float(cfg['train']['lr'])
|
LR = float(cfg["train"]["lr"])
|
||||||
optimizer = optim.Adam(model.parameters(), lr=LR)
|
optimizer = optim.Adam(model.parameters(), lr=LR)
|
||||||
|
|
||||||
start_epoch = state_dict.get('epoch', -1)
|
start_epoch = state_dict.get("epoch", -1)
|
||||||
it = state_dict.get('it', -1)
|
it = state_dict.get("it", -1)
|
||||||
|
|
||||||
trainer = Trainer(cfg, optimizer, device=device)
|
trainer = Trainer(cfg, optimizer, device=device)
|
||||||
runtime = {}
|
runtime = {}
|
||||||
runtime['all'] = AverageMeter()
|
runtime["all"] = AverageMeter()
|
||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1):
|
for epoch in range(start_epoch + 1, cfg["train"]["total_epochs"] + 1):
|
||||||
|
|
||||||
for batch in train_loader:
|
for batch in train_loader:
|
||||||
it += 1
|
it += 1
|
||||||
|
|
||||||
|
@ -124,62 +132,57 @@ def main():
|
||||||
|
|
||||||
# measure elapsed time
|
# measure elapsed time
|
||||||
end = time.time()
|
end = time.time()
|
||||||
runtime['all'].update(end - start)
|
runtime["all"].update(end - start)
|
||||||
|
|
||||||
if it % cfg['train']['print_every'] == 0:
|
if it % cfg["train"]["print_every"] == 0:
|
||||||
log_text = ('[Epoch %02d] it=%d, loss=%.4f') %(epoch, it, loss)
|
log_text = ("[Epoch %02d] it=%d, loss=%.4f") % (epoch, it, loss)
|
||||||
writer.add_scalar('train/loss', loss, it)
|
writer.add_scalar("train/loss", loss, it)
|
||||||
if loss_each is not None:
|
if loss_each is not None:
|
||||||
for k, l in loss_each.items():
|
for k, l in loss_each.items():
|
||||||
if l.item() != 0.:
|
if l.item() != 0.0:
|
||||||
log_text += (' loss_%s=%.4f') % (k, l.item())
|
log_text += f" loss_{k}={l.item():.4f}"
|
||||||
writer.add_scalar('train/%s' % k, l, it)
|
writer.add_scalar("train/%s" % k, l, it)
|
||||||
|
|
||||||
log_text += (' time=%.3f / %.2f') % (runtime['all'].val, runtime['all'].sum)
|
log_text += (" time={:.3f} / {:.2f}").format(runtime["all"].val, runtime["all"].sum)
|
||||||
logger.info(log_text)
|
logger.info(log_text)
|
||||||
|
|
||||||
if (it>0)& (it % cfg['train']['visualize_every'] == 0):
|
if (it > 0) & (it % cfg["train"]["visualize_every"] == 0):
|
||||||
for i, batch_vis in enumerate(vis_loader):
|
for i, batch_vis in enumerate(vis_loader):
|
||||||
trainer.save(model, batch_vis, it, i)
|
trainer.save(model, batch_vis, it, i)
|
||||||
if i >= 4:
|
if i >= 4:
|
||||||
break
|
break
|
||||||
logger.info('Saved mesh and pointcloud')
|
logger.info("Saved mesh and pointcloud")
|
||||||
|
|
||||||
# run validation
|
# run validation
|
||||||
if it > 0 and (it % cfg['train']['validate_every']) == 0:
|
if it > 0 and (it % cfg["train"]["validate_every"]) == 0:
|
||||||
eval_dict = trainer.evaluate(val_loader, model)
|
eval_dict = trainer.evaluate(val_loader, model)
|
||||||
metric_val = eval_dict[model_selection_metric]
|
metric_val = eval_dict[model_selection_metric]
|
||||||
logger.info('Validation metric (%s): %.4f'
|
logger.info(f"Validation metric ({model_selection_metric}): {metric_val:.4f}")
|
||||||
% (model_selection_metric, metric_val))
|
|
||||||
|
|
||||||
for k, v in eval_dict.items():
|
for k, v in eval_dict.items():
|
||||||
writer.add_scalar('val/%s' % k, v, it)
|
writer.add_scalar("val/%s" % k, v, it)
|
||||||
|
|
||||||
if -(metric_val - metric_val_best) >= 0:
|
if -(metric_val - metric_val_best) >= 0:
|
||||||
metric_val_best = metric_val
|
metric_val_best = metric_val
|
||||||
logger.info('New best model (loss %.4f)' % metric_val_best)
|
logger.info("New best model (loss %.4f)" % metric_val_best)
|
||||||
state = {'epoch': epoch,
|
state = {"epoch": epoch, "it": it, "loss_val_best": metric_val_best}
|
||||||
'it': it,
|
state["state_dict"] = model.state_dict()
|
||||||
'loss_val_best': metric_val_best}
|
torch.save(state, os.path.join(cfg["train"]["out_dir"], "model_best.pt"))
|
||||||
state['state_dict'] = model.state_dict()
|
|
||||||
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model_best.pt'))
|
|
||||||
|
|
||||||
# save checkpoint
|
# save checkpoint
|
||||||
if (epoch > 0) & (it % cfg['train']['checkpoint_every'] == 0):
|
if (epoch > 0) & (it % cfg["train"]["checkpoint_every"] == 0):
|
||||||
state = {'epoch': epoch,
|
state = {"epoch": epoch, "it": it, "loss_val_best": metric_val_best}
|
||||||
'it': it,
|
state["state_dict"] = model.state_dict()
|
||||||
'loss_val_best': metric_val_best}
|
|
||||||
pcl = None
|
|
||||||
state['state_dict'] = model.state_dict()
|
|
||||||
|
|
||||||
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt'))
|
torch.save(state, os.path.join(cfg["train"]["out_dir"], "model.pt"))
|
||||||
|
|
||||||
if (it % cfg['train']['backup_every'] == 0):
|
if it % cfg["train"]["backup_every"] == 0:
|
||||||
torch.save(state, os.path.join(cfg['train']['dir_model'], '%04d' % it + '.pt'))
|
torch.save(state, os.path.join(cfg["train"]["dir_model"], "%04d" % it + ".pt"))
|
||||||
logger.info("Backup model at iteration %d" % it)
|
logger.info("Backup model at iteration %d" % it)
|
||||||
logger.info("Save new model at iteration %d" % it)
|
logger.info("Save new model at iteration %d" % it)
|
||||||
|
|
||||||
done=time.time()
|
time.time()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
Loading…
Reference in a new issue