Compare commits
10 commits
79c35e453a
...
43164f892f
Author | SHA1 | Date | |
---|---|---|---|
43164f892f | |||
377132fdd5 | |||
241631628b | |||
aa60b82868 | |||
41b213e4a5 | |||
ee1be08bce | |||
ee2501d2bd | |||
388991be13 | |||
671b0c7c81 | |||
3967559875 |
15
.editorconfig
Normal file
15
.editorconfig
Normal file
|
@ -0,0 +1,15 @@
|
|||
# EditorConfig is awesome: https://EditorConfig.org
|
||||
|
||||
# top-most EditorConfig file
|
||||
root = true
|
||||
|
||||
[*]
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
end_of_line = lf
|
||||
charset = utf-8
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
|
||||
[*.{json,toml,yaml,yml,md}]
|
||||
indent_size = 2
|
28
.gitattributes
vendored
Normal file
28
.gitattributes
vendored
Normal file
|
@ -0,0 +1,28 @@
|
|||
# https://github.com/alexkaratarakis/gitattributes/blob/master/Python.gitattributes
|
||||
# Basic .gitattributes for a python repo.
|
||||
|
||||
# Source files
|
||||
*.pxd text diff=python
|
||||
*.py text diff=python
|
||||
*.py3 text diff=python
|
||||
*.pyw text diff=python
|
||||
*.pyx text diff=python
|
||||
*.pyz text diff=python
|
||||
*.pyi text diff=python
|
||||
|
||||
# Binary files
|
||||
*.db binary
|
||||
*.p binary
|
||||
*.pkl binary
|
||||
*.pickle binary
|
||||
*.pyc binary export-ignore
|
||||
*.pyo binary export-ignore
|
||||
*.pyd binary
|
||||
|
||||
# Jupyter notebook
|
||||
*.ipynb text
|
||||
|
||||
# Note: .db, .p, and .pkl files are associated
|
||||
# with the python modules ``pickle``, ``dbm.*``,
|
||||
# ``shelve``, ``marshal``, ``anydbm``, & ``bsddb``
|
||||
# (among others).
|
71
.gitignore
vendored
71
.gitignore
vendored
|
@ -1,32 +1,22 @@
|
|||
/out
|
||||
/data
|
||||
.vscode
|
||||
.cache
|
||||
*.pyc
|
||||
*.pyd
|
||||
*.pt
|
||||
*.so
|
||||
*.o
|
||||
*.prof
|
||||
*.swp
|
||||
*.lib
|
||||
*.obj
|
||||
*.exp
|
||||
.nfs*
|
||||
*.jpg
|
||||
*.png
|
||||
*.ply
|
||||
*.off
|
||||
*.npz
|
||||
*.txt
|
||||
# *.sh
|
||||
# Personnal ignores
|
||||
lightning_logs/
|
||||
|
||||
*.tar.gz
|
||||
*.vtk
|
||||
|
||||
demo/
|
||||
|
||||
# https://github.com/github/gitignore/blob/main/Python.gitignore
|
||||
# Basic .gitignore for a python repo.
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
|
@ -41,7 +31,6 @@ parts/
|
|||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
|
@ -71,6 +60,7 @@ coverage.xml
|
|||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
|
@ -93,6 +83,7 @@ instance/
|
|||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
|
@ -103,7 +94,9 @@ profile_default/
|
|||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
|
@ -112,7 +105,22 @@ ipython_config.py
|
|||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
|
@ -148,3 +156,16 @@ dmypy.json
|
|||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
|
5
.vscode/settings.json
vendored
5
.vscode/settings.json
vendored
|
@ -1,5 +1,5 @@
|
|||
{
|
||||
"python.defaultInterpreterPath": "/local_scratch/lfainsin/.conda/envs/pyg/bin/python", // required for python ide tools
|
||||
"python.defaultInterpreterPath": "/local_scratch/lfainsin/.conda/envs/sap/bin/python", // required for python ide tools
|
||||
"python.terminal.activateEnvironment": false, // or else terminal gets bugged
|
||||
// good pratice settings
|
||||
"editor.formatOnSave": true,
|
||||
|
@ -33,7 +33,7 @@
|
|||
"path": "bash",
|
||||
"icon": "rocket",
|
||||
"env": {
|
||||
"CONDAENV": "pyg",
|
||||
"CONDAENV": "sap",
|
||||
},
|
||||
"args": [
|
||||
"-c",
|
||||
|
@ -42,7 +42,6 @@
|
|||
}
|
||||
},
|
||||
"terminal.integrated.env.linux": {
|
||||
"PYTHONPATH": "${workspaceFolder}/src/",
|
||||
"SLURM_JOB_ID": null,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,12 +33,13 @@ conda activate sap
|
|||
|
||||
Next, you should install [PyTorch3D](https://pytorch3d.org/) (**>=0.5**) yourself from the [official instruction](https://github.com/facebookresearch/pytorch3d/blob/master/INSTALL.md#3-install-wheels-for-linux).
|
||||
|
||||
And install [PyTorch Scatter](https://github.com/rusty1s/pytorch_scatter):
|
||||
```sh
|
||||
conda install pytorch-scatter -c pyg
|
||||
```bash
|
||||
git clone https://github.com/facebookresearch/pytorch3d.git
|
||||
cd pytorch3d
|
||||
module load compilers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
|
||||
## Demo - Quick Start
|
||||
|
||||
First, run the script to get the demo data:
|
||||
|
|
|
@ -1,28 +1,44 @@
|
|||
name: sap
|
||||
|
||||
channels:
|
||||
- nodefaults
|
||||
- conda-forge
|
||||
- pytorch
|
||||
- defaults
|
||||
- nvidia
|
||||
- pyg
|
||||
|
||||
dependencies:
|
||||
- python
|
||||
- pytorch
|
||||
- torchvision
|
||||
- cudatoolkit
|
||||
- numpy
|
||||
- matplotlib
|
||||
- pyyaml
|
||||
- scipy
|
||||
#---# basic python
|
||||
- python=3.8
|
||||
- tqdm
|
||||
- pyyaml
|
||||
#---# visu
|
||||
- matplotlib
|
||||
#---# scientific
|
||||
- numpy
|
||||
- scipy
|
||||
- trimesh
|
||||
- igl
|
||||
#---# pytorch
|
||||
- pytorch
|
||||
- pytorch-cuda=11.8
|
||||
- cudatoolkit
|
||||
- torchvision
|
||||
- torch-scatter
|
||||
#---# tooling (linting, typing...)
|
||||
- ruff
|
||||
- mypy
|
||||
- black
|
||||
- isort
|
||||
#---# logging
|
||||
- tensorboard
|
||||
#---# pip shit
|
||||
- pip
|
||||
- pip:
|
||||
- plyfile==0.7
|
||||
# - open3d>=0.11.1
|
||||
- scikit-image>=0.18.0
|
||||
- python-mnist==0.7
|
||||
- opencv-python>=4.4
|
||||
- av==8.0.3
|
||||
- pykdtree==1.3.4
|
||||
- ipdb==0.13.7
|
||||
- plyfile
|
||||
- scikit-image
|
||||
- python-mnist
|
||||
- opencv-python
|
||||
- av
|
||||
- pykdtree
|
||||
- ipdb
|
||||
|
|
148
eval_meshes.py
148
eval_meshes.py
|
@ -1,155 +1,145 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import trimesh
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import numpy as np; np.set_printoptions(precision=4)
|
||||
import shutil, argparse, time, os
|
||||
import pandas as pd
|
||||
from src.data import collate_remove_none, collate_stack_together, worker_init_fn
|
||||
from src.training import Trainer
|
||||
from src.model import Encode2Points
|
||||
from src.data import PointCloudField, IndexField, Shapes3dDataset
|
||||
from src.utils import load_config, load_pointcloud
|
||||
from src.eval import MeshEvaluator
|
||||
from tqdm import tqdm
|
||||
from pdb import set_trace as st
|
||||
|
||||
from src.data import IndexField, PointCloudField, Shapes3dDataset
|
||||
from src.eval import MeshEvaluator
|
||||
from src.utils import load_config, load_pointcloud
|
||||
|
||||
np.set_printoptions(precision=4)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='MNIST toy experiment')
|
||||
parser.add_argument('config', type=str, help='Path to config file.')
|
||||
parser.add_argument('--no_cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
|
||||
parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.')
|
||||
parser = argparse.ArgumentParser(description="MNIST toy experiment")
|
||||
parser.add_argument("config", type=str, help="Path to config file.")
|
||||
parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
|
||||
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
|
||||
parser.add_argument("--iter", type=int, metavar="S", help="the training iteration to be evaluated.")
|
||||
|
||||
args = parser.parse_args()
|
||||
cfg = load_config(args.config, 'configs/default.yaml')
|
||||
cfg = load_config(args.config, "configs/default.yaml")
|
||||
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
data_type = cfg['data']['data_type']
|
||||
torch.device("cuda" if use_cuda else "cpu")
|
||||
cfg["data"]["data_type"]
|
||||
# Shorthands
|
||||
out_dir = cfg['train']['out_dir']
|
||||
generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir'])
|
||||
out_dir = cfg["train"]["out_dir"]
|
||||
generation_dir = os.path.join(out_dir, cfg["generation"]["generation_dir"])
|
||||
|
||||
if cfg['generation'].get('iter', 0)!=0:
|
||||
generation_dir += '_%04d'%cfg['generation']['iter']
|
||||
if cfg["generation"].get("iter", 0) != 0:
|
||||
generation_dir += "_%04d" % cfg["generation"]["iter"]
|
||||
elif args.iter is not None:
|
||||
generation_dir += '_%04d'%args.iter
|
||||
generation_dir += "_%04d" % args.iter
|
||||
|
||||
print('Evaluate meshes under %s'%generation_dir)
|
||||
print("Evaluate meshes under %s" % generation_dir)
|
||||
|
||||
out_file = os.path.join(generation_dir, 'eval_meshes_full.pkl')
|
||||
out_file_class = os.path.join(generation_dir, 'eval_meshes.csv')
|
||||
out_file = os.path.join(generation_dir, "eval_meshes_full.pkl")
|
||||
out_file_class = os.path.join(generation_dir, "eval_meshes.csv")
|
||||
|
||||
# PYTORCH VERSION > 1.0.0
|
||||
assert(float(torch.__version__.split('.')[-3]) > 0)
|
||||
assert float(torch.__version__.split(".")[-3]) > 0
|
||||
|
||||
pointcloud_field = PointCloudField(cfg['data']['pointcloud_file'])
|
||||
pointcloud_field = PointCloudField(cfg["data"]["pointcloud_file"])
|
||||
fields = {
|
||||
'pointcloud': pointcloud_field,
|
||||
'idx': IndexField(),
|
||||
"pointcloud": pointcloud_field,
|
||||
"idx": IndexField(),
|
||||
}
|
||||
|
||||
print('Test split: ', cfg['data']['test_split'])
|
||||
print("Test split: ", cfg["data"]["test_split"])
|
||||
|
||||
dataset_folder = cfg['data']['path']
|
||||
dataset_folder = cfg["data"]["path"]
|
||||
dataset = Shapes3dDataset(
|
||||
dataset_folder, fields,
|
||||
cfg['data']['test_split'],
|
||||
categories=cfg['data']['class'], cfg=cfg)
|
||||
dataset_folder, fields, cfg["data"]["test_split"], categories=cfg["data"]["class"], cfg=cfg,
|
||||
)
|
||||
|
||||
# Loader
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=1, num_workers=0, shuffle=False)
|
||||
test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False)
|
||||
|
||||
# Evaluator
|
||||
evaluator = MeshEvaluator(n_points=100000)
|
||||
|
||||
eval_dicts = []
|
||||
print('Evaluating meshes...')
|
||||
for it, data in enumerate(tqdm(test_loader)):
|
||||
|
||||
print("Evaluating meshes...")
|
||||
for _it, data in enumerate(tqdm(test_loader)):
|
||||
if data is None:
|
||||
print('Invalid data.')
|
||||
print("Invalid data.")
|
||||
continue
|
||||
|
||||
mesh_dir = os.path.join(generation_dir, 'meshes')
|
||||
pointcloud_dir = os.path.join(generation_dir, 'pointcloud')
|
||||
|
||||
mesh_dir = os.path.join(generation_dir, "meshes")
|
||||
pointcloud_dir = os.path.join(generation_dir, "pointcloud")
|
||||
|
||||
# Get index etc.
|
||||
idx = data['idx'].item()
|
||||
idx = data["idx"].item()
|
||||
try:
|
||||
model_dict = dataset.get_model_dict(idx)
|
||||
except AttributeError:
|
||||
model_dict = {'model': str(idx), 'category': 'n/a'}
|
||||
model_dict = {"model": str(idx), "category": "n/a"}
|
||||
|
||||
modelname = model_dict['model']
|
||||
category_id = model_dict['category']
|
||||
modelname = model_dict["model"]
|
||||
category_id = model_dict["category"]
|
||||
|
||||
try:
|
||||
category_name = dataset.metadata[category_id].get('name', 'n/a')
|
||||
category_name = dataset.metadata[category_id].get("name", "n/a")
|
||||
except AttributeError:
|
||||
category_name = 'n/a'
|
||||
category_name = "n/a"
|
||||
|
||||
if category_id != 'n/a':
|
||||
if category_id != "n/a":
|
||||
mesh_dir = os.path.join(mesh_dir, category_id)
|
||||
pointcloud_dir = os.path.join(pointcloud_dir, category_id)
|
||||
|
||||
# Evaluate
|
||||
pointcloud_tgt = data['pointcloud'].squeeze(0).numpy()
|
||||
normals_tgt = data['pointcloud.normals'].squeeze(0).numpy()
|
||||
|
||||
pointcloud_tgt = data["pointcloud"].squeeze(0).numpy()
|
||||
normals_tgt = data["pointcloud.normals"].squeeze(0).numpy()
|
||||
|
||||
eval_dict = {
|
||||
'idx': idx,
|
||||
'class id': category_id,
|
||||
'class name': category_name,
|
||||
'modelname':modelname,
|
||||
"idx": idx,
|
||||
"class id": category_id,
|
||||
"class name": category_name,
|
||||
"modelname": modelname,
|
||||
}
|
||||
eval_dicts.append(eval_dict)
|
||||
|
||||
# Evaluate mesh
|
||||
if cfg['test']['eval_mesh']:
|
||||
mesh_file = os.path.join(mesh_dir, '%s.off' % modelname)
|
||||
if cfg["test"]["eval_mesh"]:
|
||||
mesh_file = os.path.join(mesh_dir, "%s.off" % modelname)
|
||||
|
||||
if os.path.exists(mesh_file):
|
||||
mesh = trimesh.load(mesh_file, process=False)
|
||||
eval_dict_mesh = evaluator.eval_mesh(
|
||||
mesh, pointcloud_tgt, normals_tgt)
|
||||
eval_dict_mesh = evaluator.eval_mesh(mesh, pointcloud_tgt, normals_tgt)
|
||||
for k, v in eval_dict_mesh.items():
|
||||
eval_dict[k + ' (mesh)'] = v
|
||||
eval_dict[k + " (mesh)"] = v
|
||||
else:
|
||||
print('Warning: mesh does not exist: %s' % mesh_file)
|
||||
print("Warning: mesh does not exist: %s" % mesh_file)
|
||||
|
||||
# Evaluate point cloud
|
||||
if cfg['test']['eval_pointcloud']:
|
||||
pointcloud_file = os.path.join(
|
||||
pointcloud_dir, '%s.ply' % modelname)
|
||||
if cfg["test"]["eval_pointcloud"]:
|
||||
pointcloud_file = os.path.join(pointcloud_dir, "%s.ply" % modelname)
|
||||
|
||||
if os.path.exists(pointcloud_file):
|
||||
pointcloud = load_pointcloud(pointcloud_file).astype(np.float32)
|
||||
eval_dict_pcl = evaluator.eval_pointcloud(
|
||||
pointcloud, pointcloud_tgt)
|
||||
eval_dict_pcl = evaluator.eval_pointcloud(pointcloud, pointcloud_tgt)
|
||||
for k, v in eval_dict_pcl.items():
|
||||
eval_dict[k + ' (pcl)'] = v
|
||||
eval_dict[k + " (pcl)"] = v
|
||||
else:
|
||||
print('Warning: pointcloud does not exist: %s'
|
||||
% pointcloud_file)
|
||||
|
||||
print("Warning: pointcloud does not exist: %s" % pointcloud_file)
|
||||
|
||||
# Create pandas dataframe and save
|
||||
eval_df = pd.DataFrame(eval_dicts)
|
||||
eval_df.set_index(['idx'], inplace=True)
|
||||
eval_df.set_index(["idx"], inplace=True)
|
||||
eval_df.to_pickle(out_file)
|
||||
|
||||
# Create CSV file with main statistics
|
||||
eval_df_class = eval_df.groupby(by=['class name']).mean()
|
||||
eval_df_class.loc['mean'] = eval_df_class.mean()
|
||||
eval_df_class = eval_df.groupby(by=["class name"]).mean()
|
||||
eval_df_class.loc["mean"] = eval_df_class.mean()
|
||||
eval_df_class.to_csv(out_file_class)
|
||||
|
||||
# Print results
|
||||
print(eval_df_class)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
191
generate.py
191
generate.py
|
@ -1,127 +1,136 @@
|
|||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import numpy as np; np.set_printoptions(precision=4)
|
||||
import shutil, argparse, time, os
|
||||
import pandas as pd
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from src import config
|
||||
from src.utils import mc_from_psr, export_mesh, export_pointcloud
|
||||
from src.dpsr import DPSR
|
||||
from src.training import Trainer
|
||||
from src.model import Encode2Points
|
||||
from src.utils import load_config, load_model_manual, scale2onet, is_url, load_url
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from pdb import set_trace as st
|
||||
|
||||
from src import config
|
||||
from src.dpsr import DPSR
|
||||
from src.model import Encode2Points
|
||||
from src.utils import (
|
||||
export_mesh,
|
||||
export_pointcloud,
|
||||
is_url,
|
||||
load_config,
|
||||
load_model_manual,
|
||||
load_url,
|
||||
mc_from_psr,
|
||||
scale2onet,
|
||||
)
|
||||
|
||||
np.set_printoptions(precision=4)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='MNIST toy experiment')
|
||||
parser.add_argument('config', type=str, help='Path to config file.')
|
||||
parser.add_argument('--no_cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
|
||||
parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.')
|
||||
parser = argparse.ArgumentParser(description="MNIST toy experiment")
|
||||
parser.add_argument("config", type=str, help="Path to config file.")
|
||||
parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
|
||||
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
|
||||
parser.add_argument("--iter", type=int, metavar="S", help="the training iteration to be evaluated.")
|
||||
|
||||
args = parser.parse_args()
|
||||
cfg = load_config(args.config, 'configs/default.yaml')
|
||||
cfg = load_config(args.config, "configs/default.yaml")
|
||||
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
data_type = cfg['data']['data_type']
|
||||
input_type = cfg['data']['input_type']
|
||||
vis_n_outputs = cfg['generation']['vis_n_outputs']
|
||||
cfg["data"]["data_type"]
|
||||
cfg["data"]["input_type"]
|
||||
vis_n_outputs = cfg["generation"]["vis_n_outputs"]
|
||||
if vis_n_outputs is None:
|
||||
vis_n_outputs = -1
|
||||
# Shorthands
|
||||
out_dir = cfg['train']['out_dir']
|
||||
out_dir = cfg["train"]["out_dir"]
|
||||
if not out_dir:
|
||||
os.makedirs(out_dir)
|
||||
generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir'])
|
||||
out_time_file = os.path.join(generation_dir, 'time_generation_full.pkl')
|
||||
out_time_file_class = os.path.join(generation_dir, 'time_generation.pkl')
|
||||
generation_dir = os.path.join(out_dir, cfg["generation"]["generation_dir"])
|
||||
out_time_file = os.path.join(generation_dir, "time_generation_full.pkl")
|
||||
out_time_file_class = os.path.join(generation_dir, "time_generation.pkl")
|
||||
|
||||
# PYTORCH VERSION > 1.0.0
|
||||
assert(float(torch.__version__.split('.')[-3]) > 0)
|
||||
assert float(torch.__version__.split(".")[-3]) > 0
|
||||
|
||||
dataset = config.get_dataset('test', cfg, return_idx=True)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=1, num_workers=0, shuffle=False)
|
||||
dataset = config.get_dataset("test", cfg, return_idx=True)
|
||||
test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False)
|
||||
|
||||
model = Encode2Points(cfg).to(device)
|
||||
|
||||
# load model
|
||||
try:
|
||||
if is_url(cfg['test']['model_file']):
|
||||
state_dict = load_url(cfg['test']['model_file'])
|
||||
elif cfg['generation'].get('iter', 0)!=0:
|
||||
state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% cfg['generation']['iter']))
|
||||
generation_dir += '_%04d'%cfg['generation']['iter']
|
||||
if is_url(cfg["test"]["model_file"]):
|
||||
state_dict = load_url(cfg["test"]["model_file"])
|
||||
elif cfg["generation"].get("iter", 0) != 0:
|
||||
state_dict = torch.load(os.path.join(out_dir, "model", "%04d.pt" % cfg["generation"]["iter"]))
|
||||
generation_dir += "_%04d" % cfg["generation"]["iter"]
|
||||
elif args.iter is not None:
|
||||
state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% args.iter))
|
||||
state_dict = torch.load(os.path.join(out_dir, "model", "%04d.pt" % args.iter))
|
||||
else:
|
||||
state_dict = torch.load(os.path.join(out_dir, 'model_best.pt'))
|
||||
state_dict = torch.load(os.path.join(out_dir, "model_best.pt"))
|
||||
|
||||
load_model_manual(state_dict['state_dict'], model)
|
||||
load_model_manual(state_dict["state_dict"], model)
|
||||
|
||||
except:
|
||||
print('Model loading error. Exiting.')
|
||||
print("Model loading error. Exiting.")
|
||||
exit()
|
||||
|
||||
|
||||
# Generator
|
||||
generator = config.get_generator(model, cfg, device=device)
|
||||
|
||||
# Determine what to generate
|
||||
generate_mesh = cfg['generation']['generate_mesh']
|
||||
generate_pointcloud = cfg['generation']['generate_pointcloud']
|
||||
generate_mesh = cfg["generation"]["generate_mesh"]
|
||||
generate_pointcloud = cfg["generation"]["generate_pointcloud"]
|
||||
|
||||
# Statistics
|
||||
time_dicts = []
|
||||
|
||||
# Generate
|
||||
model.eval()
|
||||
dpsr = DPSR(res=(cfg['generation']['psr_resolution'],
|
||||
cfg['generation']['psr_resolution'],
|
||||
cfg['generation']['psr_resolution']),
|
||||
sig= cfg['generation']['psr_sigma']).to(device)
|
||||
|
||||
|
||||
dpsr = DPSR(
|
||||
res=(
|
||||
cfg["generation"]["psr_resolution"],
|
||||
cfg["generation"]["psr_resolution"],
|
||||
cfg["generation"]["psr_resolution"],
|
||||
),
|
||||
sig=cfg["generation"]["psr_sigma"],
|
||||
).to(device)
|
||||
|
||||
# Count how many models already created
|
||||
model_counter = defaultdict(int)
|
||||
|
||||
print('Generating...')
|
||||
for it, data in enumerate(tqdm(test_loader)):
|
||||
|
||||
print("Generating...")
|
||||
for _it, data in enumerate(tqdm(test_loader)):
|
||||
# Output folders
|
||||
mesh_dir = os.path.join(generation_dir, 'meshes')
|
||||
in_dir = os.path.join(generation_dir, 'input')
|
||||
pointcloud_dir = os.path.join(generation_dir, 'pointcloud')
|
||||
generation_vis_dir = os.path.join(generation_dir, 'vis', )
|
||||
mesh_dir = os.path.join(generation_dir, "meshes")
|
||||
in_dir = os.path.join(generation_dir, "input")
|
||||
pointcloud_dir = os.path.join(generation_dir, "pointcloud")
|
||||
generation_vis_dir = os.path.join(generation_dir, "vis")
|
||||
|
||||
# Get index etc.
|
||||
idx = data['idx'].item()
|
||||
idx = data["idx"].item()
|
||||
|
||||
try:
|
||||
model_dict = dataset.get_model_dict(idx)
|
||||
except AttributeError:
|
||||
model_dict = {'model': str(idx), 'category': 'n/a'}
|
||||
model_dict = {"model": str(idx), "category": "n/a"}
|
||||
|
||||
modelname = model_dict['model']
|
||||
category_id = model_dict['category']
|
||||
modelname = model_dict["model"]
|
||||
category_id = model_dict["category"]
|
||||
|
||||
try:
|
||||
category_name = dataset.metadata[category_id].get('name', 'n/a')
|
||||
category_name = dataset.metadata[category_id].get("name", "n/a")
|
||||
except AttributeError:
|
||||
category_name = 'n/a'
|
||||
category_name = "n/a"
|
||||
|
||||
if category_id != 'n/a':
|
||||
if category_id != "n/a":
|
||||
mesh_dir = os.path.join(mesh_dir, str(category_id))
|
||||
pointcloud_dir = os.path.join(pointcloud_dir, str(category_id))
|
||||
in_dir = os.path.join(in_dir, str(category_id))
|
||||
|
||||
folder_name = str(category_id)
|
||||
if category_name != 'n/a':
|
||||
folder_name = str(folder_name) + '_' + category_name.split(',')[0]
|
||||
if category_name != "n/a":
|
||||
folder_name = str(folder_name) + "_" + category_name.split(",")[0]
|
||||
|
||||
generation_vis_dir = os.path.join(generation_vis_dir, folder_name)
|
||||
|
||||
|
@ -140,10 +149,10 @@ def main():
|
|||
|
||||
# Timing dict
|
||||
time_dict = {
|
||||
'idx': idx,
|
||||
'class id': category_id,
|
||||
'class name': category_name,
|
||||
'modelname':modelname,
|
||||
"idx": idx,
|
||||
"class id": category_id,
|
||||
"class name": category_name,
|
||||
"modelname": modelname,
|
||||
}
|
||||
time_dicts.append(time_dict)
|
||||
|
||||
|
@ -158,60 +167,56 @@ def main():
|
|||
time_dict.update(stats_dict)
|
||||
|
||||
# Write output
|
||||
mesh_out_file = os.path.join(mesh_dir, '%s.off' % modelname)
|
||||
mesh_out_file = os.path.join(mesh_dir, "%s.off" % modelname)
|
||||
export_mesh(mesh_out_file, scale2onet(v), f)
|
||||
out_file_dict['mesh'] = mesh_out_file
|
||||
out_file_dict["mesh"] = mesh_out_file
|
||||
|
||||
if generate_pointcloud:
|
||||
pointcloud_out_file = os.path.join(
|
||||
pointcloud_dir, '%s.ply' % modelname)
|
||||
pointcloud_out_file = os.path.join(pointcloud_dir, "%s.ply" % modelname)
|
||||
export_pointcloud(pointcloud_out_file, scale2onet(points), normals)
|
||||
out_file_dict['pointcloud'] = pointcloud_out_file
|
||||
out_file_dict["pointcloud"] = pointcloud_out_file
|
||||
|
||||
if cfg['generation']['copy_input']:
|
||||
inputs_path = os.path.join(in_dir, '%s.ply' % modelname)
|
||||
p = data.get('inputs').to(device)
|
||||
if cfg["generation"]["copy_input"]:
|
||||
inputs_path = os.path.join(in_dir, "%s.ply" % modelname)
|
||||
p = data.get("inputs").to(device)
|
||||
export_pointcloud(inputs_path, scale2onet(p))
|
||||
out_file_dict['in'] = inputs_path
|
||||
out_file_dict["in"] = inputs_path
|
||||
|
||||
# Copy to visualization directory for first vis_n_output samples
|
||||
c_it = model_counter[category_id]
|
||||
if c_it < vis_n_outputs:
|
||||
# Save output files
|
||||
img_name = '%02d.off' % c_it
|
||||
"%02d.off" % c_it
|
||||
for k, filepath in out_file_dict.items():
|
||||
ext = os.path.splitext(filepath)[1]
|
||||
out_file = os.path.join(generation_vis_dir, '%02d_%s%s'
|
||||
% (c_it, k, ext))
|
||||
out_file = os.path.join(generation_vis_dir, "%02d_%s%s" % (c_it, k, ext))
|
||||
shutil.copyfile(filepath, out_file)
|
||||
|
||||
# Also generate oracle meshes
|
||||
if cfg['generation']['exp_oracle']:
|
||||
points_gt = data.get('gt_points').to(device)
|
||||
normals_gt = data.get('gt_points.normals').to(device)
|
||||
if cfg["generation"]["exp_oracle"]:
|
||||
points_gt = data.get("gt_points").to(device)
|
||||
normals_gt = data.get("gt_points.normals").to(device)
|
||||
psr_gt = dpsr(points_gt, normals_gt)
|
||||
v, f, _ = mc_from_psr(psr_gt,
|
||||
zero_level=cfg['data']['zero_level'])
|
||||
out_file = os.path.join(generation_vis_dir, '%02d_%s%s'
|
||||
% (c_it, 'mesh_oracle', '.off'))
|
||||
v, f, _ = mc_from_psr(psr_gt, zero_level=cfg["data"]["zero_level"])
|
||||
out_file = os.path.join(generation_vis_dir, "%02d_%s%s" % (c_it, "mesh_oracle", ".off"))
|
||||
export_mesh(out_file, scale2onet(v), f)
|
||||
|
||||
model_counter[category_id] += 1
|
||||
|
||||
|
||||
# Create pandas dataframe and save
|
||||
time_df = pd.DataFrame(time_dicts)
|
||||
time_df.set_index(['idx'], inplace=True)
|
||||
time_df.set_index(["idx"], inplace=True)
|
||||
time_df.to_pickle(out_time_file)
|
||||
|
||||
# Create pickle files with main statistics
|
||||
time_df_class = time_df.groupby(by=['class name']).mean()
|
||||
time_df_class.loc['mean'] = time_df_class.mean()
|
||||
time_df_class = time_df.groupby(by=["class name"]).mean()
|
||||
time_df_class.loc["mean"] = time_df_class.mean()
|
||||
time_df_class.to_pickle(out_time_file_class)
|
||||
|
||||
# Print results
|
||||
print('Timings [s]:')
|
||||
print("Timings [s]:")
|
||||
print(time_df_class)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
300
optim.py
300
optim.py
|
@ -1,79 +1,86 @@
|
|||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import open3d as o3d
|
||||
import torch
|
||||
import trimesh
|
||||
import shutil, argparse, time, os, glob
|
||||
|
||||
import numpy as np; np.set_printoptions(precision=4)
|
||||
import open3d as o3d
|
||||
from plyfile import PlyData
|
||||
from pytorch3d.io import load_objs_as_meshes
|
||||
from pytorch3d.ops import sample_points_from_meshes
|
||||
from pytorch3d.structures import Meshes
|
||||
from skimage import measure
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torchvision.utils import save_image
|
||||
from torchvision.io import write_video
|
||||
|
||||
from src.optimization import Trainer
|
||||
from src.utils import load_config, update_config, initialize_logger, \
|
||||
get_learning_rate_schedules, adjust_learning_rate, AverageMeter,\
|
||||
update_optimizer, export_pointcloud
|
||||
from skimage import measure
|
||||
from plyfile import PlyData
|
||||
from pytorch3d.ops import sample_points_from_meshes
|
||||
from pytorch3d.io import load_objs_as_meshes
|
||||
from pytorch3d.structures import Meshes
|
||||
from src.utils import (
|
||||
AverageMeter,
|
||||
adjust_learning_rate,
|
||||
export_pointcloud,
|
||||
get_learning_rate_schedules,
|
||||
initialize_logger,
|
||||
load_config,
|
||||
update_config,
|
||||
update_optimizer,
|
||||
)
|
||||
|
||||
np.set_printoptions(precision=4)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='MNIST toy experiment')
|
||||
parser.add_argument('config', type=str, help='Path to config file.')
|
||||
parser.add_argument('--no_cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--seed', type=int, default=1457, metavar='S',
|
||||
help='random seed')
|
||||
parser = argparse.ArgumentParser(description="MNIST toy experiment")
|
||||
parser.add_argument("config", type=str, help="Path to config file.")
|
||||
parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
|
||||
parser.add_argument("--seed", type=int, default=1457, metavar="S", help="random seed")
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
cfg = load_config(args.config, 'configs/default.yaml')
|
||||
cfg = load_config(args.config, "configs/default.yaml")
|
||||
cfg = update_config(cfg, unknown)
|
||||
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
data_type = cfg['data']['data_type']
|
||||
data_class = cfg['data']['class']
|
||||
data_type = cfg["data"]["data_type"]
|
||||
cfg["data"]["class"]
|
||||
|
||||
print(cfg['train']['out_dir'])
|
||||
print(cfg["train"]["out_dir"])
|
||||
|
||||
# PYTORCH VERSION > 1.0.0
|
||||
assert(float(torch.__version__.split('.')[-3]) > 0)
|
||||
assert float(torch.__version__.split(".")[-3]) > 0
|
||||
|
||||
# boiler-plate
|
||||
if cfg['train']['timestamp']:
|
||||
cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S")
|
||||
if cfg["train"]["timestamp"]:
|
||||
cfg["train"]["out_dir"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
|
||||
logger = initialize_logger(cfg)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
shutil.copyfile(args.config,
|
||||
os.path.join(cfg['train']['out_dir'], 'config.yaml'))
|
||||
shutil.copyfile(args.config, os.path.join(cfg["train"]["out_dir"], "config.yaml"))
|
||||
|
||||
# tensorboardX writer
|
||||
tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log")
|
||||
tblogdir = os.path.join(cfg["train"]["out_dir"], "tensorboard_log")
|
||||
if not os.path.exists(tblogdir):
|
||||
os.makedirs(tblogdir)
|
||||
writer = SummaryWriter(log_dir=tblogdir)
|
||||
SummaryWriter(log_dir=tblogdir)
|
||||
|
||||
# initialize o3d visualizer
|
||||
vis = None
|
||||
if cfg['train']['o3d_show']:
|
||||
if cfg["train"]["o3d_show"]:
|
||||
vis = o3d.visualization.Visualizer()
|
||||
vis.create_window(width=cfg['train']['o3d_window_size'],
|
||||
height=cfg['train']['o3d_window_size'])
|
||||
vis.create_window(width=cfg["train"]["o3d_window_size"], height=cfg["train"]["o3d_window_size"])
|
||||
|
||||
# initialize dataset
|
||||
if data_type == 'point':
|
||||
if cfg['data']['object_id'] != -1:
|
||||
data_paths = sorted(glob.glob(cfg['data']['data_path']))
|
||||
data_path = data_paths[cfg['data']['object_id']]
|
||||
print('Loaded %d/%d object' % (cfg['data']['object_id']+1, len(data_paths)))
|
||||
if data_type == "point":
|
||||
if cfg["data"]["object_id"] != -1:
|
||||
data_paths = sorted(glob.glob(cfg["data"]["data_path"]))
|
||||
data_path = data_paths[cfg["data"]["object_id"]]
|
||||
print("Loaded %d/%d object" % (cfg["data"]["object_id"] + 1, len(data_paths)))
|
||||
else:
|
||||
data_path = cfg['data']['data_path']
|
||||
print('Data loaded')
|
||||
ext = data_path.split('.')[-1]
|
||||
if ext == 'obj': # have GT mesh
|
||||
data_path = cfg["data"]["data_path"]
|
||||
print("Data loaded")
|
||||
ext = data_path.split(".")[-1]
|
||||
if ext == "obj": # have GT mesh
|
||||
mesh = load_objs_as_meshes([data_path], device=device)
|
||||
# scale the mesh into unit cube
|
||||
verts = mesh.verts_packed()
|
||||
|
@ -81,20 +88,15 @@ def main():
|
|||
center = verts.mean(0)
|
||||
mesh.offset_verts_(-center.expand(N, 3))
|
||||
scale = max((verts - center).abs().max(0)[0])
|
||||
mesh.scale_verts_((1.0 / float(scale)))
|
||||
mesh.scale_verts_(1.0 / float(scale))
|
||||
# important for our DPSR to have the range in [0, 1), not reaching 1
|
||||
mesh.scale_verts_(0.9)
|
||||
|
||||
target_pts, target_normals = sample_points_from_meshes(mesh,
|
||||
num_samples=200000, return_normals=True)
|
||||
elif ext == 'ply': # only have the point cloud
|
||||
target_pts, target_normals = sample_points_from_meshes(mesh, num_samples=200000, return_normals=True)
|
||||
elif ext == "ply": # only have the point cloud
|
||||
plydata = PlyData.read(data_path)
|
||||
vertices = np.stack([plydata['vertex']['x'],
|
||||
plydata['vertex']['y'],
|
||||
plydata['vertex']['z']], axis=1)
|
||||
normals = np.stack([plydata['vertex']['nx'],
|
||||
plydata['vertex']['ny'],
|
||||
plydata['vertex']['nz']], axis=1)
|
||||
vertices = np.stack([plydata["vertex"]["x"], plydata["vertex"]["y"], plydata["vertex"]["z"]], axis=1)
|
||||
normals = np.stack([plydata["vertex"]["nx"], plydata["vertex"]["ny"], plydata["vertex"]["nz"]], axis=1)
|
||||
N = vertices.shape[0]
|
||||
center = vertices.mean(0)
|
||||
scale = np.max(np.max(np.abs(vertices - center), axis=0))
|
||||
|
@ -111,205 +113,205 @@ def main():
|
|||
if not torch.is_tensor(scale):
|
||||
scale = torch.from_numpy(np.array([scale]))
|
||||
|
||||
data = {'target_points': target_pts,
|
||||
'target_normals': target_normals, # normals are never used
|
||||
'gt_mesh': mesh}
|
||||
data = {
|
||||
"target_points": target_pts,
|
||||
"target_normals": target_normals, # normals are never used
|
||||
"gt_mesh": mesh,
|
||||
}
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# save the input point cloud
|
||||
if 'target_points' in data.keys():
|
||||
outdir_pcl = os.path.join(cfg['train']['out_dir'], 'target_pcl.ply')
|
||||
if 'target_normals' in data.keys():
|
||||
export_pointcloud(outdir_pcl, data['target_points'], data['target_normals'])
|
||||
if "target_points" in data.keys():
|
||||
outdir_pcl = os.path.join(cfg["train"]["out_dir"], "target_pcl.ply")
|
||||
if "target_normals" in data.keys():
|
||||
export_pointcloud(outdir_pcl, data["target_points"], data["target_normals"])
|
||||
else:
|
||||
export_pointcloud(outdir_pcl, data['target_points'])
|
||||
export_pointcloud(outdir_pcl, data["target_points"])
|
||||
|
||||
# save oracle PSR mesh (mesh from our PSR using GT point+normals)
|
||||
if data.get('gt_mesh') is not None:
|
||||
gt_verts, gt_faces = data['gt_mesh'].get_mesh_verts_faces(0)
|
||||
pts_gt, norms_gt = sample_points_from_meshes(data['gt_mesh'],
|
||||
num_samples=500000, return_normals=True)
|
||||
if data.get("gt_mesh") is not None:
|
||||
gt_verts, gt_faces = data["gt_mesh"].get_mesh_verts_faces(0)
|
||||
pts_gt, norms_gt = sample_points_from_meshes(data["gt_mesh"], num_samples=500000, return_normals=True)
|
||||
pts_gt = (pts_gt + 1) / 2
|
||||
from src.dpsr import DPSR
|
||||
dpsr_tmp = DPSR(res=(cfg['model']['grid_res'],
|
||||
cfg['model']['grid_res'],
|
||||
cfg['model']['grid_res']),
|
||||
sig=cfg['model']['psr_sigma']).to(device)
|
||||
|
||||
dpsr_tmp = DPSR(
|
||||
res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
|
||||
sig=cfg["model"]["psr_sigma"],
|
||||
).to(device)
|
||||
target = dpsr_tmp(pts_gt, norms_gt).unsqueeze(1).to(device)
|
||||
target = torch.tanh(target)
|
||||
s = target.shape[-1] # size of psr_grid
|
||||
psr_grid_numpy = target.squeeze().detach().cpu().numpy()
|
||||
verts, faces, _, _ = measure.marching_cubes(psr_grid_numpy)
|
||||
verts = verts / s * 2. - 1 # [-1, 1]
|
||||
verts = verts / s * 2.0 - 1 # [-1, 1]
|
||||
mesh = o3d.geometry.TriangleMesh()
|
||||
mesh.vertices = o3d.utility.Vector3dVector(verts)
|
||||
mesh.triangles = o3d.utility.Vector3iVector(faces)
|
||||
outdir_mesh = os.path.join(cfg['train']['out_dir'], 'oracle_mesh.ply')
|
||||
outdir_mesh = os.path.join(cfg["train"]["out_dir"], "oracle_mesh.ply")
|
||||
o3d.io.write_triangle_mesh(outdir_mesh, mesh)
|
||||
|
||||
# initialize the source point cloud given an input mesh
|
||||
if 'input_mesh' in cfg['train'].keys() and \
|
||||
os.path.isfile(cfg['train']['input_mesh']):
|
||||
if cfg['train']['input_mesh'].split('/')[-2] == 'mesh':
|
||||
mesh_tmp = trimesh.load_mesh(cfg['train']['input_mesh'])
|
||||
if "input_mesh" in cfg["train"].keys() and os.path.isfile(cfg["train"]["input_mesh"]):
|
||||
if cfg["train"]["input_mesh"].split("/")[-2] == "mesh":
|
||||
mesh_tmp = trimesh.load_mesh(cfg["train"]["input_mesh"])
|
||||
verts = torch.from_numpy(mesh_tmp.vertices[None]).float().to(device)
|
||||
faces = torch.from_numpy(mesh_tmp.faces[None]).to(device)
|
||||
mesh = Meshes(verts=verts, faces=faces)
|
||||
points, normals = sample_points_from_meshes(mesh,
|
||||
num_samples=cfg['data']['num_points'], return_normals=True)
|
||||
points, normals = sample_points_from_meshes(
|
||||
mesh, num_samples=cfg["data"]["num_points"], return_normals=True,
|
||||
)
|
||||
# mesh is saved in the original scale of the gt
|
||||
points -= center.float().to(device)
|
||||
points /= scale.float().to(device)
|
||||
points *= 0.9
|
||||
# make sure the points are within the range of [0, 1)
|
||||
points = points / 2. + 0.5
|
||||
points = points / 2.0 + 0.5
|
||||
else:
|
||||
# directly initialize from a point cloud
|
||||
pcd = o3d.io.read_point_cloud(cfg['train']['input_mesh'])
|
||||
pcd = o3d.io.read_point_cloud(cfg["train"]["input_mesh"])
|
||||
points = torch.from_numpy(np.array(pcd.points)[None]).float().to(device)
|
||||
normals = torch.from_numpy(np.array(pcd.normals)[None]).float().to(device)
|
||||
points -= center.float().to(device)
|
||||
points /= scale.float().to(device)
|
||||
points *= 0.9
|
||||
points = points / 2. + 0.5
|
||||
points = points / 2.0 + 0.5
|
||||
else: #! initialize our source point cloud from a sphere
|
||||
sphere_radius = cfg['model']['sphere_radius']
|
||||
sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius,
|
||||
count=[256,256])
|
||||
points, idx = sphere_mesh.sample(cfg['data']['num_points'],
|
||||
return_index=True)
|
||||
sphere_radius = cfg["model"]["sphere_radius"]
|
||||
sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius, count=[256, 256])
|
||||
points, idx = sphere_mesh.sample(cfg["data"]["num_points"], return_index=True)
|
||||
points += 0.5 # make sure the points are within the range of [0, 1)
|
||||
normals = sphere_mesh.face_normals[idx]
|
||||
points = torch.from_numpy(points).unsqueeze(0).to(device)
|
||||
normals = torch.from_numpy(normals).unsqueeze(0).to(device)
|
||||
|
||||
|
||||
points = torch.log(points/(1-points)) # inverse sigmoid
|
||||
points = torch.log(points / (1 - points)) # inverse sigmoid
|
||||
inputs = torch.cat([points, normals], axis=-1).float()
|
||||
inputs.requires_grad = True
|
||||
|
||||
model = None # no network
|
||||
|
||||
# initialize optimizer
|
||||
cfg['train']['schedule']['pcl']['initial'] = cfg['train']['lr_pcl']
|
||||
print('Initial learning rate:', cfg['train']['schedule']['pcl']['initial'])
|
||||
if 'schedule' in cfg['train']:
|
||||
lr_schedules = get_learning_rate_schedules(cfg['train']['schedule'])
|
||||
cfg["train"]["schedule"]["pcl"]["initial"] = cfg["train"]["lr_pcl"]
|
||||
print("Initial learning rate:", cfg["train"]["schedule"]["pcl"]["initial"])
|
||||
if "schedule" in cfg["train"]:
|
||||
lr_schedules = get_learning_rate_schedules(cfg["train"]["schedule"])
|
||||
else:
|
||||
lr_schedules = None
|
||||
|
||||
optimizer = update_optimizer(inputs, cfg,
|
||||
epoch=0, model=model, schedule=lr_schedules)
|
||||
optimizer = update_optimizer(inputs, cfg, epoch=0, model=model, schedule=lr_schedules)
|
||||
|
||||
try:
|
||||
# load model
|
||||
state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt'))
|
||||
if ('pcl' in state_dict.keys()) & (state_dict['pcl'] is not None):
|
||||
inputs = state_dict['pcl'].to(device)
|
||||
state_dict = torch.load(os.path.join(cfg["train"]["out_dir"], "model.pt"))
|
||||
if ("pcl" in state_dict.keys()) & (state_dict["pcl"] is not None):
|
||||
inputs = state_dict["pcl"].to(device)
|
||||
inputs.requires_grad = True
|
||||
|
||||
optimizer = update_optimizer(inputs, cfg,
|
||||
epoch=state_dict.get('epoch'), schedule=lr_schedules)
|
||||
optimizer = update_optimizer(inputs, cfg, epoch=state_dict.get("epoch"), schedule=lr_schedules)
|
||||
|
||||
out = "Load model from epoch %d" % state_dict.get('epoch', 0)
|
||||
out = "Load model from epoch %d" % state_dict.get("epoch", 0)
|
||||
print(out)
|
||||
logger.info(out)
|
||||
except:
|
||||
state_dict = dict()
|
||||
|
||||
start_epoch = state_dict.get('epoch', -1)
|
||||
start_epoch = state_dict.get("epoch", -1)
|
||||
|
||||
trainer = Trainer(cfg, optimizer, device=device)
|
||||
runtime = {}
|
||||
runtime['all'] = AverageMeter()
|
||||
runtime["all"] = AverageMeter()
|
||||
|
||||
# training loop
|
||||
for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1):
|
||||
|
||||
for epoch in range(start_epoch + 1, cfg["train"]["total_epochs"] + 1):
|
||||
# schedule the learning rate
|
||||
if (epoch>0) & (lr_schedules is not None):
|
||||
if (epoch % lr_schedules[0].interval == 0):
|
||||
if (epoch > 0) & (lr_schedules is not None):
|
||||
if epoch % lr_schedules[0].interval == 0:
|
||||
adjust_learning_rate(lr_schedules, optimizer, epoch)
|
||||
if len(lr_schedules) >1:
|
||||
print('[epoch {}] net_lr: {}, pcl_lr: {}'.format(epoch,
|
||||
lr_schedules[0].get_learning_rate(epoch),
|
||||
lr_schedules[1].get_learning_rate(epoch)))
|
||||
if len(lr_schedules) > 1:
|
||||
print(
|
||||
"[epoch {}] net_lr: {}, pcl_lr: {}".format(
|
||||
epoch, lr_schedules[0].get_learning_rate(epoch), lr_schedules[1].get_learning_rate(epoch),
|
||||
),
|
||||
)
|
||||
else:
|
||||
print('[epoch {}] adjust pcl_lr to: {}'.format(epoch,
|
||||
lr_schedules[0].get_learning_rate(epoch)))
|
||||
print(f"[epoch {epoch}] adjust pcl_lr to: {lr_schedules[0].get_learning_rate(epoch)}")
|
||||
|
||||
start = time.time()
|
||||
loss, loss_each = trainer.train_step(data, inputs, model, epoch)
|
||||
runtime['all'].update(time.time() - start)
|
||||
runtime["all"].update(time.time() - start)
|
||||
|
||||
if epoch % cfg['train']['print_every'] == 0:
|
||||
log_text = ('[Epoch %02d] loss=%.5f') %(epoch, loss)
|
||||
if epoch % cfg["train"]["print_every"] == 0:
|
||||
log_text = ("[Epoch %02d] loss=%.5f") % (epoch, loss)
|
||||
if loss_each is not None:
|
||||
for k, l in loss_each.items():
|
||||
if l.item() != 0.:
|
||||
log_text += (' loss_%s=%.5f') % (k, l.item())
|
||||
if l.item() != 0.0:
|
||||
log_text += f" loss_{k}={l.item():.5f}"
|
||||
|
||||
log_text += (' time=%.3f / %.3f') % (runtime['all'].val,
|
||||
runtime['all'].sum)
|
||||
log_text += (" time={:.3f} / {:.3f}").format(runtime["all"].val, runtime["all"].sum)
|
||||
logger.info(log_text)
|
||||
print(log_text)
|
||||
|
||||
# visualize point clouds and meshes
|
||||
if (epoch % cfg['train']['visualize_every'] == 0) & (vis is not None):
|
||||
if (epoch % cfg["train"]["visualize_every"] == 0) & (vis is not None):
|
||||
trainer.visualize(data, inputs, model, epoch, o3d_vis=vis)
|
||||
|
||||
# save outputs
|
||||
if epoch % cfg['train']['save_every'] == 0:
|
||||
trainer.save_mesh_pointclouds(inputs, epoch,
|
||||
center.cpu().numpy(),
|
||||
scale.cpu().numpy()*(1/0.9))
|
||||
if epoch % cfg["train"]["save_every"] == 0:
|
||||
trainer.save_mesh_pointclouds(inputs, epoch, center.cpu().numpy(), scale.cpu().numpy() * (1 / 0.9))
|
||||
|
||||
# save checkpoints
|
||||
if (epoch > 0) & (epoch % cfg['train']['checkpoint_every'] == 0):
|
||||
state = {'epoch': epoch}
|
||||
pcl = None
|
||||
if (epoch > 0) & (epoch % cfg["train"]["checkpoint_every"] == 0):
|
||||
state = {"epoch": epoch}
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
state['pcl'] = inputs.detach().cpu()
|
||||
state["pcl"] = inputs.detach().cpu()
|
||||
|
||||
torch.save(state, os.path.join(cfg['train']['dir_model'],
|
||||
'%04d' % epoch + '.pt'))
|
||||
torch.save(state, os.path.join(cfg["train"]["dir_model"], "%04d" % epoch + ".pt"))
|
||||
print("Save new model at epoch %d" % epoch)
|
||||
logger.info("Save new model at epoch %d" % epoch)
|
||||
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt'))
|
||||
torch.save(state, os.path.join(cfg["train"]["out_dir"], "model.pt"))
|
||||
|
||||
# resample and gradually add new points to the source pcl
|
||||
if (epoch > 0) & \
|
||||
(cfg['train']['resample_every']!=0) & \
|
||||
(epoch % cfg['train']['resample_every'] == 0) & \
|
||||
(epoch < cfg['train']['total_epochs']):
|
||||
if (
|
||||
(epoch > 0)
|
||||
& (cfg["train"]["resample_every"] != 0)
|
||||
& (epoch % cfg["train"]["resample_every"] == 0)
|
||||
& (epoch < cfg["train"]["total_epochs"])
|
||||
):
|
||||
inputs = trainer.point_resampling(inputs)
|
||||
optimizer = update_optimizer(inputs, cfg,
|
||||
epoch=epoch, model=model, schedule=lr_schedules)
|
||||
optimizer = update_optimizer(inputs, cfg, epoch=epoch, model=model, schedule=lr_schedules)
|
||||
trainer = Trainer(cfg, optimizer, device=device)
|
||||
|
||||
# visualize the Open3D outputs
|
||||
if cfg['train']['o3d_show']:
|
||||
out_video_dir = os.path.join(cfg['train']['out_dir'],
|
||||
'vis/o3d/video.mp4')
|
||||
if cfg["train"]["o3d_show"]:
|
||||
out_video_dir = os.path.join(cfg["train"]["out_dir"], "vis/o3d/video.mp4")
|
||||
if os.path.isfile(out_video_dir):
|
||||
os.system('rm {}'.format(out_video_dir))
|
||||
os.system('ffmpeg -framerate 30 \
|
||||
os.system(f"rm {out_video_dir}")
|
||||
os.system(
|
||||
"ffmpeg -framerate 30 \
|
||||
-start_number 0 \
|
||||
-i {}/vis/o3d/%04d.jpg \
|
||||
-pix_fmt yuv420p \
|
||||
-crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir))
|
||||
out_video_dir = os.path.join(cfg['train']['out_dir'],
|
||||
'vis/o3d/video_pcd.mp4')
|
||||
-crf 17 {}".format(
|
||||
cfg["train"]["out_dir"], out_video_dir,
|
||||
),
|
||||
)
|
||||
out_video_dir = os.path.join(cfg["train"]["out_dir"], "vis/o3d/video_pcd.mp4")
|
||||
if os.path.isfile(out_video_dir):
|
||||
os.system('rm {}'.format(out_video_dir))
|
||||
os.system('ffmpeg -framerate 30 \
|
||||
os.system(f"rm {out_video_dir}")
|
||||
os.system(
|
||||
"ffmpeg -framerate 30 \
|
||||
-start_number 0 \
|
||||
-i {}/vis/o3d/%04d_pcd.jpg \
|
||||
-pix_fmt yuv420p \
|
||||
-crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir))
|
||||
print('Video saved.')
|
||||
-crf 17 {}".format(
|
||||
cfg["train"]["out_dir"], out_video_dir,
|
||||
),
|
||||
)
|
||||
print("Video saved.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -1,69 +1,80 @@
|
|||
import sys, os
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from src.utils import load_config
|
||||
import subprocess
|
||||
os.environ['MKL_THREADING_LAYER'] = 'GNU'
|
||||
|
||||
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser(description='MNIST toy experiment')
|
||||
parser.add_argument('config', type=str, help='Path to config file.')
|
||||
parser.add_argument('--start_res', type=int, default=-1, help='Resolution to start with.')
|
||||
parser.add_argument('--object_id', type=int, default=-1, help='Object index.')
|
||||
parser = argparse.ArgumentParser(description="MNIST toy experiment")
|
||||
parser.add_argument("config", type=str, help="Path to config file.")
|
||||
parser.add_argument("--start_res", type=int, default=-1, help="Resolution to start with.")
|
||||
parser.add_argument("--object_id", type=int, default=-1, help="Object index.")
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
cfg = load_config(args.config, 'configs/default.yaml')
|
||||
cfg = load_config(args.config, "configs/default.yaml")
|
||||
|
||||
resolutions=[32, 64, 128, 256]
|
||||
iterations=[1000, 1000, 1000, 200]
|
||||
lrs=[2e-3, 2e-3*0.7, 2e-3*(0.7**2), 2e-3*(0.7**3)] # reduce lr
|
||||
for idx,(res, iteration, lr) in enumerate(zip(resolutions, iterations, lrs)):
|
||||
|
||||
if res<args.start_res:
|
||||
resolutions = [32, 64, 128, 256]
|
||||
iterations = [1000, 1000, 1000, 200]
|
||||
lrs = [2e-3, 2e-3 * 0.7, 2e-3 * (0.7**2), 2e-3 * (0.7**3)] # reduce lr
|
||||
for idx, (res, iteration, lr) in enumerate(zip(resolutions, iterations, lrs)):
|
||||
if res < args.start_res:
|
||||
continue
|
||||
|
||||
if res>cfg['model']['grid_res']:
|
||||
if res > cfg["model"]["grid_res"]:
|
||||
continue
|
||||
|
||||
psr_sigma= 2 if res<=128 else 3
|
||||
psr_sigma = 2 if res <= 128 else 3
|
||||
|
||||
if res > 128:
|
||||
psr_sigma = 5 if 'thingi_noisy' in args.config else 3
|
||||
psr_sigma = 5 if "thingi_noisy" in args.config else 3
|
||||
|
||||
if args.object_id != -1:
|
||||
out_dir = os.path.join(cfg['train']['out_dir'], 'object_%02d'%args.object_id, 'res_%d'%res)
|
||||
out_dir = os.path.join(cfg["train"]["out_dir"], "object_%02d" % args.object_id, "res_%d" % res)
|
||||
else:
|
||||
out_dir = os.path.join(cfg['train']['out_dir'], 'res_%d'%res)
|
||||
out_dir = os.path.join(cfg["train"]["out_dir"], "res_%d" % res)
|
||||
|
||||
# sample from mesh when resampling is enabled, otherwise reuse the pointcloud
|
||||
init_shape='mesh' if cfg['train']['resample_every']>0 else 'pointcloud'
|
||||
|
||||
init_shape = "mesh" if cfg["train"]["resample_every"] > 0 else "pointcloud"
|
||||
|
||||
if args.object_id != -1:
|
||||
input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'],
|
||||
'object_%02d'%args.object_id, 'res_%d' % (resolutions[idx-1]),
|
||||
'vis', init_shape, '%04d.ply' % (iterations[idx-1]))
|
||||
input_mesh = (
|
||||
"None"
|
||||
if idx == 0
|
||||
else os.path.join(
|
||||
cfg["train"]["out_dir"],
|
||||
"object_%02d" % args.object_id,
|
||||
"res_%d" % (resolutions[idx - 1]),
|
||||
"vis",
|
||||
init_shape,
|
||||
"%04d.ply" % (iterations[idx - 1]),
|
||||
)
|
||||
)
|
||||
else:
|
||||
input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'],
|
||||
'res_%d' % (resolutions[idx-1]),
|
||||
'vis', init_shape, '%04d.ply' % (iterations[idx-1]))
|
||||
input_mesh = (
|
||||
"None"
|
||||
if idx == 0
|
||||
else os.path.join(
|
||||
cfg["train"]["out_dir"],
|
||||
"res_%d" % (resolutions[idx - 1]),
|
||||
"vis",
|
||||
init_shape,
|
||||
"%04d.ply" % (iterations[idx - 1]),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
cmd = 'export MKL_SERVICE_FORCE_INTEL=1 && '
|
||||
cmd += "python optim.py %s --model:grid_res %d --model:psr_sigma %d \
|
||||
cmd = "export MKL_SERVICE_FORCE_INTEL=1 && "
|
||||
cmd += (
|
||||
"python optim.py %s --model:grid_res %d --model:psr_sigma %d \
|
||||
--train:input_mesh %s --train:total_epochs %d \
|
||||
--train:out_dir %s --train:lr_pcl %f \
|
||||
--data:object_id %d" % (
|
||||
args.config,
|
||||
res,
|
||||
psr_sigma,
|
||||
input_mesh,
|
||||
iteration,
|
||||
out_dir,
|
||||
lr,
|
||||
args.object_id)
|
||||
--data:object_id %d"
|
||||
% (args.config, res, psr_sigma, input_mesh, iteration, out_dir, lr, args.object_id)
|
||||
)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
|
||||
if __name__=="__main__":
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
55
pyproject.toml
Normal file
55
pyproject.toml
Normal file
|
@ -0,0 +1,55 @@
|
|||
[tool.ruff]
|
||||
line-length = 120
|
||||
ignore-init-module-imports = true
|
||||
ignore = [
|
||||
"G004", # Logging statement uses f-string
|
||||
"EM102", # Exception must not use an f-string literal, assign to variable first
|
||||
]
|
||||
select = [
|
||||
"A", # flake8-builtins
|
||||
"B", # flake8-bugbear
|
||||
"C90", # mccabe
|
||||
"COM", # flake8-commas
|
||||
"D", # pydocstyle
|
||||
"EM", # flake8-errmsg
|
||||
"E", # pycodestyle errors
|
||||
"F", # Pyflakes
|
||||
"G", # flake8-logging-format
|
||||
"I", # isort
|
||||
"N", # pep8-naming
|
||||
"PIE", # flake8-pie
|
||||
"PTH", # flake8-use-pathlib
|
||||
"RET", # flake8-return
|
||||
"RUF", # ruff
|
||||
"S", # flake8-bandit
|
||||
"TCH", # flake8-type-checking
|
||||
"TID", # flake8-tidy-imports
|
||||
"UP", # pyupgrade
|
||||
"W", # pycodestyle warnings
|
||||
]
|
||||
|
||||
[tool.ruff.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
"__init__.py" = ["F401"]
|
||||
|
||||
[tool.black]
|
||||
exclude = '''
|
||||
/(
|
||||
\.git
|
||||
\.venv
|
||||
)/
|
||||
'''
|
||||
include = '\.pyi?$'
|
||||
line-length = 120
|
||||
target-version = ["py310"]
|
||||
|
||||
[tool.isort]
|
||||
multi_line_output = 3
|
||||
profile = "black"
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
|
@ -1,14 +1,16 @@
|
|||
import os
|
||||
import torch
|
||||
import time
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.dpsr import DPSR
|
||||
|
||||
data_path = 'data/ShapeNet' # path for ShapeNet from ONet
|
||||
base = 'data' # output base directory
|
||||
dataset_name = 'shapenet_psr'
|
||||
data_path = "data/ShapeNet" # path for ShapeNet from ONet
|
||||
base = "data" # output base directory
|
||||
dataset_name = "shapenet_psr"
|
||||
multiprocess = True
|
||||
njobs = 8
|
||||
save_pointcloud = True
|
||||
|
@ -20,20 +22,20 @@ padding = 1.2
|
|||
|
||||
dpsr = DPSR(res=(resolution, resolution, resolution), sig=0)
|
||||
|
||||
def process_one(obj):
|
||||
|
||||
obj_name = obj.split('/')[-1]
|
||||
c = obj.split('/')[-2]
|
||||
def process_one(obj):
|
||||
obj_name = obj.split("/")[-1]
|
||||
c = obj.split("/")[-2]
|
||||
|
||||
# create new for the current object
|
||||
out_path_cur = os.path.join(base, dataset_name, c)
|
||||
out_path_cur_obj = os.path.join(out_path_cur, obj_name)
|
||||
os.makedirs(out_path_cur_obj, exist_ok=True)
|
||||
|
||||
gt_path = os.path.join(data_path, c, obj_name, 'pointcloud.npz')
|
||||
gt_path = os.path.join(data_path, c, obj_name, "pointcloud.npz")
|
||||
data = np.load(gt_path)
|
||||
points = data['points']
|
||||
normals = data['normals']
|
||||
points = data["points"]
|
||||
normals = data["normals"]
|
||||
|
||||
# normalize the point to [0, 1)
|
||||
points = points / padding + 0.5
|
||||
|
@ -41,31 +43,35 @@ def process_one(obj):
|
|||
#! p = (p - 0.5) * padding
|
||||
|
||||
if save_pointcloud:
|
||||
outdir = os.path.join(out_path_cur_obj, 'pointcloud.npz')
|
||||
outdir = os.path.join(out_path_cur_obj, "pointcloud.npz")
|
||||
# np.savez(outdir, points=points, normals=normals)
|
||||
np.savez(outdir, points=data['points'], normals=data['normals'])
|
||||
np.savez(outdir, points=data["points"], normals=data["normals"])
|
||||
# return
|
||||
|
||||
if save_psr_field:
|
||||
psr_gt = dpsr(torch.from_numpy(points.astype(np.float32))[None],
|
||||
torch.from_numpy(normals.astype(np.float32))[None]).squeeze().cpu().numpy().astype(np.float16)
|
||||
psr_gt = (
|
||||
dpsr(torch.from_numpy(points.astype(np.float32))[None], torch.from_numpy(normals.astype(np.float32))[None])
|
||||
.squeeze()
|
||||
.cpu()
|
||||
.numpy()
|
||||
.astype(np.float16)
|
||||
)
|
||||
|
||||
outdir = os.path.join(out_path_cur_obj, 'psr.npz')
|
||||
outdir = os.path.join(out_path_cur_obj, "psr.npz")
|
||||
np.savez(outdir, psr=psr_gt)
|
||||
|
||||
|
||||
def main(c):
|
||||
print("---------------------------------------")
|
||||
print(f"Processing {c} {split}")
|
||||
print("---------------------------------------")
|
||||
|
||||
print('---------------------------------------')
|
||||
print('Processing {} {}'.format(c, split))
|
||||
print('---------------------------------------')
|
||||
|
||||
for split in ['train', 'val', 'test']:
|
||||
fname = os.path.join(data_path, c, split+'.lst')
|
||||
with open(fname, 'r') as f:
|
||||
for split in ["train", "val", "test"]:
|
||||
fname = os.path.join(data_path, c, split + ".lst")
|
||||
with open(fname) as f:
|
||||
obj_list = f.read().splitlines()
|
||||
|
||||
obj_list = [c+'/'+s for s in obj_list]
|
||||
obj_list = [c + "/" + s for s in obj_list]
|
||||
|
||||
if multiprocess:
|
||||
# multiprocessing.set_start_method('spawn', force=True)
|
||||
|
@ -82,20 +88,29 @@ def main(c):
|
|||
for obj in tqdm(obj_list):
|
||||
process_one(obj)
|
||||
|
||||
print('Done Processing {} {}!'.format(c, split))
|
||||
print(f"Done Processing {c} {split}!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
classes = ['02691156', '02828884', '02933112',
|
||||
'02958343', '03211117', '03001627',
|
||||
'03636649', '03691459', '04090263',
|
||||
'04256520', '04379243', '04401088', '04530566']
|
||||
|
||||
classes = [
|
||||
"02691156",
|
||||
"02828884",
|
||||
"02933112",
|
||||
"02958343",
|
||||
"03211117",
|
||||
"03001627",
|
||||
"03636649",
|
||||
"03691459",
|
||||
"04090263",
|
||||
"04256520",
|
||||
"04379243",
|
||||
"04401088",
|
||||
"04530566",
|
||||
]
|
||||
|
||||
t_start = time.time()
|
||||
for c in classes:
|
||||
main(c)
|
||||
|
||||
t_end = time.time()
|
||||
print('Total processing time: ', t_end - t_start)
|
||||
print("Total processing time: ", t_end - t_start)
|
||||
|
|
139
src/config.py
139
src/config.py
|
@ -1,146 +1,151 @@
|
|||
import yaml
|
||||
from torchvision import transforms
|
||||
|
||||
from src import data, generation
|
||||
from src.dpsr import DPSR
|
||||
from ipdb import set_trace as st
|
||||
|
||||
|
||||
# Generator for final mesh extraction
|
||||
def get_generator(model, cfg, device, **kwargs):
|
||||
''' Returns the generator object.
|
||||
"""Returns the generator object.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Occupancy Network model
|
||||
cfg (dict): imported yaml config
|
||||
device (device): pytorch device
|
||||
'''
|
||||
|
||||
if cfg['generation']['psr_resolution'] == 0:
|
||||
psr_res = cfg['model']['grid_res']
|
||||
psr_sigma = cfg['model']['psr_sigma']
|
||||
"""
|
||||
if cfg["generation"]["psr_resolution"] == 0:
|
||||
psr_res = cfg["model"]["grid_res"]
|
||||
psr_sigma = cfg["model"]["psr_sigma"]
|
||||
else:
|
||||
psr_res = cfg['generation']['psr_resolution']
|
||||
psr_sigma = cfg['generation']['psr_sigma']
|
||||
|
||||
dpsr = DPSR(res=(psr_res, psr_res, psr_res),
|
||||
sig= psr_sigma).to(device)
|
||||
psr_res = cfg["generation"]["psr_resolution"]
|
||||
psr_sigma = cfg["generation"]["psr_sigma"]
|
||||
|
||||
dpsr = DPSR(res=(psr_res, psr_res, psr_res), sig=psr_sigma).to(device)
|
||||
|
||||
generator = generation.Generator3D(
|
||||
model,
|
||||
device=device,
|
||||
threshold=cfg['data']['zero_level'],
|
||||
sample=cfg['generation']['use_sampling'],
|
||||
input_type = cfg['data']['input_type'],
|
||||
padding=cfg['data']['padding'],
|
||||
threshold=cfg["data"]["zero_level"],
|
||||
sample=cfg["generation"]["use_sampling"],
|
||||
input_type=cfg["data"]["input_type"],
|
||||
padding=cfg["data"]["padding"],
|
||||
dpsr=dpsr,
|
||||
psr_tanh=cfg['model']['psr_tanh']
|
||||
psr_tanh=cfg["model"]["psr_tanh"],
|
||||
)
|
||||
return generator
|
||||
|
||||
|
||||
# Datasets
|
||||
def get_dataset(mode, cfg, return_idx=False):
|
||||
''' Returns the dataset.
|
||||
"""Returns the dataset.
|
||||
|
||||
Args:
|
||||
model (nn.Module): the model which is used
|
||||
cfg (dict): config dictionary
|
||||
return_idx (bool): whether to include an ID field
|
||||
'''
|
||||
dataset_type = cfg['data']['dataset']
|
||||
dataset_folder = cfg['data']['path']
|
||||
categories = cfg['data']['class']
|
||||
"""
|
||||
dataset_type = cfg["data"]["dataset"]
|
||||
dataset_folder = cfg["data"]["path"]
|
||||
categories = cfg["data"]["class"]
|
||||
|
||||
# Get split
|
||||
splits = {
|
||||
'train': cfg['data']['train_split'],
|
||||
'val': cfg['data']['val_split'],
|
||||
'test': cfg['data']['test_split'],
|
||||
'vis': cfg['data']['val_split'],
|
||||
"train": cfg["data"]["train_split"],
|
||||
"val": cfg["data"]["val_split"],
|
||||
"test": cfg["data"]["test_split"],
|
||||
"vis": cfg["data"]["val_split"],
|
||||
}
|
||||
|
||||
split = splits[mode]
|
||||
|
||||
# Create dataset
|
||||
if dataset_type == 'Shapes3D':
|
||||
if dataset_type == "Shapes3D":
|
||||
fields = get_data_fields(mode, cfg)
|
||||
# Input fields
|
||||
inputs_field = get_inputs_field(mode, cfg)
|
||||
if inputs_field is not None:
|
||||
fields['inputs'] = inputs_field
|
||||
fields["inputs"] = inputs_field
|
||||
|
||||
if return_idx:
|
||||
fields['idx'] = data.IndexField()
|
||||
fields["idx"] = data.IndexField()
|
||||
|
||||
dataset = data.Shapes3dDataset(
|
||||
dataset_folder, fields,
|
||||
dataset_folder,
|
||||
fields,
|
||||
split=split,
|
||||
categories=categories,
|
||||
cfg = cfg
|
||||
cfg=cfg,
|
||||
)
|
||||
else:
|
||||
raise ValueError('Invalid dataset "%s"' % cfg['data']['dataset'])
|
||||
raise ValueError('Invalid dataset "%s"' % cfg["data"]["dataset"])
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def get_inputs_field(mode, cfg):
|
||||
''' Returns the inputs fields.
|
||||
"""Returns the inputs fields.
|
||||
|
||||
Args:
|
||||
mode (str): the mode which is used
|
||||
cfg (dict): config dictionary
|
||||
'''
|
||||
input_type = cfg['data']['input_type']
|
||||
"""
|
||||
input_type = cfg["data"]["input_type"]
|
||||
|
||||
if input_type is None:
|
||||
inputs_field = None
|
||||
elif input_type == 'pointcloud':
|
||||
noise_level = cfg['data']['pointcloud_noise']
|
||||
if cfg['data']['pointcloud_outlier_ratio']>0:
|
||||
transform = transforms.Compose([
|
||||
data.SubsamplePointcloud(cfg['data']['pointcloud_n']),
|
||||
elif input_type == "pointcloud":
|
||||
noise_level = cfg["data"]["pointcloud_noise"]
|
||||
if cfg["data"]["pointcloud_outlier_ratio"] > 0:
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
data.SubsamplePointcloud(cfg["data"]["pointcloud_n"]),
|
||||
data.PointcloudNoise(noise_level),
|
||||
data.PointcloudOutliers(cfg['data']['pointcloud_outlier_ratio'])
|
||||
])
|
||||
else:
|
||||
transform = transforms.Compose([
|
||||
data.SubsamplePointcloud(cfg['data']['pointcloud_n']),
|
||||
data.PointcloudNoise(noise_level)
|
||||
])
|
||||
|
||||
data_type = cfg['data']['data_type']
|
||||
inputs_field = data.PointCloudField(
|
||||
cfg['data']['pointcloud_file'], data_type, transform,
|
||||
multi_files= cfg['data']['multi_files']
|
||||
data.PointcloudOutliers(cfg["data"]["pointcloud_outlier_ratio"]),
|
||||
],
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Invalid input type (%s)' % input_type)
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
data.SubsamplePointcloud(cfg["data"]["pointcloud_n"]),
|
||||
data.PointcloudNoise(noise_level),
|
||||
],
|
||||
)
|
||||
|
||||
data_type = cfg["data"]["data_type"]
|
||||
inputs_field = data.PointCloudField(
|
||||
cfg["data"]["pointcloud_file"],
|
||||
data_type,
|
||||
transform,
|
||||
multi_files=cfg["data"]["multi_files"],
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid input type (%s)" % input_type)
|
||||
return inputs_field
|
||||
|
||||
|
||||
def get_data_fields(mode, cfg):
|
||||
''' Returns the data fields.
|
||||
"""Returns the data fields.
|
||||
|
||||
Args:
|
||||
mode (str): the mode which is used
|
||||
cfg (dict): imported yaml config
|
||||
'''
|
||||
data_type = cfg['data']['data_type']
|
||||
"""
|
||||
data_type = cfg["data"]["data_type"]
|
||||
fields = {}
|
||||
|
||||
if (mode in ('val', 'test')):
|
||||
if mode in ("val", "test"):
|
||||
transform = data.SubsamplePointcloud(100000)
|
||||
else:
|
||||
transform = data.SubsamplePointcloud(cfg['data']['num_gt_points'])
|
||||
transform = data.SubsamplePointcloud(cfg["data"]["num_gt_points"])
|
||||
|
||||
data_name = cfg['data']['pointcloud_file']
|
||||
fields['gt_points'] = data.PointCloudField(data_name,
|
||||
transform=transform, data_type=data_type, multi_files=cfg['data']['multi_files'])
|
||||
if data_type == 'psr_full':
|
||||
if mode != 'test':
|
||||
fields['gt_psr'] = data.FullPSRField(multi_files=cfg['data']['multi_files'])
|
||||
data_name = cfg["data"]["pointcloud_file"]
|
||||
fields["gt_points"] = data.PointCloudField(
|
||||
data_name, transform=transform, data_type=data_type, multi_files=cfg["data"]["multi_files"],
|
||||
)
|
||||
if data_type == "psr_full":
|
||||
if mode != "test":
|
||||
fields["gt_psr"] = data.FullPSRField(multi_files=cfg["data"]["multi_files"])
|
||||
else:
|
||||
raise ValueError('Invalid data type (%s)' % data_type)
|
||||
raise ValueError("Invalid data type (%s)" % data_type)
|
||||
|
||||
return fields
|
|
@ -1,14 +1,12 @@
|
|||
|
||||
from src.data.core import (
|
||||
Shapes3dDataset, collate_remove_none, worker_init_fn, collate_stack_together
|
||||
)
|
||||
from src.data.fields import (
|
||||
IndexField, PointCloudField, FullPSRField
|
||||
)
|
||||
from src.data.transforms import (
|
||||
PointcloudNoise, SubsamplePointcloud,
|
||||
PointcloudOutliers,
|
||||
Shapes3dDataset,
|
||||
collate_remove_none,
|
||||
collate_stack_together,
|
||||
worker_init_fn,
|
||||
)
|
||||
from src.data.fields import FullPSRField, IndexField, PointCloudField
|
||||
from src.data.transforms import PointcloudNoise, PointcloudOutliers, SubsamplePointcloud
|
||||
|
||||
__all__ = [
|
||||
# Core
|
||||
Shapes3dDataset,
|
||||
|
|
132
src/data/core.py
132
src/data/core.py
|
@ -1,45 +1,41 @@
|
|||
import os
|
||||
import logging
|
||||
from torch.utils import data
|
||||
from pdb import set_trace as st
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
from torch.utils import data
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Fields
|
||||
class Field(object):
|
||||
''' Data fields class.
|
||||
'''
|
||||
class Field:
|
||||
"""Data fields class."""
|
||||
|
||||
def load(self, data_path, idx, category):
|
||||
''' Loads a data point.
|
||||
"""Loads a data point.
|
||||
|
||||
Args:
|
||||
data_path (str): path to data file
|
||||
idx (int): index of data point
|
||||
category (int): index of category
|
||||
'''
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def check_complete(self, files):
|
||||
''' Checks if set is complete.
|
||||
"""Checks if set is complete.
|
||||
|
||||
Args:
|
||||
files: files
|
||||
'''
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Shapes3dDataset(data.Dataset):
|
||||
''' 3D Shapes dataset class.
|
||||
'''
|
||||
"""3D Shapes dataset class."""
|
||||
|
||||
def __init__(self, dataset_folder, fields, split=None,
|
||||
categories=None, no_except=True, transform=None, cfg=None):
|
||||
''' Initialization of the the 3D shape dataset.
|
||||
def __init__(self, dataset_folder, fields, split=None, categories=None, no_except=True, transform=None, cfg=None):
|
||||
"""Initialization of the the 3D shape dataset.
|
||||
|
||||
Args:
|
||||
dataset_folder (str): dataset folder
|
||||
|
@ -49,7 +45,7 @@ class Shapes3dDataset(data.Dataset):
|
|||
no_except (bool): no exception
|
||||
transform (callable): transformation applied to data points
|
||||
cfg (yaml): config file
|
||||
'''
|
||||
"""
|
||||
# Attributes
|
||||
self.dataset_folder = dataset_folder
|
||||
self.fields = fields
|
||||
|
@ -60,76 +56,69 @@ class Shapes3dDataset(data.Dataset):
|
|||
# If categories is None, use all subfolders
|
||||
if categories is None:
|
||||
categories = os.listdir(dataset_folder)
|
||||
categories = [c for c in categories
|
||||
if os.path.isdir(os.path.join(dataset_folder, c))]
|
||||
categories = [c for c in categories if os.path.isdir(os.path.join(dataset_folder, c))]
|
||||
|
||||
# Read metadata file
|
||||
metadata_file = os.path.join(dataset_folder, 'metadata.yaml')
|
||||
metadata_file = os.path.join(dataset_folder, "metadata.yaml")
|
||||
|
||||
if os.path.exists(metadata_file):
|
||||
with open(metadata_file, 'r') as f:
|
||||
with open(metadata_file) as f:
|
||||
self.metadata = yaml.load(f, Loader=yaml.Loader)
|
||||
else:
|
||||
self.metadata = {
|
||||
c: {'id': c, 'name': 'n/a'} for c in categories
|
||||
}
|
||||
self.metadata = {c: {"id": c, "name": "n/a"} for c in categories}
|
||||
|
||||
# Set index
|
||||
for c_idx, c in enumerate(categories):
|
||||
self.metadata[c]['idx'] = c_idx
|
||||
self.metadata[c]["idx"] = c_idx
|
||||
|
||||
# Get all models
|
||||
self.models = []
|
||||
for c_idx, c in enumerate(categories):
|
||||
subpath = os.path.join(dataset_folder, c)
|
||||
if not os.path.isdir(subpath):
|
||||
logger.warning('Category %s does not exist in dataset.' % c)
|
||||
logger.warning("Category %s does not exist in dataset." % c)
|
||||
|
||||
if split is None:
|
||||
self.models += [
|
||||
{'category': c, 'model': m} for m in [d for d in os.listdir(subpath) if (os.path.isdir(os.path.join(subpath, d)) and d != '') ]
|
||||
{"category": c, "model": m}
|
||||
for m in [d for d in os.listdir(subpath) if (os.path.isdir(os.path.join(subpath, d)) and d != "")]
|
||||
]
|
||||
|
||||
else:
|
||||
split_file = os.path.join(subpath, split + '.lst')
|
||||
with open(split_file, 'r') as f:
|
||||
models_c = f.read().split('\n')
|
||||
split_file = os.path.join(subpath, split + ".lst")
|
||||
with open(split_file) as f:
|
||||
models_c = f.read().split("\n")
|
||||
|
||||
if '' in models_c:
|
||||
models_c.remove('')
|
||||
if "" in models_c:
|
||||
models_c.remove("")
|
||||
|
||||
self.models += [
|
||||
{'category': c, 'model': m}
|
||||
for m in models_c
|
||||
]
|
||||
self.models += [{"category": c, "model": m} for m in models_c]
|
||||
|
||||
# precompute
|
||||
self.split = split
|
||||
|
||||
def __len__(self):
|
||||
''' Returns the length of the dataset.
|
||||
'''
|
||||
"""Returns the length of the dataset."""
|
||||
return len(self.models)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
''' Returns an item of the dataset.
|
||||
"""Returns an item of the dataset.
|
||||
|
||||
Args:
|
||||
idx (int): ID of data point
|
||||
'''
|
||||
|
||||
category = self.models[idx]['category']
|
||||
model = self.models[idx]['model']
|
||||
c_idx = self.metadata[category]['idx']
|
||||
"""
|
||||
category = self.models[idx]["category"]
|
||||
model = self.models[idx]["model"]
|
||||
c_idx = self.metadata[category]["idx"]
|
||||
|
||||
model_path = os.path.join(self.dataset_folder, category, model)
|
||||
data = {}
|
||||
|
||||
info = c_idx
|
||||
|
||||
if self.cfg['data']['multi_files'] is not None:
|
||||
idx = np.random.randint(self.cfg['data']['multi_files'])
|
||||
if self.split != 'train':
|
||||
if self.cfg["data"]["multi_files"] is not None:
|
||||
idx = np.random.randint(self.cfg["data"]["multi_files"])
|
||||
if self.split != "train":
|
||||
idx = 0
|
||||
|
||||
for field_name, field in self.fields.items():
|
||||
|
@ -137,9 +126,8 @@ class Shapes3dDataset(data.Dataset):
|
|||
field_data = field.load(model_path, idx, info)
|
||||
except Exception:
|
||||
if self.no_except:
|
||||
logger.warn(
|
||||
'Error occured when loading field %s of model %s'
|
||||
% (field_name, model)
|
||||
logger.warning(
|
||||
f"Error occured when loading field {field_name} of model {model}",
|
||||
)
|
||||
return None
|
||||
else:
|
||||
|
@ -150,7 +138,7 @@ class Shapes3dDataset(data.Dataset):
|
|||
if k is None:
|
||||
data[field_name] = v
|
||||
else:
|
||||
data['%s.%s' % (field_name, k)] = v
|
||||
data[f"{field_name}.{k}"] = v
|
||||
else:
|
||||
data[field_name] = field_data
|
||||
|
||||
|
@ -159,77 +147,75 @@ class Shapes3dDataset(data.Dataset):
|
|||
|
||||
return data
|
||||
|
||||
|
||||
def get_model_dict(self, idx):
|
||||
return self.models[idx]
|
||||
|
||||
def test_model_complete(self, category, model):
|
||||
''' Tests if model is complete.
|
||||
"""Tests if model is complete.
|
||||
|
||||
Args:
|
||||
model (str): modelname
|
||||
'''
|
||||
"""
|
||||
model_path = os.path.join(self.dataset_folder, category, model)
|
||||
files = os.listdir(model_path)
|
||||
for field_name, field in self.fields.items():
|
||||
if not field.check_complete(files):
|
||||
logger.warn('Field "%s" is incomplete: %s'
|
||||
% (field_name, model_path))
|
||||
logger.warning(f'Field "{field_name}" is incomplete: {model_path}')
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def collate_remove_none(batch):
|
||||
''' Collater that puts each data field into a tensor with outer dimension
|
||||
"""Collater that puts each data field into a tensor with outer dimension
|
||||
batch size.
|
||||
|
||||
Args:
|
||||
batch: batch
|
||||
'''
|
||||
|
||||
"""
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
return data.dataloader.default_collate(batch)
|
||||
|
||||
|
||||
def collate_stack_together(batch):
|
||||
''' Collater that puts each data field into a tensor with outer dimension
|
||||
"""Collater that puts each data field into a tensor with outer dimension
|
||||
batch size.
|
||||
|
||||
Args:
|
||||
batch: batch
|
||||
'''
|
||||
|
||||
"""
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
keys = batch[0].keys()
|
||||
concat = {}
|
||||
if len(batch)>1:
|
||||
if len(batch) > 1:
|
||||
for key in keys:
|
||||
key_val = [item[key] for item in batch]
|
||||
concat[key] = np.concatenate(key_val, axis=0)
|
||||
if key == 'inputs':
|
||||
if key == "inputs":
|
||||
n_pts = [item[key].shape[0] for item in batch]
|
||||
|
||||
concat['batch_ind'] = np.concatenate(
|
||||
[i * np.ones(n, dtype=int) for i, n in enumerate(n_pts)], axis=0)
|
||||
concat["batch_ind"] = np.concatenate([i * np.ones(n, dtype=int) for i, n in enumerate(n_pts)], axis=0)
|
||||
|
||||
return data.dataloader.default_collate([concat])
|
||||
else:
|
||||
n_pts = batch[0]['inputs'].shape[0]
|
||||
batch[0]['batch_ind'] = np.zeros(n_pts, dtype=int)
|
||||
n_pts = batch[0]["inputs"].shape[0]
|
||||
batch[0]["batch_ind"] = np.zeros(n_pts, dtype=int)
|
||||
return data.dataloader.default_collate(batch)
|
||||
|
||||
|
||||
def worker_init_fn(worker_id):
|
||||
''' Worker init function to ensure true randomness.
|
||||
'''
|
||||
"""Worker init function to ensure true randomness."""
|
||||
|
||||
def set_num_threads(nt):
|
||||
try:
|
||||
import mkl; mkl.set_num_threads(nt)
|
||||
import mkl
|
||||
|
||||
mkl.set_num_threads(nt)
|
||||
except:
|
||||
pass
|
||||
torch.set_num_threads(1)
|
||||
os.environ['IPC_ENABLE']='1'
|
||||
for o in ['OPENBLAS_NUM_THREADS','NUMEXPR_NUM_THREADS','OMP_NUM_THREADS','MKL_NUM_THREADS']:
|
||||
os.environ["IPC_ENABLE"] = "1"
|
||||
for o in ["OPENBLAS_NUM_THREADS", "NUMEXPR_NUM_THREADS", "OMP_NUM_THREADS", "MKL_NUM_THREADS"]:
|
||||
os.environ[o] = str(nt)
|
||||
|
||||
random_data = os.urandom(4)
|
||||
|
|
|
@ -1,34 +1,32 @@
|
|||
import os
|
||||
import glob
|
||||
import time
|
||||
import random
|
||||
from PIL import Image
|
||||
|
||||
import numpy as np
|
||||
import trimesh
|
||||
|
||||
from src.data.core import Field
|
||||
from pdb import set_trace as st
|
||||
|
||||
|
||||
class IndexField(Field):
|
||||
''' Basic index field.'''
|
||||
"""Basic index field."""
|
||||
|
||||
def load(self, model_path, idx, category):
|
||||
''' Loads the index field.
|
||||
"""Loads the index field.
|
||||
|
||||
Args:
|
||||
model_path (str): path to model
|
||||
idx (int): ID of data point
|
||||
category (int): index of category
|
||||
'''
|
||||
"""
|
||||
return idx
|
||||
|
||||
def check_complete(self, files):
|
||||
''' Check if field is complete.
|
||||
"""Check if field is complete.
|
||||
|
||||
Args:
|
||||
files: files
|
||||
'''
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
class FullPSRField(Field):
|
||||
def __init__(self, transform=None, multi_files=None):
|
||||
self.transform = transform
|
||||
|
@ -36,16 +34,15 @@ class FullPSRField(Field):
|
|||
self.multi_files = multi_files
|
||||
|
||||
def load(self, model_path, idx, category):
|
||||
|
||||
# try:
|
||||
# t0 = time.time()
|
||||
if self.multi_files is not None:
|
||||
psr_path = os.path.join(model_path, 'psr', 'psr_{:02d}.npz'.format(idx))
|
||||
psr_path = os.path.join(model_path, "psr", f"psr_{idx:02d}.npz")
|
||||
else:
|
||||
psr_path = os.path.join(model_path, 'psr.npz')
|
||||
psr_path = os.path.join(model_path, "psr.npz")
|
||||
psr_dict = np.load(psr_path)
|
||||
# t1 = time.time()
|
||||
psr = psr_dict['psr']
|
||||
psr = psr_dict["psr"]
|
||||
psr = psr.astype(np.float32)
|
||||
# t2 = time.time()
|
||||
# print('load PSR: {:.4f}, change type: {:.4f}, total: {:.4f}'.format(t1 - t0, t2 - t1, t2-t0))
|
||||
|
@ -56,8 +53,9 @@ class FullPSRField(Field):
|
|||
|
||||
return data
|
||||
|
||||
|
||||
class PointCloudField(Field):
|
||||
''' Point cloud field.
|
||||
"""Point cloud field.
|
||||
|
||||
It provides the field used for point cloud data. These are the points
|
||||
randomly sampled on the mesh.
|
||||
|
@ -66,7 +64,8 @@ class PointCloudField(Field):
|
|||
file_name (str): file name
|
||||
transform (list): list of transformations applied to data points
|
||||
multi_files (callable): number of files
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, file_name, data_type=None, transform=None, multi_files=None, padding=0.1, scale=1.2):
|
||||
self.file_name = file_name
|
||||
self.data_type = data_type # to make sure the range of input is correct
|
||||
|
@ -76,43 +75,43 @@ class PointCloudField(Field):
|
|||
self.scale = scale
|
||||
|
||||
def load(self, model_path, idx, category):
|
||||
''' Loads the data point.
|
||||
"""Loads the data point.
|
||||
|
||||
Args:
|
||||
model_path (str): path to model
|
||||
idx (int): ID of data point
|
||||
category (int): index of category
|
||||
'''
|
||||
"""
|
||||
if self.multi_files is None:
|
||||
file_path = os.path.join(model_path, self.file_name)
|
||||
else:
|
||||
# num = np.random.randint(self.multi_files)
|
||||
# file_path = os.path.join(model_path, self.file_name, '%s_%02d.npz' % (self.file_name, num))
|
||||
file_path = os.path.join(model_path, self.file_name, 'pointcloud_%02d.npz' % (idx))
|
||||
file_path = os.path.join(model_path, self.file_name, "pointcloud_%02d.npz" % (idx))
|
||||
|
||||
pointcloud_dict = np.load(file_path)
|
||||
|
||||
points = pointcloud_dict['points'].astype(np.float32)
|
||||
normals = pointcloud_dict['normals'].astype(np.float32)
|
||||
points = pointcloud_dict["points"].astype(np.float32)
|
||||
normals = pointcloud_dict["normals"].astype(np.float32)
|
||||
|
||||
data = {
|
||||
None: points,
|
||||
'normals': normals,
|
||||
"normals": normals,
|
||||
}
|
||||
if self.transform is not None:
|
||||
data = self.transform(data)
|
||||
|
||||
if self.data_type == 'psr_full':
|
||||
if self.data_type == "psr_full":
|
||||
# scale the point cloud to the range of (0, 1)
|
||||
data[None] = data[None] / self.scale + 0.5
|
||||
|
||||
return data
|
||||
|
||||
def check_complete(self, files):
|
||||
''' Check if field is complete.
|
||||
"""Check if field is complete.
|
||||
|
||||
Args:
|
||||
files: files
|
||||
'''
|
||||
complete = (self.file_name in files)
|
||||
"""
|
||||
complete = self.file_name in files
|
||||
return complete
|
||||
|
|
|
@ -2,24 +2,24 @@ import numpy as np
|
|||
|
||||
|
||||
# Transforms
|
||||
class PointcloudNoise(object):
|
||||
''' Point cloud noise transformation class.
|
||||
class PointcloudNoise:
|
||||
"""Point cloud noise transformation class.
|
||||
|
||||
It adds noise to point cloud data.
|
||||
|
||||
Args:
|
||||
stddev (int): standard deviation
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, stddev):
|
||||
self.stddev = stddev
|
||||
|
||||
def __call__(self, data):
|
||||
''' Calls the transformation.
|
||||
"""Calls the transformation.
|
||||
|
||||
Args:
|
||||
data (dictionary): data dictionary
|
||||
'''
|
||||
"""
|
||||
data_out = data.copy()
|
||||
points = data[None]
|
||||
noise = self.stddev * np.random.randn(*points.shape)
|
||||
|
@ -27,28 +27,29 @@ class PointcloudNoise(object):
|
|||
data_out[None] = points + noise
|
||||
return data_out
|
||||
|
||||
class PointcloudOutliers(object):
|
||||
''' Point cloud outlier transformation class.
|
||||
|
||||
class PointcloudOutliers:
|
||||
"""Point cloud outlier transformation class.
|
||||
|
||||
It adds outliers to point cloud data.
|
||||
|
||||
Args:
|
||||
ratio (int): outlier percentage to the entire point cloud
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, ratio):
|
||||
self.ratio = ratio
|
||||
|
||||
def __call__(self, data):
|
||||
''' Calls the transformation.
|
||||
"""Calls the transformation.
|
||||
|
||||
Args:
|
||||
data (dictionary): data dictionary
|
||||
'''
|
||||
"""
|
||||
data_out = data.copy()
|
||||
points = data[None]
|
||||
n_points = points.shape[0]
|
||||
n_outlier_points = int(n_points*self.ratio)
|
||||
n_outlier_points = int(n_points * self.ratio)
|
||||
ind = np.random.randint(0, n_points, n_outlier_points)
|
||||
|
||||
outliers = np.random.uniform(-0.55, 0.55, (n_outlier_points, 3))
|
||||
|
@ -57,30 +58,32 @@ class PointcloudOutliers(object):
|
|||
data_out[None] = points
|
||||
return data_out
|
||||
|
||||
class SubsamplePointcloud(object):
|
||||
''' Point cloud subsampling transformation class.
|
||||
|
||||
class SubsamplePointcloud:
|
||||
"""Point cloud subsampling transformation class.
|
||||
|
||||
It subsamples the point cloud data.
|
||||
|
||||
Args:
|
||||
N (int): number of points to be subsampled
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, N):
|
||||
self.N = N
|
||||
|
||||
def __call__(self, data):
|
||||
''' Calls the transformation.
|
||||
"""Calls the transformation.
|
||||
|
||||
Args:
|
||||
data (dict): data dictionary
|
||||
'''
|
||||
"""
|
||||
data_out = data.copy()
|
||||
points = data[None]
|
||||
|
||||
indices = np.random.randint(points.shape[0], size=self.N)
|
||||
data_out[None] = points[indices, :]
|
||||
if 'normals' in data.keys():
|
||||
normals = data['normals']
|
||||
data_out['normals'] = normals[indices, :]
|
||||
if "normals" in data.keys():
|
||||
normals = data["normals"]
|
||||
data_out["normals"] = normals[indices, :]
|
||||
|
||||
return data_out
|
|
@ -1,17 +1,20 @@
|
|||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from torch.utils import data
|
||||
from src.utils import load_rgb, load_mask, get_camera_params
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.renderer import PerspectiveCameras
|
||||
from skimage import img_as_float32
|
||||
from torch.utils import data
|
||||
|
||||
from src.utils import get_camera_params, load_mask, load_rgb
|
||||
|
||||
##################################################
|
||||
# Below are for the differentiable renderer
|
||||
# Taken from https://github.com/lioryariv/idr/blob/main/code/utils/rend_util.py
|
||||
|
||||
|
||||
def load_rgb(path):
|
||||
img = imageio.imread(path)
|
||||
img = img_as_float32(img)
|
||||
|
@ -23,6 +26,7 @@ def load_rgb(path):
|
|||
# img = img.transpose(2, 0, 1)
|
||||
return img
|
||||
|
||||
|
||||
def load_mask(path):
|
||||
alpha = imageio.imread(path, as_gray=True)
|
||||
alpha = img_as_float32(alpha)
|
||||
|
@ -32,10 +36,10 @@ def load_mask(path):
|
|||
|
||||
|
||||
def get_camera_params(uv, pose, intrinsics):
|
||||
if pose.shape[1] == 7: #In case of quaternion vector representation
|
||||
if pose.shape[1] == 7: # In case of quaternion vector representation
|
||||
cam_loc = pose[:, 4:]
|
||||
R = quat_to_rot(pose[:,:4])
|
||||
p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float()
|
||||
R = quat_to_rot(pose[:, :4])
|
||||
p = torch.eye(4).repeat(pose.shape[0], 1, 1).cuda().float()
|
||||
p[:, :3, :3] = R
|
||||
p[:, :3, 3] = cam_loc
|
||||
else: # In case of pose matrix representation
|
||||
|
@ -60,25 +64,27 @@ def get_camera_params(uv, pose, intrinsics):
|
|||
|
||||
return ray_dirs, cam_loc
|
||||
|
||||
|
||||
def quat_to_rot(q):
|
||||
batch_size, _ = q.shape
|
||||
q = F.normalize(q, dim=1)
|
||||
R = torch.ones((batch_size, 3,3)).cuda()
|
||||
qr=q[:,0]
|
||||
R = torch.ones((batch_size, 3, 3)).cuda()
|
||||
qr = q[:, 0]
|
||||
qi = q[:, 1]
|
||||
qj = q[:, 2]
|
||||
qk = q[:, 3]
|
||||
R[:, 0, 0]=1-2 * (qj**2 + qk**2)
|
||||
R[:, 0, 1] = 2 * (qj *qi -qk*qr)
|
||||
R[:, 0, 0] = 1 - 2 * (qj**2 + qk**2)
|
||||
R[:, 0, 1] = 2 * (qj * qi - qk * qr)
|
||||
R[:, 0, 2] = 2 * (qi * qk + qr * qj)
|
||||
R[:, 1, 0] = 2 * (qj * qi + qk * qr)
|
||||
R[:, 1, 1] = 1-2 * (qi**2 + qk**2)
|
||||
R[:, 1, 2] = 2*(qj*qk - qi*qr)
|
||||
R[:, 2, 0] = 2 * (qk * qi-qj * qr)
|
||||
R[:, 2, 1] = 2 * (qj*qk + qi*qr)
|
||||
R[:, 2, 2] = 1-2 * (qi**2 + qj**2)
|
||||
R[:, 1, 1] = 1 - 2 * (qi**2 + qk**2)
|
||||
R[:, 1, 2] = 2 * (qj * qk - qi * qr)
|
||||
R[:, 2, 0] = 2 * (qk * qi - qj * qr)
|
||||
R[:, 2, 1] = 2 * (qj * qk + qi * qr)
|
||||
R[:, 2, 2] = 1 - 2 * (qi**2 + qj**2)
|
||||
return R
|
||||
|
||||
|
||||
def lift(x, y, z, intrinsics):
|
||||
# parse intrinsics
|
||||
# intrinsics = intrinsics.cuda()
|
||||
|
@ -88,7 +94,16 @@ def lift(x, y, z, intrinsics):
|
|||
cy = intrinsics[:, 1, 2]
|
||||
sk = intrinsics[:, 0, 1]
|
||||
|
||||
x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z
|
||||
x_lift = (
|
||||
(
|
||||
x
|
||||
- cx.unsqueeze(-1)
|
||||
+ cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1)
|
||||
- sk.unsqueeze(-1) * y / fy.unsqueeze(-1)
|
||||
)
|
||||
/ fx.unsqueeze(-1)
|
||||
* z
|
||||
)
|
||||
y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z
|
||||
|
||||
# homogeneous
|
||||
|
@ -96,21 +111,18 @@ def lift(x, y, z, intrinsics):
|
|||
|
||||
|
||||
class PixelNeRFDTUDataset(data.Dataset):
|
||||
"""
|
||||
Processed DTU from pixelNeRF
|
||||
"""
|
||||
def __init__(self,
|
||||
data_dir='data/DTU',
|
||||
"""Processed DTU from pixelNeRF."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir="data/DTU",
|
||||
scan_id=65,
|
||||
img_size=None,
|
||||
device=None,
|
||||
fixed_scale=0,
|
||||
):
|
||||
data_dir = os.path.join(data_dir, "scan{}".format(scan_id))
|
||||
rgb_paths = [
|
||||
x for x in glob(os.path.join(data_dir, "image", "*"))
|
||||
if (x.endswith(".jpg") or x.endswith(".png"))
|
||||
]
|
||||
data_dir = os.path.join(data_dir, f"scan{scan_id}")
|
||||
rgb_paths = [x for x in glob(os.path.join(data_dir, "image", "*")) if (x.endswith((".jpg", ".png")))]
|
||||
rgb_paths = sorted(rgb_paths)
|
||||
mask_paths = sorted(glob(os.path.join(data_dir, "mask", "*.png")))
|
||||
if len(mask_paths) == 0:
|
||||
|
@ -129,21 +141,18 @@ class PixelNeRFDTUDataset(data.Dataset):
|
|||
all_T = []
|
||||
|
||||
for idx, (rgb_path, mask_path) in enumerate(zip(rgb_paths, mask_paths)):
|
||||
|
||||
i = sel_indices[idx]
|
||||
rgb = load_rgb(rgb_path)
|
||||
mask = load_mask(mask_path)
|
||||
rgb[~mask] = 0.
|
||||
rgb[~mask] = 0.0
|
||||
rgb = torch.from_numpy(rgb).float().to(device)
|
||||
mask = torch.from_numpy(mask).float().to(device)
|
||||
x_scale = y_scale = 1.0
|
||||
xy_delta = 0.0
|
||||
|
||||
P = all_cam["world_mat_" + str(i)]
|
||||
P = P[:3]
|
||||
|
||||
# scale the original shape to really [-0.9, 0.9]
|
||||
if fixed_scale!=0.:
|
||||
if fixed_scale != 0.0:
|
||||
scale_mat_new = np.eye(4, 4)
|
||||
scale_mat_new[:3, :3] *= fixed_scale # scale to [-0.9, 0.9]
|
||||
P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)] @ scale_mat_new
|
||||
|
@ -158,38 +167,34 @@ class PixelNeRFDTUDataset(data.Dataset):
|
|||
|
||||
########!!!!!
|
||||
RR = torch.from_numpy(R).permute(1, 0).unsqueeze(0)
|
||||
tt = torch.from_numpy(-R@(t[:3] / t[3])).permute(1, 0)
|
||||
tt = torch.from_numpy(-R @ (t[:3] / t[3])).permute(1, 0)
|
||||
focal = torch.tensor((fx, fy), dtype=torch.float32).unsqueeze(0)
|
||||
pc = torch.tensor((cx, cy), dtype=torch.float32).unsqueeze(0)
|
||||
im_size = (rgb.shape[1], rgb.shape[0])
|
||||
|
||||
# check https://pytorch3d.org/docs/cameras for how to transform from screen to NDC
|
||||
s = min(im_size)
|
||||
focal[:, 0] = focal[:, 0] * 2 / (s-1)
|
||||
focal[:, 1] = focal[:, 1] * 2 /(s-1)
|
||||
pc[:, 0] = -(pc[:, 0] - (im_size[0]-1)/2) * 2 / (s-1)
|
||||
pc[:, 1] = -(pc[:, 1] - (im_size[1]-1)/2) * 2 / (s-1)
|
||||
focal[:, 0] = focal[:, 0] * 2 / (s - 1)
|
||||
focal[:, 1] = focal[:, 1] * 2 / (s - 1)
|
||||
pc[:, 0] = -(pc[:, 0] - (im_size[0] - 1) / 2) * 2 / (s - 1)
|
||||
pc[:, 1] = -(pc[:, 1] - (im_size[1] - 1) / 2) * 2 / (s - 1)
|
||||
|
||||
camera = PerspectiveCameras(focal_length=-focal, principal_point=pc,
|
||||
device=device, R=RR, T=tt)
|
||||
camera = PerspectiveCameras(focal_length=-focal, principal_point=pc, device=device, R=RR, T=tt)
|
||||
|
||||
# calculate camera rays
|
||||
uv = uv_creation(im_size)[None].float()
|
||||
pose = np.eye(4, dtype=np.float32)
|
||||
pose[:3, :3] = R.transpose()
|
||||
pose[:3,3] = (t[:3] / t[3])[:,0]
|
||||
pose[:3, 3] = (t[:3] / t[3])[:, 0]
|
||||
pose = torch.from_numpy(pose)[None].float()
|
||||
intrinsics = np.eye(4)
|
||||
intrinsics[:3, :3] = K
|
||||
intrinsics[0, 1] = 0. #! remove skew for now
|
||||
intrinsics[0, 1] = 0.0 #! remove skew for now
|
||||
intrinsics = torch.from_numpy(intrinsics)[None].float()
|
||||
|
||||
|
||||
rays, _ = get_camera_params(uv, pose, intrinsics)
|
||||
rays = -rays.to(device)
|
||||
|
||||
|
||||
|
||||
all_poses.append(camera)
|
||||
all_imgs.append(rgb)
|
||||
all_masks.append(mask)
|
||||
|
@ -198,7 +203,7 @@ class PixelNeRFDTUDataset(data.Dataset):
|
|||
# only for neural renderer
|
||||
all_K.append(torch.tensor(K).to(device))
|
||||
all_R.append(torch.tensor(R).to(device))
|
||||
all_T.append(torch.tensor(t[:3]/t[3]).to(device))
|
||||
all_T.append(torch.tensor(t[:3] / t[3]).to(device))
|
||||
|
||||
all_imgs = torch.stack(all_imgs)
|
||||
all_masks = torch.stack(all_masks)
|
||||
|
@ -210,15 +215,16 @@ class PixelNeRFDTUDataset(data.Dataset):
|
|||
all_T = torch.stack(all_T).permute(0, 2, 1).float()
|
||||
|
||||
uv = uv_creation((all_imgs.size(2), all_imgs.size(1)))
|
||||
self.data = {'rgbs': all_imgs,
|
||||
'masks': all_masks,
|
||||
'poses': all_poses,
|
||||
'rays': all_rays,
|
||||
'uv': uv,
|
||||
'light_pose': all_light_pose, # for rendering lights
|
||||
'K': all_K,
|
||||
'R': all_R,
|
||||
'T': all_T,
|
||||
self.data = {
|
||||
"rgbs": all_imgs,
|
||||
"masks": all_masks,
|
||||
"poses": all_poses,
|
||||
"rays": all_rays,
|
||||
"uv": uv,
|
||||
"light_pose": all_light_pose, # for rendering lights
|
||||
"K": all_K,
|
||||
"R": all_R,
|
||||
"T": all_T,
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
|
|
37
src/dpsr.py
37
src/dpsr.py
|
@ -1,17 +1,17 @@
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from src.utils import spec_gaussian_filter, fftfreqs, img, grid_interp, point_rasterize
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.fft
|
||||
import torch.nn as nn
|
||||
|
||||
from src.utils import fftfreqs, grid_interp, img, point_rasterize, spec_gaussian_filter
|
||||
|
||||
|
||||
class DPSR(nn.Module):
|
||||
def __init__(self, res, sig=10, scale=True, shift=True):
|
||||
"""
|
||||
:param res: tuple of output field resolution. eg., (128,128)
|
||||
""":param res: tuple of output field resolution. eg., (128,128)
|
||||
:param sig: degree of gaussian smoothing
|
||||
"""
|
||||
super(DPSR, self).__init__()
|
||||
super().__init__()
|
||||
self.res = res
|
||||
self.sig = sig
|
||||
self.dim = len(res)
|
||||
|
@ -24,16 +24,15 @@ class DPSR(nn.Module):
|
|||
self.register_buffer("G", G)
|
||||
|
||||
def forward(self, V, N):
|
||||
"""
|
||||
:param V: (batch, nv, 2 or 3) tensor for point cloud coordinates
|
||||
""":param V: (batch, nv, 2 or 3) tensor for point cloud coordinates
|
||||
:param N: (batch, nv, 2 or 3) tensor for point normals
|
||||
:return phi: (batch, res, res, ...) tensor of output indicator function field
|
||||
"""
|
||||
assert(V.shape == N.shape) # [b, nv, ndims]
|
||||
assert V.shape == N.shape # [b, nv, ndims]
|
||||
ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2]
|
||||
|
||||
ras_s = torch.fft.rfftn(ras_p, dim=(2,3,4))
|
||||
ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+1))+[self.dim+1, 1]))
|
||||
ras_s = torch.fft.rfftn(ras_p, dim=(2, 3, 4))
|
||||
ras_s = ras_s.permute(*tuple([0, *list(range(2, self.dim + 1)), self.dim + 1, 1]))
|
||||
N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1]
|
||||
|
||||
omega = fftfreqs(self.res, dtype=torch.float32).unsqueeze(-1) # [dim0, dim1, dim2/2+1, n_dim, 1]
|
||||
|
@ -43,12 +42,12 @@ class DPSR(nn.Module):
|
|||
DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2)
|
||||
|
||||
Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1]
|
||||
Phi = DivN / (Lap+1e-6) # [b, dim0, dim1, dim2/2+1, 2]
|
||||
Phi = Phi.permute(*tuple([list(range(1,self.dim+2)) + [0]])) # [dim0, dim1, dim2/2+1, 2, b]
|
||||
Phi = DivN / (Lap + 1e-6) # [b, dim0, dim1, dim2/2+1, 2]
|
||||
Phi = Phi.permute(*tuple([[*list(range(1, self.dim + 2)), 0]])) # [dim0, dim1, dim2/2+1, 2, b]
|
||||
Phi[tuple([0] * self.dim)] = 0
|
||||
Phi = Phi.permute(*tuple([[self.dim+1] + list(range(self.dim+1))])) # [b, dim0, dim1, dim2/2+1, 2]
|
||||
Phi = Phi.permute(*tuple([[self.dim + 1, *list(range(self.dim + 1))]])) # [b, dim0, dim1, dim2/2+1, 2]
|
||||
|
||||
phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1,2,3))
|
||||
phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1, 2, 3))
|
||||
|
||||
if self.shift or self.scale:
|
||||
# ensure values at points are zero
|
||||
|
@ -57,10 +56,10 @@ class DPSR(nn.Module):
|
|||
offset = torch.mean(fv, dim=-1) # [b,]
|
||||
phi -= offset.view(*tuple([-1] + [1] * self.dim))
|
||||
|
||||
phi = phi.permute(*tuple([list(range(1,self.dim+1)) + [0]]))
|
||||
phi = phi.permute(*tuple([[*list(range(1, self.dim + 1)), 0]]))
|
||||
fv0 = phi[tuple([0] * self.dim)] # [b,]
|
||||
phi = phi.permute(*tuple([[self.dim] + list(range(self.dim))]))
|
||||
phi = phi.permute(*tuple([[self.dim, *list(range(self.dim))]]))
|
||||
|
||||
if self.scale:
|
||||
phi = -phi / torch.abs(fv0.view(*tuple([-1]+[1] * self.dim))) *0.5
|
||||
phi = -phi / torch.abs(fv0.view(*tuple([-1] + [1] * self.dim))) * 0.5
|
||||
return phi
|
126
src/eval.py
126
src/eval.py
|
@ -1,43 +1,45 @@
|
|||
import logging
|
||||
|
||||
import numpy as np
|
||||
import trimesh
|
||||
from pykdtree.kdtree import KDTree
|
||||
|
||||
EMPTY_PCL_DICT = {
|
||||
'completeness': np.sqrt(3),
|
||||
'accuracy': np.sqrt(3),
|
||||
'completeness2': 3,
|
||||
'accuracy2': 3,
|
||||
'chamfer': 6,
|
||||
"completeness": np.sqrt(3),
|
||||
"accuracy": np.sqrt(3),
|
||||
"completeness2": 3,
|
||||
"accuracy2": 3,
|
||||
"chamfer": 6,
|
||||
}
|
||||
|
||||
EMPTY_PCL_DICT_NORMALS = {
|
||||
'normals completeness': -1.,
|
||||
'normals accuracy': -1.,
|
||||
'normals': -1.,
|
||||
"normals completeness": -1.0,
|
||||
"normals accuracy": -1.0,
|
||||
"normals": -1.0,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MeshEvaluator(object):
|
||||
''' Mesh evaluation class.
|
||||
class MeshEvaluator:
|
||||
"""Mesh evaluation class.
|
||||
It handles the mesh evaluation process.
|
||||
|
||||
Args:
|
||||
n_points (int): number of points to be used for evaluation
|
||||
'''
|
||||
n_points (int): number of points to be used for evaluation.
|
||||
"""
|
||||
|
||||
def __init__(self, n_points=100000):
|
||||
self.n_points = n_points
|
||||
|
||||
def eval_mesh(self, mesh, pointcloud_tgt, normals_tgt, thresholds=np.linspace(1./1000, 1, 1000)):
|
||||
''' Evaluates a mesh.
|
||||
def eval_mesh(self, mesh, pointcloud_tgt, normals_tgt, thresholds=np.linspace(1.0 / 1000, 1, 1000)):
|
||||
"""Evaluates a mesh.
|
||||
|
||||
Args:
|
||||
mesh (trimesh): mesh which should be evaluated
|
||||
pointcloud_tgt (numpy array): target point cloud
|
||||
normals_tgt (numpy array): target normals
|
||||
thresholds (numpy arry): for F-Score
|
||||
'''
|
||||
thresholds (numpy arry): for F-Score.
|
||||
"""
|
||||
if len(mesh.vertices) != 0 and len(mesh.faces) != 0:
|
||||
pointcloud, idx = mesh.sample(self.n_points, return_index=True)
|
||||
|
||||
|
@ -47,25 +49,25 @@ class MeshEvaluator(object):
|
|||
pointcloud = np.empty((0, 3))
|
||||
normals = np.empty((0, 3))
|
||||
|
||||
out_dict = self.eval_pointcloud(
|
||||
pointcloud, pointcloud_tgt, normals, normals_tgt, thresholds=thresholds)
|
||||
out_dict = self.eval_pointcloud(pointcloud, pointcloud_tgt, normals, normals_tgt, thresholds=thresholds)
|
||||
|
||||
return out_dict
|
||||
|
||||
def eval_pointcloud(self, pointcloud, pointcloud_tgt,
|
||||
normals=None, normals_tgt=None,
|
||||
thresholds=np.linspace(1./1000, 1, 1000)):
|
||||
''' Evaluates a point cloud.
|
||||
def eval_pointcloud(
|
||||
self, pointcloud, pointcloud_tgt, normals=None, normals_tgt=None, thresholds=np.linspace(1.0 / 1000, 1, 1000),
|
||||
):
|
||||
"""Evaluates a point cloud.
|
||||
|
||||
Args:
|
||||
pointcloud (numpy array): predicted point cloud
|
||||
pointcloud_tgt (numpy array): target point cloud
|
||||
normals (numpy array): predicted normals
|
||||
normals_tgt (numpy array): target normals
|
||||
thresholds (numpy array): threshold values for the F-score calculation
|
||||
'''
|
||||
thresholds (numpy array): threshold values for the F-score calculation.
|
||||
"""
|
||||
# Return maximum losses if pointcloud is empty
|
||||
if pointcloud.shape[0] == 0:
|
||||
logger.warn('Empty pointcloud / mesh detected!')
|
||||
logger.warning("Empty pointcloud / mesh detected!")
|
||||
out_dict = EMPTY_PCL_DICT.copy()
|
||||
if normals is not None and normals_tgt is not None:
|
||||
out_dict.update(EMPTY_PCL_DICT_NORMALS)
|
||||
|
@ -74,11 +76,13 @@ class MeshEvaluator(object):
|
|||
pointcloud = np.asarray(pointcloud)
|
||||
pointcloud_tgt = np.asarray(pointcloud_tgt)
|
||||
|
||||
|
||||
# Completeness: how far are the points of the target point cloud
|
||||
# from thre predicted point cloud
|
||||
completeness, completeness_normals = distance_p2p(
|
||||
pointcloud_tgt, normals_tgt, pointcloud, normals
|
||||
pointcloud_tgt,
|
||||
normals_tgt,
|
||||
pointcloud,
|
||||
normals,
|
||||
)
|
||||
recall = get_threshold_percentage(completeness, thresholds)
|
||||
completeness2 = completeness**2
|
||||
|
@ -90,7 +94,10 @@ class MeshEvaluator(object):
|
|||
# Accuracy: how far are th points of the predicted pointcloud
|
||||
# from the target pointcloud
|
||||
accuracy, accuracy_normals = distance_p2p(
|
||||
pointcloud, normals, pointcloud_tgt, normals_tgt
|
||||
pointcloud,
|
||||
normals,
|
||||
pointcloud_tgt,
|
||||
normals_tgt,
|
||||
)
|
||||
precision = get_threshold_percentage(accuracy, thresholds)
|
||||
accuracy2 = accuracy**2
|
||||
|
@ -101,68 +108,61 @@ class MeshEvaluator(object):
|
|||
|
||||
# Chamfer distance
|
||||
chamferL2 = 0.5 * (completeness2 + accuracy2)
|
||||
normals_correctness = (
|
||||
0.5 * completeness_normals + 0.5 * accuracy_normals
|
||||
)
|
||||
normals_correctness = 0.5 * completeness_normals + 0.5 * accuracy_normals
|
||||
chamferL1 = 0.5 * (completeness + accuracy)
|
||||
|
||||
# F-Score
|
||||
F = [
|
||||
2 * precision[i] * recall[i] / (precision[i] + recall[i])
|
||||
for i in range(len(precision))
|
||||
]
|
||||
F = [2 * precision[i] * recall[i] / (precision[i] + recall[i]) for i in range(len(precision))]
|
||||
|
||||
out_dict = {
|
||||
'completeness': completeness,
|
||||
'accuracy': accuracy,
|
||||
'normals completeness': completeness_normals,
|
||||
'normals accuracy': accuracy_normals,
|
||||
'normals': normals_correctness,
|
||||
'completeness2': completeness2,
|
||||
'accuracy2': accuracy2,
|
||||
'chamfer-L2': chamferL2,
|
||||
'chamfer-L1': chamferL1,
|
||||
'f-score': F[9], # threshold = 1.0%
|
||||
'f-score-15': F[14], # threshold = 1.5%
|
||||
'f-score-20': F[19], # threshold = 2.0%
|
||||
"completeness": completeness,
|
||||
"accuracy": accuracy,
|
||||
"normals completeness": completeness_normals,
|
||||
"normals accuracy": accuracy_normals,
|
||||
"normals": normals_correctness,
|
||||
"completeness2": completeness2,
|
||||
"accuracy2": accuracy2,
|
||||
"chamfer-L2": chamferL2,
|
||||
"chamfer-L1": chamferL1,
|
||||
"f-score": F[9], # threshold = 1.0%
|
||||
"f-score-15": F[14], # threshold = 1.5%
|
||||
"f-score-20": F[19], # threshold = 2.0%
|
||||
}
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
def distance_p2p(points_src, normals_src, points_tgt, normals_tgt):
|
||||
''' Computes minimal distances of each point in points_src to points_tgt.
|
||||
"""Computes minimal distances of each point in points_src to points_tgt.
|
||||
|
||||
Args:
|
||||
points_src (numpy array): source points
|
||||
normals_src (numpy array): source normals
|
||||
points_tgt (numpy array): target points
|
||||
normals_tgt (numpy array): target normals
|
||||
'''
|
||||
normals_tgt (numpy array): target normals.
|
||||
"""
|
||||
kdtree = KDTree(points_tgt)
|
||||
dist, idx = kdtree.query(points_src)
|
||||
|
||||
if normals_src is not None and normals_tgt is not None:
|
||||
normals_src = \
|
||||
normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True)
|
||||
normals_tgt = \
|
||||
normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True)
|
||||
normals_src = normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True)
|
||||
normals_tgt = normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True)
|
||||
|
||||
normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1)
|
||||
# Handle normals that point into wrong direction gracefully
|
||||
# (mostly due to mehtod not caring about this in generation)
|
||||
normals_dot_product = np.abs(normals_dot_product)
|
||||
else:
|
||||
normals_dot_product = np.array(
|
||||
[np.nan] * points_src.shape[0], dtype=np.float32)
|
||||
normals_dot_product = np.array([np.nan] * points_src.shape[0], dtype=np.float32)
|
||||
return dist, normals_dot_product
|
||||
|
||||
|
||||
def get_threshold_percentage(dist, thresholds):
|
||||
''' Evaluates a point cloud.
|
||||
"""Evaluates a point cloud.
|
||||
|
||||
Args:
|
||||
dist (numpy array): calculated distance
|
||||
thresholds (numpy array): threshold values for the F-score calculation
|
||||
'''
|
||||
in_threshold = [
|
||||
(dist <= t).mean() for t in thresholds
|
||||
]
|
||||
thresholds (numpy array): threshold values for the F-score calculation.
|
||||
"""
|
||||
in_threshold = [(dist <= t).mean() for t in thresholds]
|
||||
return in_threshold
|
|
@ -1,11 +1,12 @@
|
|||
import torch
|
||||
import time
|
||||
import trimesh
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
from src.utils import mc_from_psr
|
||||
|
||||
class Generator3D(object):
|
||||
''' Generator class for Occupancy Networks.
|
||||
|
||||
class Generator3D:
|
||||
"""Generator class for Occupancy Networks.
|
||||
|
||||
It provides functions to generate the final mesh as well refining options.
|
||||
|
||||
|
@ -17,11 +18,20 @@ class Generator3D(object):
|
|||
padding (float): how much padding should be used for MISE
|
||||
sample (bool): whether z should be sampled
|
||||
input_type (str): type of input
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, model, points_batch_size=100000,
|
||||
threshold=0.5, device=None, padding=0.1,
|
||||
sample=False, input_type = None, dpsr=None, psr_tanh=True):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
points_batch_size=100000,
|
||||
threshold=0.5,
|
||||
device=None,
|
||||
padding=0.1,
|
||||
sample=False,
|
||||
input_type=None,
|
||||
dpsr=None,
|
||||
psr_tanh=True,
|
||||
):
|
||||
self.model = model.to(device)
|
||||
self.points_batch_size = points_batch_size
|
||||
self.threshold = threshold
|
||||
|
@ -33,29 +43,28 @@ class Generator3D(object):
|
|||
self.psr_tanh = psr_tanh
|
||||
|
||||
def generate_mesh(self, data, return_stats=True):
|
||||
''' Generates the output mesh.
|
||||
"""Generates the output mesh.
|
||||
|
||||
Args:
|
||||
data (tensor): data tensor
|
||||
return_stats (bool): whether stats should be returned
|
||||
'''
|
||||
"""
|
||||
self.model.eval()
|
||||
device = self.device
|
||||
stats_dict = {}
|
||||
|
||||
p = data.get('inputs', torch.empty(1, 0)).to(device)
|
||||
p = data.get("inputs", torch.empty(1, 0)).to(device)
|
||||
|
||||
t0 = time.time()
|
||||
points, normals = self.model(p)
|
||||
t1 = time.time()
|
||||
psr_grid = self.dpsr(points, normals)
|
||||
t2 = time.time()
|
||||
v, f, _ = mc_from_psr(psr_grid,
|
||||
zero_level=self.threshold)
|
||||
stats_dict['pcl'] = t1 - t0
|
||||
stats_dict['dpsr'] = t2 - t1
|
||||
stats_dict['mc'] = time.time() - t2
|
||||
stats_dict['total'] = time.time() - t0
|
||||
v, f, _ = mc_from_psr(psr_grid, zero_level=self.threshold)
|
||||
stats_dict["pcl"] = t1 - t0
|
||||
stats_dict["dpsr"] = t2 - t1
|
||||
stats_dict["mc"] = time.time() - t2
|
||||
stats_dict["total"] = time.time() - t0
|
||||
|
||||
if return_stats:
|
||||
return v, f, points, normals, stats_dict
|
||||
|
|
90
src/model.py
90
src/model.py
|
@ -1,18 +1,17 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import time
|
||||
from src.utils import point_rasterize, grid_interp, mc_from_psr, \
|
||||
calc_inters_points
|
||||
from src.dpsr import DPSR
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from src.network import encoder_dict, decoder_dict
|
||||
|
||||
from src.network import decoder_dict, encoder_dict
|
||||
from src.network.utils import map2local
|
||||
from src.utils import calc_inters_points, grid_interp, mc_from_psr, point_rasterize
|
||||
|
||||
|
||||
class PSR2Mesh(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, psr_grid):
|
||||
"""
|
||||
In the forward pass we receive a Tensor containing the input and return
|
||||
"""In the forward pass we receive a Tensor containing the input and return
|
||||
a Tensor containing the output. ctx is a context object that can be used
|
||||
to stash information for backward computation. You can cache arbitrary
|
||||
objects for use in the backward pass using the ctx.save_for_backward method.
|
||||
|
@ -29,8 +28,7 @@ class PSR2Mesh(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals):
|
||||
"""
|
||||
In the backward pass we receive a Tensor containing the gradient of the loss
|
||||
"""In the backward pass we receive a Tensor containing the gradient of the loss
|
||||
with respect to the output, and we need to compute the gradient of the loss
|
||||
with respect to the input.
|
||||
"""
|
||||
|
@ -43,12 +41,12 @@ class PSR2Mesh(torch.autograd.Function):
|
|||
|
||||
return grad_grid
|
||||
|
||||
|
||||
class PSR2SurfacePoints(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample):
|
||||
verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
|
||||
verts = verts * 2. - 1. # within the range of [-1, 1]
|
||||
|
||||
verts = verts * 2.0 - 1.0 # within the range of [-1, 1]
|
||||
|
||||
p_all, n_all, mask_all = [], [], []
|
||||
|
||||
|
@ -67,7 +65,6 @@ class PSR2SurfacePoints(torch.autograd.Function):
|
|||
n_inters_all = torch.cat(n_all, dim=0)
|
||||
mask_visible = torch.stack(mask_all, dim=0)
|
||||
|
||||
|
||||
res = torch.tensor(psr_grid.detach().shape[2])
|
||||
ctx.save_for_backward(p_inters_all, n_inters_all, res)
|
||||
|
||||
|
@ -80,30 +77,31 @@ class PSR2SurfacePoints(torch.autograd.Function):
|
|||
|
||||
# grad from the p_inters via MLP renderer
|
||||
grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None])
|
||||
grad_grid_pts = point_rasterize((pts[None]+1)/2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res
|
||||
grad_grid_pts = point_rasterize((pts[None] + 1) / 2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res
|
||||
|
||||
return grad_grid_pts, None, None, None, None, None
|
||||
|
||||
|
||||
class Encode2Points(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
encoder = cfg['model']['encoder']
|
||||
decoder = cfg['model']['decoder']
|
||||
dim = cfg['data']['dim'] # input dim
|
||||
c_dim = cfg['model']['c_dim']
|
||||
encoder_kwargs = cfg['model']['encoder_kwargs']
|
||||
if encoder_kwargs == None:
|
||||
encoder = cfg["model"]["encoder"]
|
||||
decoder = cfg["model"]["decoder"]
|
||||
dim = cfg["data"]["dim"] # input dim
|
||||
c_dim = cfg["model"]["c_dim"]
|
||||
encoder_kwargs = cfg["model"]["encoder_kwargs"]
|
||||
if encoder_kwargs is None:
|
||||
encoder_kwargs = {}
|
||||
decoder_kwargs = cfg['model']['decoder_kwargs']
|
||||
padding = cfg['data']['padding']
|
||||
self.predict_normal = cfg['model']['predict_normal']
|
||||
self.predict_offset = cfg['model']['predict_offset']
|
||||
decoder_kwargs = cfg["model"]["decoder_kwargs"]
|
||||
cfg["data"]["padding"]
|
||||
self.predict_normal = cfg["model"]["predict_normal"]
|
||||
self.predict_offset = cfg["model"]["predict_offset"]
|
||||
|
||||
out_dim = 3
|
||||
out_dim_offset = 3
|
||||
num_offset = cfg['data']['num_offset']
|
||||
num_offset = cfg["data"]["num_offset"]
|
||||
# each point predict more than one offset to add output points
|
||||
if num_offset > 1:
|
||||
out_dim_offset = out_dim * num_offset
|
||||
|
@ -111,43 +109,39 @@ class Encode2Points(nn.Module):
|
|||
|
||||
# local mapping
|
||||
self.map2local = None
|
||||
if cfg['model']['local_coord']:
|
||||
if 'unet' in encoder_kwargs.keys():
|
||||
unit_size = 1 / encoder_kwargs['plane_resolution']
|
||||
if cfg["model"]["local_coord"]:
|
||||
if "unet" in encoder_kwargs.keys():
|
||||
unit_size = 1 / encoder_kwargs["plane_resolution"]
|
||||
else:
|
||||
unit_size = 1 / encoder_kwargs['grid_resolution']
|
||||
unit_size = 1 / encoder_kwargs["grid_resolution"]
|
||||
|
||||
local_mapping = map2local(unit_size)
|
||||
|
||||
self.encoder = encoder_dict[encoder](
|
||||
dim=dim, c_dim=c_dim, map2local=local_mapping,
|
||||
**encoder_kwargs
|
||||
dim=dim,
|
||||
c_dim=c_dim,
|
||||
map2local=local_mapping,
|
||||
**encoder_kwargs,
|
||||
)
|
||||
|
||||
if self.predict_normal:
|
||||
# decoder for normal prediction
|
||||
self.decoder_normal = decoder_dict[decoder](
|
||||
dim=dim, c_dim=c_dim, out_dim=out_dim,
|
||||
**decoder_kwargs)
|
||||
self.decoder_normal = decoder_dict[decoder](dim=dim, c_dim=c_dim, out_dim=out_dim, **decoder_kwargs)
|
||||
if self.predict_offset:
|
||||
# decoder for offset prediction
|
||||
self.decoder_offset = decoder_dict[decoder](
|
||||
dim=dim, c_dim=c_dim, out_dim=out_dim_offset,
|
||||
map2local=local_mapping,
|
||||
**decoder_kwargs)
|
||||
|
||||
self.s_off = cfg['model']['s_offset']
|
||||
dim=dim, c_dim=c_dim, out_dim=out_dim_offset, map2local=local_mapping, **decoder_kwargs,
|
||||
)
|
||||
|
||||
self.s_off = cfg["model"]["s_offset"]
|
||||
|
||||
def forward(self, p):
|
||||
''' Performs a forward pass through the network.
|
||||
"""Performs a forward pass through the network.
|
||||
|
||||
Args:
|
||||
p (tensor): input unoriented points
|
||||
'''
|
||||
|
||||
"""
|
||||
time_dict = {}
|
||||
mask = None
|
||||
|
||||
batch_size = p.size(0)
|
||||
points = p.clone()
|
||||
|
@ -169,13 +163,11 @@ class Encode2Points(nn.Module):
|
|||
normals = self.decoder_normal(points, c)
|
||||
t2 = time.perf_counter()
|
||||
|
||||
time_dict['encode'] = t1 - t0
|
||||
time_dict['predict'] = t2 - t1
|
||||
time_dict["encode"] = t1 - t0
|
||||
time_dict["predict"] = t2 - t1
|
||||
|
||||
points = torch.clamp(points, 0.0, 0.99)
|
||||
if self.cfg['model']['normal_normalize']:
|
||||
normals = normals / (normals.norm(dim=-1, keepdim=True)+1e-8)
|
||||
|
||||
if self.cfg["model"]["normal_normalize"]:
|
||||
normals = normals / (normals.norm(dim=-1, keepdim=True) + 1e-8)
|
||||
|
||||
return points, normals
|
||||
|
|
@ -1,17 +1,19 @@
|
|||
import torch
|
||||
from pytorch3d.renderer import (
|
||||
MeshRasterizer,
|
||||
MeshRenderer,
|
||||
PerspectiveCameras,
|
||||
RasterizationSettings,
|
||||
SoftSilhouetteShader,
|
||||
)
|
||||
from pytorch3d.structures import Meshes
|
||||
|
||||
from src.network.net_rgb import RenderingNetwork
|
||||
from src.utils import approx_psr_grad
|
||||
from pytorch3d.renderer import (
|
||||
RasterizationSettings,
|
||||
PerspectiveCameras,
|
||||
MeshRenderer,
|
||||
MeshRasterizer,
|
||||
SoftSilhouetteShader)
|
||||
from pytorch3d.structures import Meshes
|
||||
|
||||
|
||||
def approx_psr_grad(psr_grid, res, normalize=True):
|
||||
delta_x = delta_y = delta_z = 1/res
|
||||
delta_x = delta_y = delta_z = 1 / res
|
||||
psr_pad = torch.nn.ReplicationPad3d(1)(psr_grid).squeeze()
|
||||
|
||||
grad_x = (psr_pad[2:, :, :] - psr_pad[:-2, :, :]) / 2 / delta_x
|
||||
|
@ -35,37 +37,37 @@ class SAP2Image(nn.Module):
|
|||
self.psr2sur = PSR2SurfacePoints.apply
|
||||
self.psr2mesh = PSR2Mesh.apply
|
||||
# initialize DPSR
|
||||
self.dpsr = DPSR(res=(cfg['model']['grid_res'],
|
||||
cfg['model']['grid_res'],
|
||||
cfg['model']['grid_res']),
|
||||
sig=cfg['model']['psr_sigma'])
|
||||
self.dpsr = DPSR(
|
||||
res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
|
||||
sig=cfg["model"]["psr_sigma"],
|
||||
)
|
||||
self.cfg = cfg
|
||||
if cfg['train']['l_weight']['rgb'] != 0.:
|
||||
self.rendering_network = RenderingNetwork(**cfg['model']['renderer'])
|
||||
if cfg["train"]["l_weight"]["rgb"] != 0.0:
|
||||
self.rendering_network = RenderingNetwork(**cfg["model"]["renderer"])
|
||||
|
||||
if cfg['train']['l_weight']['mask'] != 0.:
|
||||
if cfg["train"]["l_weight"]["mask"] != 0.0:
|
||||
# initialize rasterizer
|
||||
sigma = 1e-4
|
||||
raster_settings_soft = RasterizationSettings(
|
||||
image_size=img_size,
|
||||
blur_radius=np.log(1. / 1e-4 - 1.)*sigma,
|
||||
blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma,
|
||||
faces_per_pixel=150,
|
||||
perspective_correct=False
|
||||
perspective_correct=False,
|
||||
)
|
||||
|
||||
# initialize silhouette renderer
|
||||
self.mesh_rasterizer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(
|
||||
raster_settings=raster_settings_soft
|
||||
raster_settings=raster_settings_soft,
|
||||
),
|
||||
shader=SoftSilhouetteShader()
|
||||
shader=SoftSilhouetteShader(),
|
||||
)
|
||||
|
||||
self.cfg = cfg
|
||||
self.img_size = img_size
|
||||
|
||||
def forward(self, inputs, data):
|
||||
points, normals = inputs[...,:3], inputs[...,3:]
|
||||
points, normals = inputs[..., :3], inputs[..., 3:]
|
||||
points = torch.sigmoid(points)
|
||||
normals = normals / normals.norm(dim=-1, keepdim=True)
|
||||
|
||||
|
@ -76,35 +78,36 @@ class SAP2Image(nn.Module):
|
|||
return self.render_img(psr_grid, data)
|
||||
|
||||
def render_img(self, psr_grid, data):
|
||||
n_views = len(data["masks"])
|
||||
n_views_per_iter = self.cfg["data"]["n_views_per_iter"]
|
||||
|
||||
n_views = len(data['masks'])
|
||||
n_views_per_iter = self.cfg['data']['n_views_per_iter']
|
||||
|
||||
rgb_render_mode = self.cfg['model']['renderer']['mode']
|
||||
uv = data['uv']
|
||||
self.cfg["model"]["renderer"]["mode"]
|
||||
uv = data["uv"]
|
||||
|
||||
idx = np.random.randint(0, n_views, n_views_per_iter)
|
||||
pose = [data['poses'][i] for i in idx]
|
||||
rgb = data['rgbs'][idx]
|
||||
mask_gt = data['masks'][idx]
|
||||
pose = [data["poses"][i] for i in idx]
|
||||
rgb = data["rgbs"][idx]
|
||||
mask_gt = data["masks"][idx]
|
||||
ray = None
|
||||
pred_rgb = None
|
||||
pred_mask = None
|
||||
|
||||
if self.cfg['train']['l_weight']['rgb'] != 0.:
|
||||
psr_grad = approx_psr_grad(psr_grid, self.cfg['model']['grid_res'])
|
||||
if self.cfg["train"]["l_weight"]["rgb"] != 0.0:
|
||||
psr_grad = approx_psr_grad(psr_grid, self.cfg["model"]["grid_res"])
|
||||
p_inters, visible_mask = self.psr2sur(psr_grid, pose, self.img_size, uv, psr_grad, None)
|
||||
n_inters = grid_interp(psr_grad[None], (p_inters.detach()[None] + 1) / 2)
|
||||
fea_interp = None
|
||||
if 'rays' in data.keys():
|
||||
ray = data['rays'].squeeze()[idx][visible_mask]
|
||||
pred_rgb = self.rendering_network(p_inters, normals=n_inters.squeeze(), view_dirs=ray, feature_vectors=fea_interp)
|
||||
if "rays" in data.keys():
|
||||
ray = data["rays"].squeeze()[idx][visible_mask]
|
||||
pred_rgb = self.rendering_network(
|
||||
p_inters, normals=n_inters.squeeze(), view_dirs=ray, feature_vectors=fea_interp,
|
||||
)
|
||||
|
||||
# silhouette loss
|
||||
if self.cfg['train']['l_weight']['mask'] != 0.:
|
||||
if self.cfg["train"]["l_weight"]["mask"] != 0.0:
|
||||
# build mesh
|
||||
v, f, _ = self.psr2mesh(psr_grid)
|
||||
v = v * 2. - 1 # within the range of [-1, 1]
|
||||
v = v * 2.0 - 1 # within the range of [-1, 1]
|
||||
# ! Fast but more GPU usage
|
||||
mesh = Meshes(verts=[v.squeeze()], faces=[f.squeeze()])
|
||||
if True:
|
||||
|
@ -114,11 +117,7 @@ class SAP2Image(nn.Module):
|
|||
T = torch.cat([p.T for p in pose], dim=0)
|
||||
focal = torch.cat([p.focal_length for p in pose], dim=0)
|
||||
pp = torch.cat([p.principal_point for p in pose], dim=0)
|
||||
pose_cur = PerspectiveCameras(
|
||||
focal_length=focal,
|
||||
principal_point=pp,
|
||||
R=R, T=T,
|
||||
device=R.device)
|
||||
pose_cur = PerspectiveCameras(focal_length=focal, principal_point=pp, R=R, T=T, device=R.device)
|
||||
pred_mask = self.mesh_rasterizer(mesh.extend(n_views_per_iter), cameras=pose_cur)[..., 3]
|
||||
else:
|
||||
pred_mask = []
|
||||
|
@ -129,11 +128,11 @@ class SAP2Image(nn.Module):
|
|||
pred_mask = torch.cat(pred_mask, dim=0)
|
||||
|
||||
output = {
|
||||
'rgb': pred_rgb,
|
||||
'rgb_gt': rgb,
|
||||
'mask': pred_mask,
|
||||
'mask_gt': mask_gt,
|
||||
'vis_mask': visible_mask,
|
||||
"rgb": pred_rgb,
|
||||
"rgb_gt": rgb,
|
||||
"mask": pred_mask,
|
||||
"mask_gt": mask_gt,
|
||||
"vis_mask": visible_mask,
|
||||
}
|
||||
|
||||
return output
|
|
@ -1,8 +1,8 @@
|
|||
from src.network import encoder, decoder
|
||||
from src.network import decoder, encoder
|
||||
|
||||
encoder_dict = {
|
||||
'local_pool_pointnet': encoder.LocalPoolPointnet,
|
||||
"local_pool_pointnet": encoder.LocalPoolPointnet,
|
||||
}
|
||||
decoder_dict = {
|
||||
'simple_local': decoder.LocalDecoder,
|
||||
"simple_local": decoder.LocalDecoder,
|
||||
}
|
|
@ -1,15 +1,17 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from ipdb import set_trace as st
|
||||
from src.network.utils import normalize_3d_coordinate, ResnetBlockFC, \
|
||||
normalize_coordinate
|
||||
|
||||
from src.network.utils import (
|
||||
ResnetBlockFC,
|
||||
normalize_3d_coordinate,
|
||||
normalize_coordinate,
|
||||
)
|
||||
|
||||
|
||||
class LocalDecoder(nn.Module):
|
||||
''' Decoder.
|
||||
"""Decoder.
|
||||
Instead of conditioning on global features, on plane/volume local features.
|
||||
|
||||
Args:
|
||||
dim (int): input dimension
|
||||
c_dim (int): dimension of latent conditioned code c
|
||||
|
@ -17,26 +19,31 @@ class LocalDecoder(nn.Module):
|
|||
n_blocks (int): number of blocks ResNetBlockFC layers
|
||||
leaky (bool): whether to use leaky ReLUs
|
||||
sample_mode (str): sampling feature strategy, bilinear|nearest
|
||||
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
||||
'''
|
||||
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55].
|
||||
"""
|
||||
|
||||
def __init__(self, dim=3, c_dim=128, out_dim=3,
|
||||
hidden_size=256, n_blocks=5, leaky=False, sample_mode='bilinear', padding=0.1, map2local=None):
|
||||
def __init__(
|
||||
self,
|
||||
dim=3,
|
||||
c_dim=128,
|
||||
out_dim=3,
|
||||
hidden_size=256,
|
||||
n_blocks=5,
|
||||
leaky=False,
|
||||
sample_mode="bilinear",
|
||||
padding=0.1,
|
||||
map2local=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.c_dim = c_dim
|
||||
self.n_blocks = n_blocks
|
||||
|
||||
if c_dim != 0:
|
||||
self.fc_c = nn.ModuleList([
|
||||
nn.Linear(c_dim, hidden_size) for i in range(n_blocks)
|
||||
])
|
||||
|
||||
self.fc_c = nn.ModuleList([nn.Linear(c_dim, hidden_size) for i in range(n_blocks)])
|
||||
|
||||
self.fc_p = nn.Linear(dim, hidden_size)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
ResnetBlockFC(hidden_size) for i in range(n_blocks)
|
||||
])
|
||||
self.blocks = nn.ModuleList([ResnetBlockFC(hidden_size) for i in range(n_blocks)])
|
||||
|
||||
self.fc_out = nn.Linear(hidden_size, out_dim)
|
||||
|
||||
|
@ -50,14 +57,11 @@ class LocalDecoder(nn.Module):
|
|||
self.map2local = map2local
|
||||
self.out_dim = out_dim
|
||||
|
||||
|
||||
def sample_plane_feature(self, p, c, plane='xz'):
|
||||
def sample_plane_feature(self, p, c, plane="xz"):
|
||||
xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
|
||||
xy = xy[:, :, None].float()
|
||||
vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
|
||||
c = F.grid_sample(c, vgrid, padding_mode='border',
|
||||
align_corners=True,
|
||||
mode=self.sample_mode).squeeze(-1)
|
||||
c = F.grid_sample(c, vgrid, padding_mode="border", align_corners=True, mode=self.sample_mode).squeeze(-1)
|
||||
return c
|
||||
|
||||
def sample_grid_feature(self, p, c):
|
||||
|
@ -65,23 +69,25 @@ class LocalDecoder(nn.Module):
|
|||
p_nor = p_nor[:, :, None, None].float()
|
||||
vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1)
|
||||
# acutally trilinear interpolation if mode = 'bilinear'
|
||||
c = F.grid_sample(c, vgrid, padding_mode='border',
|
||||
align_corners=True,
|
||||
mode=self.sample_mode).squeeze(-1).squeeze(-1)
|
||||
c = (
|
||||
F.grid_sample(c, vgrid, padding_mode="border", align_corners=True, mode=self.sample_mode)
|
||||
.squeeze(-1)
|
||||
.squeeze(-1)
|
||||
)
|
||||
return c
|
||||
|
||||
def forward(self, p, c_plane, **kwargs):
|
||||
batch_size = p.shape[0]
|
||||
plane_type = list(c_plane.keys())
|
||||
c = 0
|
||||
if 'grid' in plane_type:
|
||||
c += self.sample_grid_feature(p, c_plane['grid'])
|
||||
if 'xz' in plane_type:
|
||||
c += self.sample_plane_feature(p, c_plane['xz'], plane='xz')
|
||||
if 'xy' in plane_type:
|
||||
c += self.sample_plane_feature(p, c_plane['xy'], plane='xy')
|
||||
if 'yz' in plane_type:
|
||||
c += self.sample_plane_feature(p, c_plane['yz'], plane='yz')
|
||||
if "grid" in plane_type:
|
||||
c += self.sample_grid_feature(p, c_plane["grid"])
|
||||
if "xz" in plane_type:
|
||||
c += self.sample_plane_feature(p, c_plane["xz"], plane="xz")
|
||||
if "xy" in plane_type:
|
||||
c += self.sample_plane_feature(p, c_plane["xy"], plane="xy")
|
||||
if "yz" in plane_type:
|
||||
c += self.sample_plane_feature(p, c_plane["yz"], plane="yz")
|
||||
c = c.transpose(1, 2)
|
||||
|
||||
p = p.float()
|
||||
|
@ -99,7 +105,6 @@ class LocalDecoder(nn.Module):
|
|||
|
||||
out = self.fc_out(self.actvn(net))
|
||||
|
||||
|
||||
if self.out_dim > 3:
|
||||
out = out.reshape(batch_size, -1, 3)
|
||||
|
||||
|
|
|
@ -1,17 +1,23 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from src.network.unet3d import UNet3D
|
||||
from torch_scatter import scatter_max, scatter_mean
|
||||
|
||||
from src.network.unet import UNet
|
||||
from ipdb import set_trace as st
|
||||
from torch_scatter import scatter_mean, scatter_max
|
||||
from src.network.utils import get_embedder, normalize_3d_coordinate,\
|
||||
coordinate2index, ResnetBlockFC, normalize_coordinate
|
||||
from src.network.unet3d import UNet3D
|
||||
from src.network.utils import (
|
||||
ResnetBlockFC,
|
||||
coordinate2index,
|
||||
get_embedder,
|
||||
normalize_3d_coordinate,
|
||||
normalize_coordinate,
|
||||
)
|
||||
|
||||
|
||||
class LocalPoolPointnet(nn.Module):
|
||||
''' PointNet-based encoder network with ResNet blocks for each point.
|
||||
"""PointNet-based encoder network with ResNet blocks for each point.
|
||||
Number of input points are fixed.
|
||||
|
||||
|
||||
Args:
|
||||
c_dim (int): dimension of latent code c
|
||||
dim (int): input points dimension
|
||||
|
@ -28,20 +34,32 @@ class LocalPoolPointnet(nn.Module):
|
|||
n_blocks (int): number of blocks ResNetBlockFC layers
|
||||
map2local (function): map global coordintes to local ones
|
||||
pos_encoding (int): frequency for the positional encoding
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max',
|
||||
unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None,
|
||||
plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5,
|
||||
map2local=None, pos_encoding=0):
|
||||
def __init__(
|
||||
self,
|
||||
c_dim=128,
|
||||
dim=3,
|
||||
hidden_dim=128,
|
||||
scatter_type="max",
|
||||
unet=False,
|
||||
unet_kwargs=None,
|
||||
unet3d=False,
|
||||
unet3d_kwargs=None,
|
||||
plane_resolution=None,
|
||||
grid_resolution=None,
|
||||
plane_type="xz",
|
||||
padding=0.1,
|
||||
n_blocks=5,
|
||||
map2local=None,
|
||||
pos_encoding=0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.c_dim = c_dim
|
||||
|
||||
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
|
||||
self.blocks = nn.ModuleList([
|
||||
ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
|
||||
])
|
||||
self.fc_pos = nn.Linear(dim, 2 * hidden_dim)
|
||||
self.blocks = nn.ModuleList([ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks)])
|
||||
self.fc_c = nn.Linear(hidden_dim, c_dim)
|
||||
|
||||
self.actvn = nn.ReLU()
|
||||
|
@ -60,24 +78,24 @@ class LocalPoolPointnet(nn.Module):
|
|||
self.plane_type = plane_type
|
||||
self.padding = padding
|
||||
|
||||
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
|
||||
self.fc_pos = nn.Linear(dim, 2 * hidden_dim)
|
||||
self.pe = None
|
||||
if pos_encoding > 0:
|
||||
embed_fn, input_ch = get_embedder(pos_encoding, d_in=dim)
|
||||
self.pe = embed_fn
|
||||
self.fc_pos = nn.Linear(input_ch, 2*hidden_dim)
|
||||
self.fc_pos = nn.Linear(input_ch, 2 * hidden_dim)
|
||||
|
||||
self.map2local = map2local
|
||||
|
||||
|
||||
if scatter_type == 'max':
|
||||
if scatter_type == "max":
|
||||
self.scatter = scatter_max
|
||||
elif scatter_type == 'mean':
|
||||
elif scatter_type == "mean":
|
||||
self.scatter = scatter_mean
|
||||
else:
|
||||
raise ValueError('incorrect scatter type')
|
||||
msg = "incorrect scatter type"
|
||||
raise ValueError(msg)
|
||||
|
||||
def generate_plane_features(self, p, c, plane='xz'):
|
||||
def generate_plane_features(self, p, c, plane="xz"):
|
||||
# acquire indices of features in plane
|
||||
xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
|
||||
index = coordinate2index(xy, self.reso_plane)
|
||||
|
@ -86,7 +104,9 @@ class LocalPoolPointnet(nn.Module):
|
|||
fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
|
||||
c = c.permute(0, 2, 1) # B x 512 x T
|
||||
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
|
||||
fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)
|
||||
fea_plane = fea_plane.reshape(
|
||||
p.size(0), self.c_dim, self.reso_plane, self.reso_plane,
|
||||
) # sparce matrix (B x 512 x reso x reso)
|
||||
|
||||
# process the plane features with UNet
|
||||
if self.unet is not None:
|
||||
|
@ -96,12 +116,14 @@ class LocalPoolPointnet(nn.Module):
|
|||
|
||||
def generate_grid_features(self, p, c):
|
||||
p_nor = normalize_3d_coordinate(p.clone())
|
||||
index = coordinate2index(p_nor, self.reso_grid, coord_type='3d')
|
||||
index = coordinate2index(p_nor, self.reso_grid, coord_type="3d")
|
||||
# scatter grid features from points
|
||||
fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3)
|
||||
c = c.permute(0, 2, 1)
|
||||
fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3
|
||||
fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) # sparce matrix (B x 512 x reso x reso)
|
||||
fea_grid = fea_grid.reshape(
|
||||
p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid,
|
||||
) # sparce matrix (B x 512 x reso x reso)
|
||||
|
||||
if self.unet3d is not None:
|
||||
fea_grid = self.unet3d(fea_grid)
|
||||
|
@ -115,7 +137,7 @@ class LocalPoolPointnet(nn.Module):
|
|||
c_out = 0
|
||||
for key in keys:
|
||||
# scatter plane features from points
|
||||
if key == 'grid':
|
||||
if key == "grid":
|
||||
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid**3)
|
||||
else:
|
||||
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2)
|
||||
|
@ -126,29 +148,27 @@ class LocalPoolPointnet(nn.Module):
|
|||
c_out += fea
|
||||
return c_out.permute(0, 2, 1)
|
||||
|
||||
|
||||
def forward(self, p, normalize=True):
|
||||
batch_size, T, D = p.size()
|
||||
|
||||
# acquire the index for each point
|
||||
coord = {}
|
||||
index = {}
|
||||
if 'xz' in self.plane_type:
|
||||
coord['xz'] = normalize_coordinate(p.clone(), plane='xz')
|
||||
index['xz'] = coordinate2index(coord['xz'], self.reso_plane)
|
||||
if 'xy' in self.plane_type:
|
||||
coord['xy'] = normalize_coordinate(p.clone(), plane='xy')
|
||||
index['xy'] = coordinate2index(coord['xy'], self.reso_plane)
|
||||
if 'yz' in self.plane_type:
|
||||
coord['yz'] = normalize_coordinate(p.clone(), plane='yz')
|
||||
index['yz'] = coordinate2index(coord['yz'], self.reso_plane)
|
||||
if 'grid' in self.plane_type:
|
||||
if "xz" in self.plane_type:
|
||||
coord["xz"] = normalize_coordinate(p.clone(), plane="xz")
|
||||
index["xz"] = coordinate2index(coord["xz"], self.reso_plane)
|
||||
if "xy" in self.plane_type:
|
||||
coord["xy"] = normalize_coordinate(p.clone(), plane="xy")
|
||||
index["xy"] = coordinate2index(coord["xy"], self.reso_plane)
|
||||
if "yz" in self.plane_type:
|
||||
coord["yz"] = normalize_coordinate(p.clone(), plane="yz")
|
||||
index["yz"] = coordinate2index(coord["yz"], self.reso_plane)
|
||||
if "grid" in self.plane_type:
|
||||
if normalize:
|
||||
coord['grid'] = normalize_3d_coordinate(p.clone())
|
||||
coord["grid"] = normalize_3d_coordinate(p.clone())
|
||||
else:
|
||||
coord['grid'] = p.clone()[...,:3]
|
||||
index['grid'] = coordinate2index(coord['grid'], self.reso_grid, coord_type='3d')
|
||||
|
||||
coord["grid"] = p.clone()[..., :3]
|
||||
index["grid"] = coordinate2index(coord["grid"], self.reso_grid, coord_type="3d")
|
||||
|
||||
if self.pe:
|
||||
p = self.pe(p)
|
||||
|
@ -169,13 +189,13 @@ class LocalPoolPointnet(nn.Module):
|
|||
c = self.fc_c(net)
|
||||
|
||||
fea = {}
|
||||
if 'grid' in self.plane_type:
|
||||
fea['grid'] = self.generate_grid_features(p, c)
|
||||
if 'xz' in self.plane_type:
|
||||
fea['xz'] = self.generate_plane_features(p, c, plane='xz')
|
||||
if 'xy' in self.plane_type:
|
||||
fea['xy'] = self.generate_plane_features(p, c, plane='xy')
|
||||
if 'yz' in self.plane_type:
|
||||
fea['yz'] = self.generate_plane_features(p, c, plane='yz')
|
||||
if "grid" in self.plane_type:
|
||||
fea["grid"] = self.generate_grid_features(p, c)
|
||||
if "xz" in self.plane_type:
|
||||
fea["xz"] = self.generate_plane_features(p, c, plane="xz")
|
||||
if "xy" in self.plane_type:
|
||||
fea["xy"] = self.generate_plane_features(p, c, plane="xy")
|
||||
if "yz" in self.plane_type:
|
||||
fea["yz"] = self.generate_plane_features(p, c, plane="yz")
|
||||
|
||||
return fea
|
|
@ -1,39 +1,41 @@
|
|||
# code from IDR (https://github.com/lioryariv/idr/blob/main/code/model/implicit_differentiable_renderer.py)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from src.network.utils import get_embedder
|
||||
from pdb import set_trace as st
|
||||
|
||||
|
||||
class RenderingNetwork(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
fea_size=0,
|
||||
mode='naive',
|
||||
mode="naive",
|
||||
d_out=3,
|
||||
dims=[512, 512, 512, 512],
|
||||
weight_norm=True,
|
||||
pe_freq_view=0 # for positional encoding
|
||||
pe_freq_view=0, # for positional encoding
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.mode = mode
|
||||
if mode == 'naive':
|
||||
if mode == "naive":
|
||||
d_in = 3
|
||||
elif mode == 'no_feature':
|
||||
elif mode == "no_feature":
|
||||
d_in = 3 + 3 + 3
|
||||
fea_size = 0
|
||||
elif mode == 'full':
|
||||
elif mode == "full":
|
||||
d_in = 3 + 3 + 3
|
||||
else:
|
||||
d_in = 3 + 3
|
||||
dims = [d_in + fea_size] + dims + [d_out]
|
||||
dims = [d_in + fea_size, *dims, d_out]
|
||||
|
||||
self.embedview_fn = None
|
||||
if pe_freq_view > 0:
|
||||
embedview_fn, input_ch = get_embedder(pe_freq_view, d_in=3)
|
||||
self.embedview_fn = embedview_fn
|
||||
dims[0] += (input_ch - 3)
|
||||
dims[0] += input_ch - 3
|
||||
|
||||
self.num_layers = len(dims)
|
||||
|
||||
|
@ -54,13 +56,13 @@ class RenderingNetwork(nn.Module):
|
|||
view_dirs = self.embedview_fn(view_dirs)
|
||||
# points = self.embedview_fn(points)
|
||||
|
||||
if (self.mode == 'full') & (feature_vectors is not None):
|
||||
if (self.mode == "full") & (feature_vectors is not None):
|
||||
rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
|
||||
elif (self.mode == 'no_feature') | ((self.mode == 'full') & (feature_vectors is None)):
|
||||
elif (self.mode == "no_feature") | ((self.mode == "full") & (feature_vectors is None)):
|
||||
rendering_input = torch.cat([points, view_dirs, normals], dim=-1)
|
||||
elif self.mode == 'no_view_dir':
|
||||
elif self.mode == "no_view_dir":
|
||||
rendering_input = torch.cat([points, normals], dim=-1)
|
||||
elif self.mode == 'no_normal':
|
||||
elif self.mode == "no_normal":
|
||||
rendering_input = torch.cat([points, view_dirs], dim=-1)
|
||||
else:
|
||||
rendering_input = points
|
||||
|
@ -83,25 +85,24 @@ class NeRFRenderingNetwork(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
feature_vector_size=0,
|
||||
mode='naive',
|
||||
mode="naive",
|
||||
d_in=3,
|
||||
d_out=3,
|
||||
dims=[512, 512, 512, 256],
|
||||
weight_norm=True,
|
||||
multires=0, # positional encoding of points
|
||||
multires_view=0 # positional encoding of view
|
||||
multires_view=0, # positional encoding of view
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.mode = mode
|
||||
dims = [d_in + feature_vector_size] + dims
|
||||
|
||||
dims = [d_in + feature_vector_size, *dims]
|
||||
|
||||
self.embed_fn = None
|
||||
if multires > 0:
|
||||
embed_fn, input_ch = get_embedder(multires, d_in=d_in)
|
||||
self.embed_fn = embed_fn
|
||||
dims[0] += (input_ch - 3)
|
||||
dims[0] += input_ch - 3
|
||||
|
||||
self.num_layers = len(dims)
|
||||
|
||||
|
@ -113,13 +114,12 @@ class NeRFRenderingNetwork(nn.Module):
|
|||
self.embedview_fn = embedview_fn
|
||||
# dims[0] += (input_ch - 3)
|
||||
|
||||
if mode == 'full':
|
||||
self.view_net = nn.ModuleList([nn.Linear(dims[-1]+view_ch, 128)])
|
||||
if mode == "full":
|
||||
self.view_net = nn.ModuleList([nn.Linear(dims[-1] + view_ch, 128)])
|
||||
self.rgb_net = nn.Linear(128, 3)
|
||||
else:
|
||||
self.rgb_net = nn.Linear(dims[-1], 3)
|
||||
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.tanh = nn.Tanh()
|
||||
|
||||
|
@ -134,7 +134,7 @@ class NeRFRenderingNetwork(nn.Module):
|
|||
x = net(x)
|
||||
x = self.relu(x)
|
||||
|
||||
if self.mode=='full':
|
||||
if self.mode == "full":
|
||||
x = torch.cat([x, view_dirs], -1)
|
||||
for net in self.view_net:
|
||||
x = net(x)
|
||||
|
@ -144,6 +144,7 @@ class NeRFRenderingNetwork(nn.Module):
|
|||
x = self.tanh(x)
|
||||
return x
|
||||
|
||||
|
||||
class ImplicitNetwork(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -155,11 +156,11 @@ class ImplicitNetwork(nn.Module):
|
|||
bias=1.0,
|
||||
skip_in=(),
|
||||
weight_norm=True,
|
||||
multires=0
|
||||
multires=0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
dims = [d_in] + dims + [d_out + feature_vector_size]
|
||||
dims = [d_in, *dims, d_out + feature_vector_size]
|
||||
|
||||
self.embed_fn = None
|
||||
if multires > 0:
|
||||
|
@ -189,7 +190,7 @@ class ImplicitNetwork(nn.Module):
|
|||
elif multires > 0 and l in self.skip_in:
|
||||
torch.nn.init.constant_(lin.bias, 0.0)
|
||||
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
||||
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
|
||||
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0)
|
||||
else:
|
||||
torch.nn.init.constant_(lin.bias, 0.0)
|
||||
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
||||
|
@ -222,13 +223,9 @@ class ImplicitNetwork(nn.Module):
|
|||
|
||||
def gradient(self, x):
|
||||
x.requires_grad_(True)
|
||||
y = self.forward(x)[:,:1]
|
||||
y = self.forward(x)[:, :1]
|
||||
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
||||
gradients = torch.autograd.grad(
|
||||
outputs=y,
|
||||
inputs=x,
|
||||
grad_outputs=d_output,
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
only_inputs=True)[0]
|
||||
outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True,
|
||||
)[0]
|
||||
return gradients.unsqueeze(1)
|
|
@ -1,55 +1,38 @@
|
|||
'''
|
||||
Codes are from:
|
||||
https://github.com/jaxony/unet-pytorch/blob/master/model.py
|
||||
'''
|
||||
"""Codes are from:
|
||||
https://github.com/jaxony/unet-pytorch/blob/master/model.py.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import init
|
||||
import numpy as np
|
||||
|
||||
def conv3x3(in_channels, out_channels, stride=1,
|
||||
padding=1, bias=True, groups=1):
|
||||
return nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
groups=groups)
|
||||
|
||||
def upconv2x2(in_channels, out_channels, mode='transpose'):
|
||||
if mode == 'transpose':
|
||||
return nn.ConvTranspose2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=2,
|
||||
stride=2)
|
||||
def conv3x3(in_channels, out_channels, stride=1, padding=1, bias=True, groups=1):
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, bias=bias, groups=groups)
|
||||
|
||||
|
||||
def upconv2x2(in_channels, out_channels, mode="transpose"):
|
||||
if mode == "transpose":
|
||||
return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
||||
else:
|
||||
# out_channels is always going to be the same
|
||||
# as in_channels
|
||||
return nn.Sequential(
|
||||
nn.Upsample(mode='bilinear', scale_factor=2),
|
||||
conv1x1(in_channels, out_channels))
|
||||
return nn.Sequential(nn.Upsample(mode="bilinear", scale_factor=2), conv1x1(in_channels, out_channels))
|
||||
|
||||
|
||||
def conv1x1(in_channels, out_channels, groups=1):
|
||||
return nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
groups=groups,
|
||||
stride=1)
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=groups, stride=1)
|
||||
|
||||
|
||||
class DownConv(nn.Module):
|
||||
"""
|
||||
A helper Module that performs 2 convolutions and 1 MaxPool.
|
||||
"""A helper Module that performs 2 convolutions and 1 MaxPool.
|
||||
A ReLU activation follows each convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, pooling=True):
|
||||
super(DownConv, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
@ -71,39 +54,35 @@ class DownConv(nn.Module):
|
|||
|
||||
|
||||
class UpConv(nn.Module):
|
||||
"""
|
||||
A helper Module that performs 2 convolutions and 1 UpConvolution.
|
||||
"""A helper Module that performs 2 convolutions and 1 UpConvolution.
|
||||
A ReLU activation follows each convolution.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels,
|
||||
merge_mode='concat', up_mode='transpose'):
|
||||
super(UpConv, self).__init__()
|
||||
|
||||
def __init__(self, in_channels, out_channels, merge_mode="concat", up_mode="transpose"):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.merge_mode = merge_mode
|
||||
self.up_mode = up_mode
|
||||
|
||||
self.upconv = upconv2x2(self.in_channels, self.out_channels,
|
||||
mode=self.up_mode)
|
||||
self.upconv = upconv2x2(self.in_channels, self.out_channels, mode=self.up_mode)
|
||||
|
||||
if self.merge_mode == 'concat':
|
||||
self.conv1 = conv3x3(
|
||||
2*self.out_channels, self.out_channels)
|
||||
if self.merge_mode == "concat":
|
||||
self.conv1 = conv3x3(2 * self.out_channels, self.out_channels)
|
||||
else:
|
||||
# num of input channels to conv2 is same
|
||||
self.conv1 = conv3x3(self.out_channels, self.out_channels)
|
||||
self.conv2 = conv3x3(self.out_channels, self.out_channels)
|
||||
|
||||
|
||||
def forward(self, from_down, from_up):
|
||||
""" Forward pass
|
||||
"""Forward pass
|
||||
Arguments:
|
||||
from_down: tensor from the encoder pathway
|
||||
from_up: upconv'd tensor from the decoder pathway
|
||||
from_up: upconv'd tensor from the decoder pathway.
|
||||
"""
|
||||
from_up = self.upconv(from_up)
|
||||
if self.merge_mode == 'concat':
|
||||
if self.merge_mode == "concat":
|
||||
x = torch.cat((from_up, from_down), 1)
|
||||
else:
|
||||
x = from_up + from_down
|
||||
|
@ -113,7 +92,7 @@ class UpConv(nn.Module):
|
|||
|
||||
|
||||
class UNet(nn.Module):
|
||||
""" `UNet` class is based on https://arxiv.org/abs/1505.04597
|
||||
"""`UNet` class is based on https://arxiv.org/abs/1505.04597.
|
||||
|
||||
The U-Net is a convolutional encoder-decoder neural network.
|
||||
Contextual spatial information (from the decoding,
|
||||
|
@ -135,11 +114,10 @@ class UNet(nn.Module):
|
|||
the tranpose convolution (specified by upmode='transpose')
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes, in_channels=3, depth=5,
|
||||
start_filts=64, up_mode='transpose',
|
||||
merge_mode='concat', **kwargs):
|
||||
"""
|
||||
Arguments:
|
||||
def __init__(
|
||||
self, num_classes, in_channels=3, depth=5, start_filts=64, up_mode="transpose", merge_mode="concat", **kwargs,
|
||||
):
|
||||
"""Arguments:
|
||||
in_channels: int, number of channels in the input tensor.
|
||||
Default is 3 for RGB images.
|
||||
depth: int, number of MaxPools in the U-Net.
|
||||
|
@ -149,30 +127,24 @@ class UNet(nn.Module):
|
|||
for transpose convolution or 'upsample' for nearest neighbour
|
||||
upsampling.
|
||||
"""
|
||||
super(UNet, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
if up_mode in ('transpose', 'upsample'):
|
||||
if up_mode in ("transpose", "upsample"):
|
||||
self.up_mode = up_mode
|
||||
else:
|
||||
raise ValueError("\"{}\" is not a valid mode for "
|
||||
"upsampling. Only \"transpose\" and "
|
||||
"\"upsample\" are allowed.".format(up_mode))
|
||||
msg = f'"{up_mode}" is not a valid mode for upsampling. Only "transpose" and "upsample" are allowed.'
|
||||
raise ValueError(msg)
|
||||
|
||||
if merge_mode in ('concat', 'add'):
|
||||
if merge_mode in ("concat", "add"):
|
||||
self.merge_mode = merge_mode
|
||||
else:
|
||||
raise ValueError("\"{}\" is not a valid mode for"
|
||||
"merging up and down paths. "
|
||||
"Only \"concat\" and "
|
||||
"\"add\" are allowed.".format(up_mode))
|
||||
msg = f'"{up_mode}" is not a valid mode formerging up and down paths. Only "concat" and "add" are allowed.'
|
||||
raise ValueError(msg)
|
||||
|
||||
# NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
|
||||
if self.up_mode == 'upsample' and self.merge_mode == 'add':
|
||||
raise ValueError("up_mode \"upsample\" is incompatible "
|
||||
"with merge_mode \"add\" at the moment "
|
||||
"because it doesn't make sense to use "
|
||||
"nearest neighbour to reduce "
|
||||
"depth channels (by half).")
|
||||
if self.up_mode == "upsample" and self.merge_mode == "add":
|
||||
msg = 'up_mode "upsample" is incompatible with merge_mode "add" at the moment because it doesn\'t make sense to use nearest neighbour to reduce depth channels (by half).'
|
||||
raise ValueError(msg)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.in_channels = in_channels
|
||||
|
@ -185,19 +157,18 @@ class UNet(nn.Module):
|
|||
# create the encoder pathway and add to a list
|
||||
for i in range(depth):
|
||||
ins = self.in_channels if i == 0 else outs
|
||||
outs = self.start_filts*(2**i)
|
||||
pooling = True if i < depth-1 else False
|
||||
outs = self.start_filts * (2**i)
|
||||
pooling = True if i < depth - 1 else False
|
||||
|
||||
down_conv = DownConv(ins, outs, pooling=pooling)
|
||||
self.down_convs.append(down_conv)
|
||||
|
||||
# create the decoder pathway and add to a list
|
||||
# - careful! decoding only requires depth-1 blocks
|
||||
for i in range(depth-1):
|
||||
for i in range(depth - 1):
|
||||
ins = outs
|
||||
outs = ins // 2
|
||||
up_conv = UpConv(ins, outs, up_mode=up_mode,
|
||||
merge_mode=merge_mode)
|
||||
up_conv = UpConv(ins, outs, up_mode=up_mode, merge_mode=merge_mode)
|
||||
self.up_convs.append(up_conv)
|
||||
|
||||
# add the list of modules to current module
|
||||
|
@ -214,12 +185,10 @@ class UNet(nn.Module):
|
|||
init.xavier_normal_(m.weight)
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
def reset_params(self):
|
||||
for i, m in enumerate(self.modules()):
|
||||
for _i, m in enumerate(self.modules()):
|
||||
self.weight_init(m)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
encoder_outs = []
|
||||
# encoder pathway, save outputs for merging
|
||||
|
@ -227,7 +196,7 @@ class UNet(nn.Module):
|
|||
x, before_pool = module(x)
|
||||
encoder_outs.append(before_pool)
|
||||
for i, module in enumerate(self.up_convs):
|
||||
before_pool = encoder_outs[-(i+2)]
|
||||
before_pool = encoder_outs[-(i + 2)]
|
||||
x = module(before_pool, x)
|
||||
|
||||
# No softmax is used. This means you need to use
|
||||
|
@ -236,21 +205,22 @@ class UNet(nn.Module):
|
|||
x = self.conv_final(x)
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
testing
|
||||
"""
|
||||
model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32)
|
||||
model = UNet(1, depth=5, merge_mode="concat", in_channels=1, start_filts=32)
|
||||
print(model)
|
||||
print(sum(p.numel() for p in model.parameters()))
|
||||
|
||||
reso = 176
|
||||
x = np.zeros((1, 1, reso, reso))
|
||||
x[:,:,int(reso/2-1), int(reso/2-1)] = np.nan
|
||||
x[:, :, int(reso / 2 - 1), int(reso / 2 - 1)] = np.nan
|
||||
x = torch.FloatTensor(x)
|
||||
|
||||
out = model(x)
|
||||
print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso)))
|
||||
print("%f" % (torch.sum(torch.isnan(out)).detach().cpu().numpy() / (reso * reso)))
|
||||
|
||||
# loss = torch.sum(out)
|
||||
# loss.backward()
|
||||
|
|
|
@ -1,16 +1,18 @@
|
|||
'''
|
||||
Code from the 3D UNet implementation:
|
||||
https://github.com/wolny/pytorch-3dunet/
|
||||
'''
|
||||
"""Code from the 3D UNet implementation:
|
||||
https://github.com/wolny/pytorch-3dunet/.
|
||||
"""
|
||||
import importlib
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from functools import partial
|
||||
|
||||
from src.network.utils import get_embedder
|
||||
|
||||
|
||||
def number_of_features_per_level(init_channel_number, num_levels):
|
||||
return [init_channel_number * 2 ** k for k in range(num_levels)]
|
||||
return [init_channel_number * 2**k for k in range(num_levels)]
|
||||
|
||||
|
||||
def conv3d(in_channels, out_channels, kernel_size, bias, padding=1):
|
||||
|
@ -18,8 +20,7 @@ def conv3d(in_channels, out_channels, kernel_size, bias, padding=1):
|
|||
|
||||
|
||||
def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1):
|
||||
"""
|
||||
Create a list of modules with together constitute a single conv layer with non-linearity
|
||||
"""Create a list of modules with together constitute a single conv layer with non-linearity
|
||||
and optional batchnorm/groupnorm.
|
||||
|
||||
Args:
|
||||
|
@ -37,23 +38,23 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi
|
|||
Return:
|
||||
list of tuple (name, module)
|
||||
"""
|
||||
assert 'c' in order, "Conv layer MUST be present"
|
||||
assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer'
|
||||
assert "c" in order, "Conv layer MUST be present"
|
||||
assert order[0] not in "rle", "Non-linearity cannot be the first operation in the layer"
|
||||
|
||||
modules = []
|
||||
for i, char in enumerate(order):
|
||||
if char == 'r':
|
||||
modules.append(('ReLU', nn.ReLU(inplace=True)))
|
||||
elif char == 'l':
|
||||
modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True)))
|
||||
elif char == 'e':
|
||||
modules.append(('ELU', nn.ELU(inplace=True)))
|
||||
elif char == 'c':
|
||||
if char == "r":
|
||||
modules.append(("ReLU", nn.ReLU(inplace=True)))
|
||||
elif char == "l":
|
||||
modules.append(("LeakyReLU", nn.LeakyReLU(negative_slope=0.1, inplace=True)))
|
||||
elif char == "e":
|
||||
modules.append(("ELU", nn.ELU(inplace=True)))
|
||||
elif char == "c":
|
||||
# add learnable bias only in the absence of batchnorm/groupnorm
|
||||
bias = not ('g' in order or 'b' in order)
|
||||
modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding)))
|
||||
elif char == 'g':
|
||||
is_before_conv = i < order.index('c')
|
||||
bias = not ("g" in order or "b" in order)
|
||||
modules.append(("conv", conv3d(in_channels, out_channels, kernel_size, bias, padding=padding)))
|
||||
elif char == "g":
|
||||
is_before_conv = i < order.index("c")
|
||||
if is_before_conv:
|
||||
num_channels = in_channels
|
||||
else:
|
||||
|
@ -63,14 +64,16 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi
|
|||
if num_channels < num_groups:
|
||||
num_groups = 1
|
||||
|
||||
assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}'
|
||||
modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)))
|
||||
elif char == 'b':
|
||||
is_before_conv = i < order.index('c')
|
||||
assert (
|
||||
num_channels % num_groups == 0
|
||||
), f"Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}"
|
||||
modules.append(("groupnorm", nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)))
|
||||
elif char == "b":
|
||||
is_before_conv = i < order.index("c")
|
||||
if is_before_conv:
|
||||
modules.append(('batchnorm', nn.BatchNorm3d(in_channels)))
|
||||
modules.append(("batchnorm", nn.BatchNorm3d(in_channels)))
|
||||
else:
|
||||
modules.append(('batchnorm', nn.BatchNorm3d(out_channels)))
|
||||
modules.append(("batchnorm", nn.BatchNorm3d(out_channels)))
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']")
|
||||
|
||||
|
@ -78,9 +81,8 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi
|
|||
|
||||
|
||||
class SingleConv(nn.Sequential):
|
||||
"""
|
||||
Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
|
||||
of operations can be specified via the `order` parameter
|
||||
"""Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
|
||||
of operations can be specified via the `order` parameter.
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input channels
|
||||
|
@ -94,16 +96,15 @@ class SingleConv(nn.Sequential):
|
|||
num_groups (int): number of groups for the GroupNorm
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8, padding=1):
|
||||
super(SingleConv, self).__init__()
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, order="crg", num_groups=8, padding=1):
|
||||
super().__init__()
|
||||
|
||||
for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding):
|
||||
self.add_module(name, module)
|
||||
|
||||
|
||||
class DoubleConv(nn.Sequential):
|
||||
"""
|
||||
A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
|
||||
"""A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
|
||||
We use (Conv3d+ReLU+GroupNorm3d) by default.
|
||||
This can be changed however by providing the 'order' argument, e.g. in order
|
||||
to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
|
||||
|
@ -123,8 +124,8 @@ class DoubleConv(nn.Sequential):
|
|||
num_groups (int): number of groups for the GroupNorm
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='crg', num_groups=8):
|
||||
super(DoubleConv, self).__init__()
|
||||
def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order="crg", num_groups=8):
|
||||
super().__init__()
|
||||
if encoder:
|
||||
# we're in the encoder path
|
||||
conv1_in_channels = in_channels
|
||||
|
@ -138,26 +139,27 @@ class DoubleConv(nn.Sequential):
|
|||
conv2_in_channels, conv2_out_channels = out_channels, out_channels
|
||||
|
||||
# conv1
|
||||
self.add_module('SingleConv1',
|
||||
SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups))
|
||||
self.add_module(
|
||||
"SingleConv1", SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups),
|
||||
)
|
||||
# conv2
|
||||
self.add_module('SingleConv2',
|
||||
SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups))
|
||||
self.add_module(
|
||||
"SingleConv2", SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups),
|
||||
)
|
||||
|
||||
|
||||
class ExtResNetBlock(nn.Module):
|
||||
"""
|
||||
Basic UNet block consisting of a SingleConv followed by the residual block.
|
||||
"""Basic UNet block consisting of a SingleConv followed by the residual block.
|
||||
The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number
|
||||
of output channels is compatible with the residual block that follows.
|
||||
This block can be used instead of standard DoubleConv in the Encoder module.
|
||||
Motivated by: https://arxiv.org/pdf/1706.00120.pdf
|
||||
Motivated by: https://arxiv.org/pdf/1706.00120.pdf.
|
||||
|
||||
Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs):
|
||||
super(ExtResNetBlock, self).__init__()
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, order="cge", num_groups=8, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
# first convolution
|
||||
self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
|
||||
|
@ -165,15 +167,16 @@ class ExtResNetBlock(nn.Module):
|
|||
self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
|
||||
# remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
|
||||
n_order = order
|
||||
for c in 'rel':
|
||||
n_order = n_order.replace(c, '')
|
||||
self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order,
|
||||
num_groups=num_groups)
|
||||
for c in "rel":
|
||||
n_order = n_order.replace(c, "")
|
||||
self.conv3 = SingleConv(
|
||||
out_channels, out_channels, kernel_size=kernel_size, order=n_order, num_groups=num_groups,
|
||||
)
|
||||
|
||||
# create non-linearity separately
|
||||
if 'l' in order:
|
||||
if "l" in order:
|
||||
self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
elif 'e' in order:
|
||||
elif "e" in order:
|
||||
self.non_linearity = nn.ELU(inplace=True)
|
||||
else:
|
||||
self.non_linearity = nn.ReLU(inplace=True)
|
||||
|
@ -194,12 +197,12 @@ class ExtResNetBlock(nn.Module):
|
|||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""
|
||||
A single module from the encoder path consisting of the optional max
|
||||
"""A single module from the encoder path consisting of the optional max
|
||||
pooling layer (one may specify the MaxPool kernel_size to be different
|
||||
than the standard (2,2,2), e.g. if the volumetric data is anisotropic
|
||||
(make sure to use complementary scale_factor in the decoder path) followed by
|
||||
a DoubleConv module.
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input channels
|
||||
out_channels (int): number of output channels
|
||||
|
@ -210,27 +213,39 @@ class Encoder(nn.Module):
|
|||
basic_module(nn.Module): either ResNetBlock or DoubleConv
|
||||
conv_layer_order (string): determines the order of layers
|
||||
in `DoubleConv` module. See `DoubleConv` for more info.
|
||||
num_groups (int): number of groups for the GroupNorm
|
||||
num_groups (int): number of groups for the GroupNorm.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True,
|
||||
pool_kernel_size=(2, 2, 2), pool_type='max', basic_module=DoubleConv, conv_layer_order='crg',
|
||||
num_groups=8):
|
||||
super(Encoder, self).__init__()
|
||||
assert pool_type in ['max', 'avg']
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_kernel_size=3,
|
||||
apply_pooling=True,
|
||||
pool_kernel_size=(2, 2, 2),
|
||||
pool_type="max",
|
||||
basic_module=DoubleConv,
|
||||
conv_layer_order="crg",
|
||||
num_groups=8,
|
||||
):
|
||||
super().__init__()
|
||||
assert pool_type in ["max", "avg"]
|
||||
if apply_pooling:
|
||||
if pool_type == 'max':
|
||||
if pool_type == "max":
|
||||
self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
|
||||
else:
|
||||
self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size)
|
||||
else:
|
||||
self.pooling = None
|
||||
|
||||
self.basic_module = basic_module(in_channels, out_channels,
|
||||
self.basic_module = basic_module(
|
||||
in_channels,
|
||||
out_channels,
|
||||
encoder=True,
|
||||
kernel_size=conv_kernel_size,
|
||||
order=conv_layer_order,
|
||||
num_groups=num_groups)
|
||||
num_groups=num_groups,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.pooling is not None:
|
||||
|
@ -240,9 +255,9 @@ class Encoder(nn.Module):
|
|||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""
|
||||
A single module for decoder path consisting of the upsampling layer
|
||||
"""A single module for decoder path consisting of the upsampling layer
|
||||
(either learned ConvTranspose3d or nearest neighbor interpolation) followed by a basic module (DoubleConv or ExtResNetBlock).
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input channels
|
||||
out_channels (int): number of output channels
|
||||
|
@ -253,32 +268,56 @@ class Decoder(nn.Module):
|
|||
basic_module(nn.Module): either ResNetBlock or DoubleConv
|
||||
conv_layer_order (string): determines the order of layers
|
||||
in `DoubleConv` module. See `DoubleConv` for more info.
|
||||
num_groups (int): number of groups for the GroupNorm
|
||||
num_groups (int): number of groups for the GroupNorm.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv,
|
||||
conv_layer_order='crg', num_groups=8, mode='nearest'):
|
||||
super(Decoder, self).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
scale_factor=(2, 2, 2),
|
||||
basic_module=DoubleConv,
|
||||
conv_layer_order="crg",
|
||||
num_groups=8,
|
||||
mode="nearest",
|
||||
):
|
||||
super().__init__()
|
||||
if basic_module == DoubleConv:
|
||||
# if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining
|
||||
self.upsampling = Upsampling(transposed_conv=False, in_channels=in_channels, out_channels=out_channels,
|
||||
kernel_size=kernel_size, scale_factor=scale_factor, mode=mode)
|
||||
self.upsampling = Upsampling(
|
||||
transposed_conv=False,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
scale_factor=scale_factor,
|
||||
mode=mode,
|
||||
)
|
||||
# concat joining
|
||||
self.joining = partial(self._joining, concat=True)
|
||||
else:
|
||||
# if basic_module=ExtResNetBlock use transposed convolution upsampling and summation joining
|
||||
self.upsampling = Upsampling(transposed_conv=True, in_channels=in_channels, out_channels=out_channels,
|
||||
kernel_size=kernel_size, scale_factor=scale_factor, mode=mode)
|
||||
self.upsampling = Upsampling(
|
||||
transposed_conv=True,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
scale_factor=scale_factor,
|
||||
mode=mode,
|
||||
)
|
||||
# sum joining
|
||||
self.joining = partial(self._joining, concat=False)
|
||||
# adapt the number of in_channels for the ExtResNetBlock
|
||||
in_channels = out_channels
|
||||
|
||||
self.basic_module = basic_module(in_channels, out_channels,
|
||||
self.basic_module = basic_module(
|
||||
in_channels,
|
||||
out_channels,
|
||||
encoder=False,
|
||||
kernel_size=kernel_size,
|
||||
order=conv_layer_order,
|
||||
num_groups=num_groups)
|
||||
num_groups=num_groups,
|
||||
)
|
||||
|
||||
def forward(self, encoder_features, x):
|
||||
x = self.upsampling(encoder_features=encoder_features, x=x)
|
||||
|
@ -295,8 +334,7 @@ class Decoder(nn.Module):
|
|||
|
||||
|
||||
class Upsampling(nn.Module):
|
||||
"""
|
||||
Upsamples a given multi-channel 3D data using either interpolation or learned transposed convolution.
|
||||
"""Upsamples a given multi-channel 3D data using either interpolation or learned transposed convolution.
|
||||
|
||||
Args:
|
||||
transposed_conv (bool): if True uses ConvTranspose3d for upsampling, otherwise uses interpolation
|
||||
|
@ -310,15 +348,23 @@ class Upsampling(nn.Module):
|
|||
'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest'
|
||||
"""
|
||||
|
||||
def __init__(self, transposed_conv, in_channels=None, out_channels=None, kernel_size=3,
|
||||
scale_factor=(2, 2, 2), mode='nearest'):
|
||||
super(Upsampling, self).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
transposed_conv,
|
||||
in_channels=None,
|
||||
out_channels=None,
|
||||
kernel_size=3,
|
||||
scale_factor=(2, 2, 2),
|
||||
mode="nearest",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if transposed_conv:
|
||||
# make sure that the output size reverses the MaxPool3d from the corresponding encoder
|
||||
# (D_out = (D_in − 1) × stride[0] − 2 × padding[0] + kernel_size[0] + output_padding[0])
|
||||
self.upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor,
|
||||
padding=1)
|
||||
# (D_out = (D_in - 1) x stride[0] - 2 x padding[0] + kernel_size[0] + output_padding[0])
|
||||
self.upsample = nn.ConvTranspose3d(
|
||||
in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor, padding=1,
|
||||
)
|
||||
else:
|
||||
self.upsample = partial(self._interpolate, mode=mode)
|
||||
|
||||
|
@ -332,13 +378,13 @@ class Upsampling(nn.Module):
|
|||
|
||||
|
||||
class FinalConv(nn.Sequential):
|
||||
"""
|
||||
A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution
|
||||
"""A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution
|
||||
which reduces the number of channels to 'out_channels'.
|
||||
with the number of output channels 'out_channels // 2' and 'out_channels' respectively.
|
||||
We use (Conv3d+ReLU+GroupNorm3d) by default.
|
||||
This can be change however by providing the 'order' argument, e.g. in order
|
||||
to change to Conv3d+BatchNorm3d+ReLU use order='cbr'.
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input channels
|
||||
out_channels (int): number of output channels
|
||||
|
@ -346,22 +392,22 @@ class FinalConv(nn.Sequential):
|
|||
order (string): determines the order of layers, e.g.
|
||||
'cr' -> conv + ReLU
|
||||
'crg' -> conv + ReLU + groupnorm
|
||||
num_groups (int): number of groups for the GroupNorm
|
||||
num_groups (int): number of groups for the GroupNorm.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8):
|
||||
super(FinalConv, self).__init__()
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, order="crg", num_groups=8):
|
||||
super().__init__()
|
||||
|
||||
# conv1
|
||||
self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups))
|
||||
self.add_module("SingleConv", SingleConv(in_channels, in_channels, kernel_size, order, num_groups))
|
||||
|
||||
# in the last layer a 1×1 convolution reduces the number of output channels to out_channels
|
||||
# in the last layer a 1x1 convolution reduces the number of output channels to out_channels
|
||||
final_conv = nn.Conv3d(in_channels, out_channels, 1)
|
||||
self.add_module('final_conv', final_conv)
|
||||
self.add_module("final_conv", final_conv)
|
||||
|
||||
|
||||
class Abstract3DUNet(nn.Module):
|
||||
"""
|
||||
Base class for standard and residual UNet.
|
||||
"""Base class for standard and residual UNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input channels
|
||||
|
@ -391,9 +437,22 @@ class Abstract3DUNet(nn.Module):
|
|||
and the `final_activation` (even if present) won't be applied; default: False
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr',
|
||||
num_groups=8, num_levels=4, is_segmentation=False, testing=False, pe_freq=0, **kwargs):
|
||||
super(Abstract3DUNet, self).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
final_sigmoid,
|
||||
basic_module,
|
||||
f_maps=64,
|
||||
layer_order="gcr",
|
||||
num_groups=8,
|
||||
num_levels=4,
|
||||
is_segmentation=False,
|
||||
testing=False,
|
||||
pe_freq=0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.testing = testing
|
||||
|
||||
|
@ -411,13 +470,24 @@ class Abstract3DUNet(nn.Module):
|
|||
encoders = []
|
||||
for i, out_feature_num in enumerate(f_maps):
|
||||
if i == 0:
|
||||
encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=basic_module,
|
||||
conv_layer_order=layer_order, num_groups=num_groups)
|
||||
encoder = Encoder(
|
||||
in_channels,
|
||||
out_feature_num,
|
||||
apply_pooling=False,
|
||||
basic_module=basic_module,
|
||||
conv_layer_order=layer_order,
|
||||
num_groups=num_groups,
|
||||
)
|
||||
else:
|
||||
# TODO: adapt for anisotropy in the data, i.e. use proper pooling kernel to make the data isotropic after 1-2 pooling operations
|
||||
# currently pools with a constant kernel: (2, 2, 2)
|
||||
encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=basic_module,
|
||||
conv_layer_order=layer_order, num_groups=num_groups)
|
||||
encoder = Encoder(
|
||||
f_maps[i - 1],
|
||||
out_feature_num,
|
||||
basic_module=basic_module,
|
||||
conv_layer_order=layer_order,
|
||||
num_groups=num_groups,
|
||||
)
|
||||
encoders.append(encoder)
|
||||
|
||||
self.encoders = nn.ModuleList(encoders)
|
||||
|
@ -434,13 +504,18 @@ class Abstract3DUNet(nn.Module):
|
|||
out_feature_num = reversed_f_maps[i + 1]
|
||||
# TODO: if non-standard pooling was used, make sure to use correct striding for transpose conv
|
||||
# currently strides with a constant stride: (2, 2, 2)
|
||||
decoder = Decoder(in_feature_num, out_feature_num, basic_module=basic_module,
|
||||
conv_layer_order=layer_order, num_groups=num_groups)
|
||||
decoder = Decoder(
|
||||
in_feature_num,
|
||||
out_feature_num,
|
||||
basic_module=basic_module,
|
||||
conv_layer_order=layer_order,
|
||||
num_groups=num_groups,
|
||||
)
|
||||
decoders.append(decoder)
|
||||
|
||||
self.decoders = nn.ModuleList(decoders)
|
||||
|
||||
# in the last layer a 1×1 convolution reduces the number of output
|
||||
# in the last layer a 1x1 convolution reduces the number of output
|
||||
# channels to the number of labels
|
||||
self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)
|
||||
|
||||
|
@ -455,7 +530,6 @@ class Abstract3DUNet(nn.Module):
|
|||
self.final_activation = None
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self.embed_fn is not None:
|
||||
x = self.embed_fn(x.permute(0, 2, 3, 4, 1))
|
||||
x = x.permute(0, 4, 1, 2, 3)
|
||||
|
@ -488,49 +562,81 @@ class Abstract3DUNet(nn.Module):
|
|||
|
||||
|
||||
class UNet3D(Abstract3DUNet):
|
||||
"""
|
||||
3DUnet model from
|
||||
"""3DUnet model from
|
||||
`"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation"
|
||||
<https://arxiv.org/pdf/1606.06650.pdf>`.
|
||||
|
||||
Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
|
||||
num_groups=8, num_levels=4, is_segmentation=True, **kwargs):
|
||||
super(UNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid,
|
||||
basic_module=DoubleConv, f_maps=f_maps, layer_order=layer_order,
|
||||
num_groups=num_groups, num_levels=num_levels, is_segmentation=is_segmentation,
|
||||
**kwargs)
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
final_sigmoid=True,
|
||||
f_maps=64,
|
||||
layer_order="gcr",
|
||||
num_groups=8,
|
||||
num_levels=4,
|
||||
is_segmentation=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
final_sigmoid=final_sigmoid,
|
||||
basic_module=DoubleConv,
|
||||
f_maps=f_maps,
|
||||
layer_order=layer_order,
|
||||
num_groups=num_groups,
|
||||
num_levels=num_levels,
|
||||
is_segmentation=is_segmentation,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class ResidualUNet3D(Abstract3DUNet):
|
||||
"""
|
||||
Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
|
||||
"""Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
|
||||
Uses ExtResNetBlock as a basic building block, summation joining instead
|
||||
of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts).
|
||||
Since the model effectively becomes a residual net, in theory it allows for deeper UNet.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
|
||||
num_groups=8, num_levels=5, is_segmentation=True, **kwargs):
|
||||
super(ResidualUNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels,
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
final_sigmoid=True,
|
||||
f_maps=64,
|
||||
layer_order="gcr",
|
||||
num_groups=8,
|
||||
num_levels=5,
|
||||
is_segmentation=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
final_sigmoid=final_sigmoid,
|
||||
basic_module=ExtResNetBlock, f_maps=f_maps, layer_order=layer_order,
|
||||
num_groups=num_groups, num_levels=num_levels,
|
||||
basic_module=ExtResNetBlock,
|
||||
f_maps=f_maps,
|
||||
layer_order=layer_order,
|
||||
num_groups=num_groups,
|
||||
num_levels=num_levels,
|
||||
is_segmentation=is_segmentation,
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def get_model(config):
|
||||
def _model_class(class_name):
|
||||
m = importlib.import_module('pytorch3dunet.unet3d.model')
|
||||
m = importlib.import_module("pytorch3dunet.unet3d.model")
|
||||
clazz = getattr(m, class_name)
|
||||
return clazz
|
||||
|
||||
assert 'model' in config, 'Could not find model configuration'
|
||||
model_config = config['model']
|
||||
model_class = _model_class(model_config['name'])
|
||||
assert "model" in config, "Could not find model configuration"
|
||||
model_config = config["model"]
|
||||
model_class = _model_class(model_config["name"])
|
||||
return model_class(**model_config)
|
||||
|
||||
|
||||
|
@ -542,18 +648,18 @@ if __name__ == "__main__":
|
|||
out_channels = 1
|
||||
f_maps = 32
|
||||
num_levels = 2
|
||||
model = UNet3D(in_channels, out_channels, f_maps=f_maps, num_levels=num_levels, layer_order='cr')
|
||||
model = UNet3D(in_channels, out_channels, f_maps=f_maps, num_levels=num_levels, layer_order="cr")
|
||||
print(model)
|
||||
print('number of parameters: ', sum(p.numel() for p in model.parameters()))
|
||||
print("number of parameters: ", sum(p.numel() for p in model.parameters()))
|
||||
|
||||
reso = 18
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
x = np.zeros((1, 1, reso, reso, reso))
|
||||
x[:,:, int(reso/2-1), int(reso/2-1), int(reso/2-1)] = np.nan
|
||||
x[:, :, int(reso / 2 - 1), int(reso / 2 - 1), int(reso / 2 - 1)] = np.nan
|
||||
x = torch.FloatTensor(x)
|
||||
|
||||
out = model(x)
|
||||
print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso*reso)))
|
||||
|
||||
print("%f" % (torch.sum(torch.isnan(out)).detach().cpu().numpy() / (reso * reso * reso)))
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
""" Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """
|
||||
"""Positional encoding embedding. Code was taken from https://github.com/bmild/nerf."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Embedder:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
@ -10,24 +11,23 @@ class Embedder:
|
|||
|
||||
def create_embedding_fn(self):
|
||||
embed_fns = []
|
||||
d = self.kwargs['input_dims']
|
||||
d = self.kwargs["input_dims"]
|
||||
out_dim = 0
|
||||
if self.kwargs['include_input']:
|
||||
if self.kwargs["include_input"]:
|
||||
embed_fns.append(lambda x: x)
|
||||
out_dim += d
|
||||
|
||||
max_freq = self.kwargs['max_freq_log2']
|
||||
N_freqs = self.kwargs['num_freqs']
|
||||
max_freq = self.kwargs["max_freq_log2"]
|
||||
N_freqs = self.kwargs["num_freqs"]
|
||||
|
||||
if self.kwargs['log_sampling']:
|
||||
freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
|
||||
if self.kwargs["log_sampling"]:
|
||||
freq_bands = 2.0 ** torch.linspace(0.0, max_freq, N_freqs)
|
||||
else:
|
||||
freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
|
||||
freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, N_freqs)
|
||||
|
||||
for freq in freq_bands:
|
||||
for p_fn in self.kwargs['periodic_fns']:
|
||||
embed_fns.append(lambda x, p_fn=p_fn,
|
||||
freq=freq: p_fn(x * freq))
|
||||
for p_fn in self.kwargs["periodic_fns"]:
|
||||
embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
|
||||
out_dim += d
|
||||
|
||||
self.embed_fns = embed_fns
|
||||
|
@ -36,31 +36,36 @@ class Embedder:
|
|||
def embed(self, inputs):
|
||||
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
|
||||
|
||||
|
||||
def get_embedder(multires, d_in=3):
|
||||
embed_kwargs = {
|
||||
'include_input': True,
|
||||
'input_dims': d_in,
|
||||
'max_freq_log2': multires-1,
|
||||
'num_freqs': multires,
|
||||
'log_sampling': True,
|
||||
'periodic_fns': [torch.sin, torch.cos],
|
||||
"include_input": True,
|
||||
"input_dims": d_in,
|
||||
"max_freq_log2": multires - 1,
|
||||
"num_freqs": multires,
|
||||
"log_sampling": True,
|
||||
"periodic_fns": [torch.sin, torch.cos],
|
||||
}
|
||||
|
||||
embedder_obj = Embedder(**embed_kwargs)
|
||||
def embed(x, eo=embedder_obj): return eo.embed(x)
|
||||
|
||||
def embed(x, eo=embedder_obj):
|
||||
return eo.embed(x)
|
||||
|
||||
return embed, embedder_obj.out_dim
|
||||
|
||||
def normalize_coordinate(p, plane='xz'):
|
||||
''' Normalize coordinate to [0, 1] for unit cube experiments
|
||||
|
||||
def normalize_coordinate(p, plane="xz"):
|
||||
"""Normalize coordinate to [0, 1] for unit cube experiments.
|
||||
|
||||
Args:
|
||||
p (tensor): point
|
||||
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
||||
plane (str): plane feature type, ['xz', 'xy', 'yz']
|
||||
'''
|
||||
if plane == 'xz':
|
||||
"""
|
||||
if plane == "xz":
|
||||
xy = p[:, :, [0, 2]]
|
||||
elif plane =='xy':
|
||||
elif plane == "xy":
|
||||
xy = p[:, :, [0, 1]]
|
||||
else:
|
||||
xy = p[:, :, [1, 2]]
|
||||
|
@ -75,40 +80,41 @@ def normalize_coordinate(p, plane='xz'):
|
|||
|
||||
|
||||
def normalize_3d_coordinate(p):
|
||||
''' Normalize coordinate to [0, 1] for unit cube experiments.
|
||||
'''
|
||||
"""Normalize coordinate to [0, 1] for unit cube experiments."""
|
||||
if p.max() >= 1:
|
||||
p[p >= 1] = 1 - 10e-6
|
||||
if p.min() < 0:
|
||||
p[p < 0] = 0.0
|
||||
return p
|
||||
|
||||
def coordinate2index(x, reso, coord_type='2d'):
|
||||
''' Normalize coordinate to [0, 1] for unit cube experiments.
|
||||
Corresponds to our 3D model
|
||||
|
||||
def coordinate2index(x, reso, coord_type="2d"):
|
||||
"""Normalize coordinate to [0, 1] for unit cube experiments.
|
||||
Corresponds to our 3D model.
|
||||
|
||||
Args:
|
||||
x (tensor): coordinate
|
||||
reso (int): defined resolution
|
||||
coord_type (str): coordinate type
|
||||
'''
|
||||
"""
|
||||
x = (x * reso).long()
|
||||
if coord_type == '2d': # plane
|
||||
if coord_type == "2d": # plane
|
||||
index = x[:, :, 0] + reso * x[:, :, 1]
|
||||
elif coord_type == '3d': # grid
|
||||
elif coord_type == "3d": # grid
|
||||
index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2])
|
||||
index = index[:, None, :]
|
||||
return index
|
||||
|
||||
|
||||
class map2local(object):
|
||||
''' Add new keys to the given input
|
||||
class map2local:
|
||||
"""Add new keys to the given input.
|
||||
|
||||
Args:
|
||||
s (float): the defined voxel size
|
||||
pos_encoding (str): method for the positional encoding, linear|sin_cos
|
||||
'''
|
||||
def __init__(self, s, pos_encoding='linear'):
|
||||
"""
|
||||
|
||||
def __init__(self, s, pos_encoding="linear"):
|
||||
super().__init__()
|
||||
self.s = s
|
||||
# self.pe = positional_encoding(basis_function=pos_encoding, local=True)
|
||||
|
@ -121,15 +127,16 @@ class map2local(object):
|
|||
# p = self.pe(p)
|
||||
return p
|
||||
|
||||
|
||||
# Resnet Blocks
|
||||
class ResnetBlockFC(nn.Module):
|
||||
''' Fully connected ResNet Block class.
|
||||
"""Fully connected ResNet Block class.
|
||||
|
||||
Args:
|
||||
size_in (int): input dimension
|
||||
size_out (int): output dimension
|
||||
size_h (int): hidden dimension
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, size_in, size_out=None, size_h=None, siren=False):
|
||||
super().__init__()
|
||||
|
|
|
@ -1,54 +1,56 @@
|
|||
import time, os
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import open3d as o3d
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import trimesh
|
||||
from pytorch3d.loss import chamfer_distance
|
||||
from torchvision.io import write_video
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from src.dpsr import DPSR
|
||||
from src.model import PSR2Mesh
|
||||
from src.utils import grid_interp, verts_on_largest_mesh,\
|
||||
export_pointcloud, mc_from_psr, GaussianSmoothing
|
||||
from src.visualize import visualize_points_mesh, visualize_psr_grid, \
|
||||
visualize_mesh_phong, render_rgb
|
||||
from torchvision.utils import save_image
|
||||
from torchvision.io import write_video
|
||||
from pytorch3d.loss import chamfer_distance
|
||||
import open3d as o3d
|
||||
from src.utils import export_pointcloud, mc_from_psr, verts_on_largest_mesh
|
||||
from src.visualize import (
|
||||
render_rgb,
|
||||
visualize_mesh_phong,
|
||||
visualize_points_mesh,
|
||||
visualize_psr_grid,
|
||||
)
|
||||
|
||||
class Trainer(object):
|
||||
'''
|
||||
Args:
|
||||
|
||||
class Trainer:
|
||||
"""Args:
|
||||
cfg : config file
|
||||
optimizer : pytorch optimizer object
|
||||
device : pytorch device
|
||||
'''
|
||||
device : pytorch device.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, optimizer, device=None):
|
||||
self.optimizer = optimizer
|
||||
self.device = device
|
||||
self.cfg = cfg
|
||||
self.psr2mesh = PSR2Mesh.apply
|
||||
self.data_type = cfg['data']['data_type']
|
||||
self.data_type = cfg["data"]["data_type"]
|
||||
|
||||
# initialize DPSR
|
||||
self.dpsr = DPSR(res=(cfg['model']['grid_res'],
|
||||
cfg['model']['grid_res'],
|
||||
cfg['model']['grid_res']),
|
||||
sig=cfg['model']['psr_sigma'])
|
||||
self.dpsr = DPSR(
|
||||
res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
|
||||
sig=cfg["model"]["psr_sigma"],
|
||||
)
|
||||
if torch.cuda.device_count() > 1:
|
||||
self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR
|
||||
self.dpsr = self.dpsr.to(device)
|
||||
|
||||
def train_step(self, data, inputs, model, it):
|
||||
''' Performs a training step.
|
||||
"""Performs a training step.
|
||||
|
||||
Args:
|
||||
data (dict) : data dictionary
|
||||
inputs (torch.tensor) : input point clouds
|
||||
model (nn.Module or None): a neural network or None
|
||||
it (int) : the number of iterations
|
||||
'''
|
||||
|
||||
"""
|
||||
self.optimizer.zero_grad()
|
||||
loss, loss_each = self.compute_loss(inputs, data, model, it)
|
||||
|
||||
|
@ -58,16 +60,15 @@ class Trainer(object):
|
|||
return loss.item(), loss_each
|
||||
|
||||
def compute_loss(self, inputs, data, model, it=0):
|
||||
''' Compute the loss.
|
||||
"""Compute the loss.
|
||||
|
||||
Args:
|
||||
data (dict) : data dictionary
|
||||
inputs (torch.tensor) : input point clouds
|
||||
model (nn.Module or None): a neural network or None
|
||||
it (int) : the number of iterations
|
||||
'''
|
||||
|
||||
device = self.device
|
||||
res = self.cfg['model']['grid_res']
|
||||
it (int) : the number of iterations.
|
||||
"""
|
||||
res = self.cfg["model"]["grid_res"]
|
||||
|
||||
# source oriented point clouds to PSR grid
|
||||
psr_grid, points, normals = self.pcl2psr(inputs)
|
||||
|
@ -77,36 +78,33 @@ class Trainer(object):
|
|||
|
||||
# the output is in the range of [0, 1), we make it to the real range [0, 1].
|
||||
# This is a hack for our DPSR solver
|
||||
v = v * res / (res-1)
|
||||
v = v * res / (res - 1)
|
||||
|
||||
points = points * 2. - 1.
|
||||
v = v * 2. - 1. # within the range of (-1, 1)
|
||||
points = points * 2.0 - 1.0
|
||||
v = v * 2.0 - 1.0 # within the range of (-1, 1)
|
||||
|
||||
loss = 0
|
||||
loss_each = {}
|
||||
# compute loss
|
||||
if self.data_type == 'point':
|
||||
if self.cfg['train']['w_chamfer'] > 0:
|
||||
loss_ = self.cfg['train']['w_chamfer'] * \
|
||||
self.compute_3d_loss(v, data)
|
||||
loss_each['chamfer'] = loss_
|
||||
if self.data_type == "point":
|
||||
if self.cfg["train"]["w_chamfer"] > 0:
|
||||
loss_ = self.cfg["train"]["w_chamfer"] * self.compute_3d_loss(v, data)
|
||||
loss_each["chamfer"] = loss_
|
||||
loss += loss_
|
||||
elif self.data_type == 'img':
|
||||
elif self.data_type == "img":
|
||||
loss, loss_each = self.compute_2d_loss(inputs, data, model)
|
||||
|
||||
return loss, loss_each
|
||||
|
||||
|
||||
def pcl2psr(self, inputs):
|
||||
''' Convert an oriented point cloud to PSR indicator grid
|
||||
"""Convert an oriented point cloud to PSR indicator grid
|
||||
Args:
|
||||
inputs (torch.tensor): input oriented point clouds
|
||||
'''
|
||||
|
||||
points, normals = inputs[...,:3], inputs[...,3:]
|
||||
if self.cfg['model']['apply_sigmoid']:
|
||||
inputs (torch.tensor): input oriented point clouds.
|
||||
"""
|
||||
points, normals = inputs[..., :3], inputs[..., 3:]
|
||||
if self.cfg["model"]["apply_sigmoid"]:
|
||||
points = torch.sigmoid(points)
|
||||
if self.cfg['model']['normal_normalize']:
|
||||
if self.cfg["model"]["normal_normalize"]:
|
||||
normals = normals / normals.norm(dim=-1, keepdim=True)
|
||||
|
||||
# DPSR to get grid
|
||||
|
@ -116,17 +114,17 @@ class Trainer(object):
|
|||
return psr_grid, points, normals
|
||||
|
||||
def compute_3d_loss(self, v, data):
|
||||
''' Compute the loss for point clouds.
|
||||
"""Compute the loss for point clouds.
|
||||
|
||||
Args:
|
||||
v (torch.tensor) : mesh vertices
|
||||
data (dict) : data dictionary
|
||||
'''
|
||||
|
||||
pts_gt = data.get('target_points')
|
||||
idx = np.random.randint(pts_gt.shape[1], size=self.cfg['train']['n_sup_point'])
|
||||
if self.cfg['train']['subsample_vertex']:
|
||||
#chamfer distance only on random sampled vertices
|
||||
idx = np.random.randint(v.shape[1], size=self.cfg['train']['n_sup_point'])
|
||||
data (dict) : data dictionary.
|
||||
"""
|
||||
pts_gt = data.get("target_points")
|
||||
idx = np.random.randint(pts_gt.shape[1], size=self.cfg["train"]["n_sup_point"])
|
||||
if self.cfg["train"]["subsample_vertex"]:
|
||||
# chamfer distance only on random sampled vertices
|
||||
idx = np.random.randint(v.shape[1], size=self.cfg["train"]["n_sup_point"])
|
||||
loss, _ = chamfer_distance(v[:, idx], pts_gt)
|
||||
else:
|
||||
loss, _ = chamfer_distance(v, pts_gt)
|
||||
|
@ -134,34 +132,31 @@ class Trainer(object):
|
|||
return loss
|
||||
|
||||
def compute_2d_loss(self, inputs, data, model):
|
||||
''' Compute the 2D losses.
|
||||
"""Compute the 2D losses.
|
||||
|
||||
Args:
|
||||
inputs (torch.tensor) : input source point clouds
|
||||
data (dict) : data dictionary
|
||||
model (nn.Module or None): neural network or None
|
||||
'''
|
||||
|
||||
losses = {"color":
|
||||
{"weight": self.cfg['train']['l_weight']['rgb'],
|
||||
"values": []
|
||||
model (nn.Module or None): neural network or None.
|
||||
"""
|
||||
losses = {
|
||||
"color": {
|
||||
"weight": self.cfg["train"]["l_weight"]["rgb"],
|
||||
"values": [],
|
||||
},
|
||||
"silhouette":
|
||||
{"weight": self.cfg['train']['l_weight']['mask'],
|
||||
"values": []},
|
||||
"silhouette": {"weight": self.cfg["train"]["l_weight"]["mask"], "values": []},
|
||||
}
|
||||
loss_all = {k: torch.tensor(0.0, device=self.device) for k in losses}
|
||||
|
||||
# forward pass
|
||||
out = model(inputs, data)
|
||||
|
||||
if out['rgb'] is not None:
|
||||
rgb_gt = out['rgb_gt'].reshape(self.cfg['data']['n_views_per_iter'],
|
||||
-1, 3)[out['vis_mask']]
|
||||
loss_all["color"] += torch.nn.L1Loss(reduction='sum')(rgb_gt,
|
||||
out['rgb']) / out['rgb'].shape[0]
|
||||
if out["rgb"] is not None:
|
||||
rgb_gt = out["rgb_gt"].reshape(self.cfg["data"]["n_views_per_iter"], -1, 3)[out["vis_mask"]]
|
||||
loss_all["color"] += torch.nn.L1Loss(reduction="sum")(rgb_gt, out["rgb"]) / out["rgb"].shape[0]
|
||||
|
||||
if out['mask'] is not None:
|
||||
loss_all["silhouette"] += ((out['mask'] - out['mask_gt']) ** 2).mean()
|
||||
if out["mask"] is not None:
|
||||
loss_all["silhouette"] += ((out["mask"] - out["mask_gt"]) ** 2).mean()
|
||||
|
||||
# weighted sum of the losses
|
||||
loss = torch.tensor(0.0, device=self.device)
|
||||
|
@ -172,15 +167,14 @@ class Trainer(object):
|
|||
return loss, loss_all
|
||||
|
||||
def point_resampling(self, inputs):
|
||||
''' Resample points
|
||||
"""Resample points
|
||||
Args:
|
||||
inputs (torch.tensor): oriented point clouds
|
||||
'''
|
||||
|
||||
inputs (torch.tensor): oriented point clouds.
|
||||
"""
|
||||
psr_grid, points, normals = self.pcl2psr(inputs)
|
||||
|
||||
# shortcuts
|
||||
n_grow = self.cfg['train']['n_grow_points']
|
||||
n_grow = self.cfg["train"]["n_grow_points"]
|
||||
|
||||
# [hack] for points resampled from the mesh from marching cubes,
|
||||
# we need to divide by s instead of (s-1), and the scale is correct.
|
||||
|
@ -191,13 +185,13 @@ class Trainer(object):
|
|||
|
||||
# sample vertices only from the largest component, not from fragments
|
||||
mesh = trimesh.Trimesh(vertices=pts_mesh, faces=faces_mesh)
|
||||
pi, face_idx = mesh.sample(n_grow+points.shape[1], return_index=True)
|
||||
normals_i = mesh.face_normals[face_idx].astype('float32')
|
||||
pts_mesh = torch.tensor(pi.astype('float32')).to(self.device)[None]
|
||||
pi, face_idx = mesh.sample(n_grow + points.shape[1], return_index=True)
|
||||
normals_i = mesh.face_normals[face_idx].astype("float32")
|
||||
pts_mesh = torch.tensor(pi.astype("float32")).to(self.device)[None]
|
||||
n_mesh = torch.tensor(normals_i).to(self.device)[None]
|
||||
|
||||
points, normals = pts_mesh, n_mesh
|
||||
print('{} total points are resampled'.format(points.shape[1]))
|
||||
print(f"{points.shape[1]} total points are resampled")
|
||||
|
||||
# update inputs
|
||||
points = torch.log(points / (1 - points)) # inverse sigmoid
|
||||
|
@ -207,119 +201,112 @@ class Trainer(object):
|
|||
return inputs
|
||||
|
||||
def visualize(self, data, inputs, renderer, epoch, o3d_vis=None):
|
||||
''' Visualization.
|
||||
"""Visualization.
|
||||
|
||||
Args:
|
||||
data (dict) : data dictionary
|
||||
inputs (torch.tensor) : source point clouds
|
||||
renderer (nn.Module or None): a neural network or None
|
||||
epoch (int) : the number of iterations
|
||||
o3d_vis (o3d.Visualizer) : open3d visualizer
|
||||
'''
|
||||
o3d_vis (o3d.Visualizer) : open3d visualizer.
|
||||
"""
|
||||
data_type = self.cfg["data"]["data_type"]
|
||||
it = "{:04d}".format(int(epoch / self.cfg["train"]["visualize_every"]))
|
||||
|
||||
data_type = self.cfg['data']['data_type']
|
||||
it = '{:04d}'.format(int(epoch/self.cfg['train']['visualize_every']))
|
||||
|
||||
|
||||
if (self.cfg['train']['exp_mesh']) \
|
||||
| (self.cfg['train']['exp_pcl']) \
|
||||
| (self.cfg['train']['o3d_show']):
|
||||
if (self.cfg["train"]["exp_mesh"]) | (self.cfg["train"]["exp_pcl"]) | (self.cfg["train"]["o3d_show"]):
|
||||
psr_grid, points, normals = self.pcl2psr(inputs)
|
||||
|
||||
with torch.no_grad():
|
||||
v, f, n = mc_from_psr(psr_grid, pytorchify=True,
|
||||
zero_level=self.cfg['data']['zero_level'], real_scale=True)
|
||||
v, f, n = mc_from_psr(
|
||||
psr_grid, pytorchify=True, zero_level=self.cfg["data"]["zero_level"], real_scale=True,
|
||||
)
|
||||
v, f, n = v[None], f[None], n[None]
|
||||
|
||||
v = v * 2. - 1. # change to the range of [-1, 1]
|
||||
v = v * 2.0 - 1.0 # change to the range of [-1, 1]
|
||||
|
||||
color_v = None
|
||||
if data_type == 'img':
|
||||
if self.cfg['train']['vis_vert_color'] & \
|
||||
(self.cfg['train']['l_weight']['rgb'] != 0.):
|
||||
color_v = renderer['color'](v, n).squeeze().detach().cpu().numpy()
|
||||
color_v[color_v<0], color_v[color_v>1] = 0., 1.
|
||||
if data_type == "img":
|
||||
if self.cfg["train"]["vis_vert_color"] & (self.cfg["train"]["l_weight"]["rgb"] != 0.0):
|
||||
color_v = renderer["color"](v, n).squeeze().detach().cpu().numpy()
|
||||
color_v[color_v < 0], color_v[color_v > 1] = 0.0, 1.0
|
||||
|
||||
vv = v.detach().squeeze().cpu().numpy()
|
||||
ff = f.detach().squeeze().cpu().numpy()
|
||||
points = points * 2 - 1
|
||||
visualize_points_mesh(o3d_vis, points, normals,
|
||||
vv, ff, self.cfg, it, epoch, color_v=color_v)
|
||||
visualize_points_mesh(o3d_vis, points, normals, vv, ff, self.cfg, it, epoch, color_v=color_v)
|
||||
|
||||
else:
|
||||
v, f, n = inputs
|
||||
|
||||
|
||||
if (data_type == 'img') & (self.cfg['train']['vis_rendering']):
|
||||
if (data_type == "img") & (self.cfg["train"]["vis_rendering"]):
|
||||
pred_imgs = []
|
||||
pred_masks = []
|
||||
n_views = len(data['poses'])
|
||||
len(data["poses"])
|
||||
# idx_list = trange(n_views)
|
||||
idx_list = [13, 24, 27, 48]
|
||||
|
||||
#!
|
||||
model = renderer.eval()
|
||||
for idx in idx_list:
|
||||
pose = data['poses'][idx]
|
||||
rgb = data['rgbs'][idx]
|
||||
mask_gt = data['masks'][idx]
|
||||
img_size = rgb.shape[0] if rgb.shape[0]== rgb.shape[1] else (rgb.shape[0], rgb.shape[1])
|
||||
pose = data["poses"][idx]
|
||||
rgb = data["rgbs"][idx]
|
||||
data["masks"][idx]
|
||||
img_size = rgb.shape[0] if rgb.shape[0] == rgb.shape[1] else (rgb.shape[0], rgb.shape[1])
|
||||
ray = None
|
||||
if 'rays' in data.keys():
|
||||
ray = data['rays'][idx]
|
||||
if self.cfg['train']['l_weight']['rgb'] != 0.:
|
||||
if "rays" in data.keys():
|
||||
ray = data["rays"][idx]
|
||||
if self.cfg["train"]["l_weight"]["rgb"] != 0.0:
|
||||
fea_grid = None
|
||||
if model.unet3d is not None:
|
||||
with torch.no_grad():
|
||||
fea_grid = model.unet3d(psr_grid).permute(0, 2, 3, 4, 1)
|
||||
if model.encoder is not None:
|
||||
pp = torch.cat([(points+1)/2, normals], dim=-1)
|
||||
fea_grid = model.encoder(pp,
|
||||
normalize=False).permute(0, 2, 3, 4, 1)
|
||||
pp = torch.cat([(points + 1) / 2, normals], dim=-1)
|
||||
fea_grid = model.encoder(pp, normalize=False).permute(0, 2, 3, 4, 1)
|
||||
|
||||
pred, visible_mask = render_rgb(v, f, n, pose,
|
||||
model.rendering_network.eval(),
|
||||
img_size, ray=ray, fea_grid=fea_grid)
|
||||
img_pred = torch.zeros([rgb.shape[0]*rgb.shape[1], 3])
|
||||
pred, visible_mask = render_rgb(
|
||||
v, f, n, pose, model.rendering_network.eval(), img_size, ray=ray, fea_grid=fea_grid,
|
||||
)
|
||||
img_pred = torch.zeros([rgb.shape[0] * rgb.shape[1], 3])
|
||||
img_pred[visible_mask] = pred.detach().cpu()
|
||||
|
||||
img_pred = img_pred.reshape(rgb.shape[0], rgb.shape[1], 3)
|
||||
img_pred[img_pred<0], img_pred[img_pred>1] = 0., 1.
|
||||
filename=os.path.join(self.cfg['train']['dir_rendering'],
|
||||
'rendering_{}_{:d}.png'.format(it, idx))
|
||||
img_pred[img_pred < 0], img_pred[img_pred > 1] = 0.0, 1.0
|
||||
filename = os.path.join(self.cfg["train"]["dir_rendering"], f"rendering_{it}_{idx:d}.png")
|
||||
save_image(img_pred.permute(2, 0, 1), filename)
|
||||
pred_imgs.append(img_pred)
|
||||
|
||||
#! Mesh rendering using Phong shading model
|
||||
filename=os.path.join(self.cfg['train']['dir_rendering'],
|
||||
'mesh_{}_{:d}.png'.format(it, idx))
|
||||
filename = os.path.join(self.cfg["train"]["dir_rendering"], f"mesh_{it}_{idx:d}.png")
|
||||
visualize_mesh_phong(v, f, n, pose, img_size, name=filename)
|
||||
|
||||
if len(pred_imgs) >= 1:
|
||||
pred_imgs = torch.stack(pred_imgs, dim=0)
|
||||
save_image(pred_imgs.permute(0, 3, 1, 2),
|
||||
os.path.join(self.cfg['train']['dir_rendering'],
|
||||
'{}.png'.format(it)), nrow=4)
|
||||
if self.cfg['train']['save_video']:
|
||||
write_video(os.path.join(self.cfg['train']['dir_rendering'],
|
||||
'{}.mp4'.format(it)),
|
||||
(pred_imgs*255.).type(torch.uint8), fps=24)
|
||||
save_image(
|
||||
pred_imgs.permute(0, 3, 1, 2), os.path.join(self.cfg["train"]["dir_rendering"], f"{it}.png"), nrow=4,
|
||||
)
|
||||
if self.cfg["train"]["save_video"]:
|
||||
write_video(
|
||||
os.path.join(self.cfg["train"]["dir_rendering"], f"{it}.mp4"),
|
||||
(pred_imgs * 255.0).type(torch.uint8),
|
||||
fps=24,
|
||||
)
|
||||
|
||||
def save_mesh_pointclouds(self, inputs, epoch, center=None, scale=None):
|
||||
''' Save meshes and point clouds.
|
||||
"""Save meshes and point clouds.
|
||||
|
||||
Args:
|
||||
inputs (torch.tensor) : source point clouds
|
||||
epoch (int) : the number of iterations
|
||||
center (numpy.array) : center of the shape
|
||||
scale (numpy.array) : scale of the shape
|
||||
'''
|
||||
|
||||
exp_pcl = self.cfg['train']['exp_pcl']
|
||||
exp_mesh = self.cfg['train']['exp_mesh']
|
||||
scale (numpy.array) : scale of the shape.
|
||||
"""
|
||||
exp_pcl = self.cfg["train"]["exp_pcl"]
|
||||
exp_mesh = self.cfg["train"]["exp_mesh"]
|
||||
|
||||
psr_grid, points, normals = self.pcl2psr(inputs)
|
||||
|
||||
if exp_pcl:
|
||||
dir_pcl = self.cfg['train']['dir_pcl']
|
||||
dir_pcl = self.cfg["train"]["dir_pcl"]
|
||||
p = points.squeeze(0).detach().cpu().numpy()
|
||||
p = p * 2 - 1
|
||||
n = normals.squeeze(0).detach().cpu().numpy()
|
||||
|
@ -327,12 +314,11 @@ class Trainer(object):
|
|||
p *= scale
|
||||
if center is not None:
|
||||
p += center
|
||||
export_pointcloud(os.path.join(dir_pcl, '{:04d}.ply'.format(epoch)), p, n)
|
||||
export_pointcloud(os.path.join(dir_pcl, f"{epoch:04d}.ply"), p, n)
|
||||
if exp_mesh:
|
||||
dir_mesh = self.cfg['train']['dir_mesh']
|
||||
dir_mesh = self.cfg["train"]["dir_mesh"]
|
||||
with torch.no_grad():
|
||||
v, f, _ = mc_from_psr(psr_grid,
|
||||
zero_level=self.cfg['data']['zero_level'], real_scale=True)
|
||||
v, f, _ = mc_from_psr(psr_grid, zero_level=self.cfg["data"]["zero_level"], real_scale=True)
|
||||
v = v * 2 - 1
|
||||
if scale is not None:
|
||||
v *= scale
|
||||
|
@ -341,9 +327,9 @@ class Trainer(object):
|
|||
mesh = o3d.geometry.TriangleMesh()
|
||||
mesh.vertices = o3d.utility.Vector3dVector(v)
|
||||
mesh.triangles = o3d.utility.Vector3iVector(f)
|
||||
outdir_mesh = os.path.join(dir_mesh, '{:04d}.ply'.format(epoch))
|
||||
outdir_mesh = os.path.join(dir_mesh, f"{epoch:04d}.ply")
|
||||
o3d.io.write_triangle_mesh(outdir_mesh, mesh)
|
||||
|
||||
if self.cfg['train']['vis_psr']:
|
||||
dir_psr_vis = self.cfg['train']['out_dir']+'/psr_vis_all'
|
||||
if self.cfg["train"]["vis_psr"]:
|
||||
dir_psr_vis = self.cfg["train"]["out_dir"] + "/psr_vis_all"
|
||||
visualize_psr_grid(psr_grid, out_dir=dir_psr_vis)
|
||||
|
|
153
src/training.py
153
src/training.py
|
@ -1,55 +1,61 @@
|
|||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.loss import chamfer_distance
|
||||
from pytorch3d.ops.knn import knn_gather, knn_points
|
||||
from torch.nn import functional as F
|
||||
from collections import defaultdict
|
||||
import trimesh
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.dpsr import DPSR
|
||||
from src.utils import grid_interp, export_pointcloud, export_mesh, \
|
||||
mc_from_psr, scale2onet, GaussianSmoothing
|
||||
from pytorch3d.ops.knn import knn_gather, knn_points
|
||||
from pytorch3d.loss import chamfer_distance
|
||||
from pdb import set_trace as st
|
||||
from src.utils import (
|
||||
GaussianSmoothing,
|
||||
export_mesh,
|
||||
export_pointcloud,
|
||||
mc_from_psr,
|
||||
scale2onet,
|
||||
)
|
||||
|
||||
class Trainer(object):
|
||||
'''
|
||||
Args:
|
||||
|
||||
class Trainer:
|
||||
"""Args:
|
||||
model (nn.Module): our defined model
|
||||
optimizer (optimizer): pytorch optimizer object
|
||||
device (device): pytorch device
|
||||
input_type (str): input type
|
||||
vis_dir (str): visualization directory
|
||||
'''
|
||||
vis_dir (str): visualization directory.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, optimizer, device=None):
|
||||
self.optimizer = optimizer
|
||||
self.device = device
|
||||
self.cfg = cfg
|
||||
if self.cfg['train']['w_raw'] != 0:
|
||||
if self.cfg["train"]["w_raw"] != 0:
|
||||
from src.model import PSR2Mesh
|
||||
|
||||
self.psr2mesh = PSR2Mesh.apply
|
||||
|
||||
# initialize DPSR
|
||||
self.dpsr = DPSR(res=(cfg['model']['grid_res'],
|
||||
cfg['model']['grid_res'],
|
||||
cfg['model']['grid_res']),
|
||||
sig=cfg['model']['psr_sigma'])
|
||||
self.dpsr = DPSR(
|
||||
res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
|
||||
sig=cfg["model"]["psr_sigma"],
|
||||
)
|
||||
if torch.cuda.device_count() > 1:
|
||||
self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR
|
||||
self.dpsr = self.dpsr.to(device)
|
||||
|
||||
if cfg['train']['gauss_weight']>0.:
|
||||
if cfg["train"]["gauss_weight"] > 0.0:
|
||||
self.gauss_smooth = GaussianSmoothing(1, 7, 2).to(device)
|
||||
|
||||
def train_step(self, inputs, data, model):
|
||||
''' Performs a training step.
|
||||
"""Performs a training step.
|
||||
|
||||
Args:
|
||||
data (dict): data dictionary
|
||||
'''
|
||||
"""
|
||||
self.optimizer.zero_grad()
|
||||
p = data.get('inputs').to(self.device)
|
||||
p = data.get("inputs").to(self.device)
|
||||
|
||||
out = model(p)
|
||||
|
||||
|
@ -57,18 +63,18 @@ class Trainer(object):
|
|||
|
||||
loss = 0
|
||||
loss_each = {}
|
||||
if self.cfg['train']['w_psr'] != 0:
|
||||
psr_gt = data.get('gt_psr').to(self.device)
|
||||
if self.cfg['model']['psr_tanh']:
|
||||
if self.cfg["train"]["w_psr"] != 0:
|
||||
psr_gt = data.get("gt_psr").to(self.device)
|
||||
if self.cfg["model"]["psr_tanh"]:
|
||||
psr_gt = torch.tanh(psr_gt)
|
||||
|
||||
psr_grid = self.dpsr(points, normals)
|
||||
if self.cfg['model']['psr_tanh']:
|
||||
if self.cfg["model"]["psr_tanh"]:
|
||||
psr_grid = torch.tanh(psr_grid)
|
||||
|
||||
# apply a rescaling weight based on GT SDF values
|
||||
if self.cfg['train']['gauss_weight']>0:
|
||||
gauss_sigma = self.cfg['train']['gauss_weight']
|
||||
if self.cfg["train"]["gauss_weight"] > 0:
|
||||
self.cfg["train"]["gauss_weight"]
|
||||
# set up the weighting for loss, higher weights
|
||||
# for points near to the surface
|
||||
psr_gt_pad = torch.nn.ReplicationPad3d(1)(psr_gt.unsqueeze(1)).squeeze(1)
|
||||
|
@ -82,43 +88,43 @@ class Trainer(object):
|
|||
psr_grad = torch.stack([grad_x, grad_y, grad_z], dim=-1)
|
||||
psr_grad_norm = psr_grad.norm(dim=-1)[:, None]
|
||||
w = torch.nn.ReplicationPad3d(3)(psr_grad_norm)
|
||||
w = 2*self.gauss_smooth(w).squeeze(1)
|
||||
loss_each['psr'] = self.cfg['train']['w_psr'] * F.mse_loss(w*psr_grid, w*psr_gt)
|
||||
w = 2 * self.gauss_smooth(w).squeeze(1)
|
||||
loss_each["psr"] = self.cfg["train"]["w_psr"] * F.mse_loss(w * psr_grid, w * psr_gt)
|
||||
else:
|
||||
loss_each['psr'] = self.cfg['train']['w_psr'] * F.mse_loss(psr_grid, psr_gt)
|
||||
loss_each["psr"] = self.cfg["train"]["w_psr"] * F.mse_loss(psr_grid, psr_gt)
|
||||
|
||||
loss += loss_each['psr']
|
||||
loss += loss_each["psr"]
|
||||
|
||||
# regularization on the input point positions via chamfer distance
|
||||
if self.cfg['train']['w_reg_point'] != 0.:
|
||||
points_gt = data.get('gt_points').to(self.device)
|
||||
if self.cfg["train"]["w_reg_point"] != 0.0:
|
||||
points_gt = data.get("gt_points").to(self.device)
|
||||
loss_reg, loss_norm = chamfer_distance(points, points_gt)
|
||||
|
||||
loss_each['reg'] = self.cfg['train']['w_reg_point'] * loss_reg
|
||||
loss += loss_each['reg']
|
||||
loss_each["reg"] = self.cfg["train"]["w_reg_point"] * loss_reg
|
||||
loss += loss_each["reg"]
|
||||
|
||||
if self.cfg['train']['w_normals'] != 0.:
|
||||
points_gt = data.get('gt_points').to(self.device)
|
||||
normals_gt = data.get('gt_points.normals').to(self.device)
|
||||
if self.cfg["train"]["w_normals"] != 0.0:
|
||||
points_gt = data.get("gt_points").to(self.device)
|
||||
normals_gt = data.get("gt_points.normals").to(self.device)
|
||||
x_nn = knn_points(points, points_gt, K=1)
|
||||
x_normals_near = knn_gather(normals_gt, x_nn.idx)[..., 0, :]
|
||||
|
||||
cham_norm_x = F.l1_loss(normals, x_normals_near)
|
||||
loss_norm = cham_norm_x
|
||||
|
||||
loss_each['normals'] = self.cfg['train']['w_normals'] * loss_norm
|
||||
loss += loss_each['normals']
|
||||
loss_each["normals"] = self.cfg["train"]["w_normals"] * loss_norm
|
||||
loss += loss_each["normals"]
|
||||
|
||||
if self.cfg['train']['w_raw'] != 0:
|
||||
res = self.cfg['model']['grid_res']
|
||||
if self.cfg["train"]["w_raw"] != 0:
|
||||
self.cfg["model"]["grid_res"]
|
||||
# DPSR to get grid
|
||||
psr_grid = self.dpsr(points, normals)
|
||||
if self.cfg['model']['psr_tanh']:
|
||||
if self.cfg["model"]["psr_tanh"]:
|
||||
psr_grid = torch.tanh(psr_grid)
|
||||
|
||||
v, f, n = self.psr2mesh(psr_grid)
|
||||
|
||||
pts_gt = data.get('gt_points').to(self.device)
|
||||
pts_gt = data.get("gt_points").to(self.device)
|
||||
|
||||
loss, _ = chamfer_distance(v, pts_gt)
|
||||
|
||||
|
@ -128,51 +134,51 @@ class Trainer(object):
|
|||
return loss.item(), loss_each
|
||||
|
||||
def save(self, model, data, epoch, id):
|
||||
p = data.get("inputs").to(self.device)
|
||||
|
||||
p = data.get('inputs').to(self.device)
|
||||
|
||||
exp_pcl = self.cfg['train']['exp_pcl']
|
||||
exp_mesh = self.cfg['train']['exp_mesh']
|
||||
exp_gt = self.cfg['generation']['exp_gt']
|
||||
exp_input = self.cfg['generation']['exp_input']
|
||||
exp_pcl = self.cfg["train"]["exp_pcl"]
|
||||
exp_mesh = self.cfg["train"]["exp_mesh"]
|
||||
exp_gt = self.cfg["generation"]["exp_gt"]
|
||||
exp_input = self.cfg["generation"]["exp_input"]
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
points, normals = model(p)
|
||||
|
||||
if exp_gt:
|
||||
points_gt = data.get('gt_points').to(self.device)
|
||||
normals_gt = data.get('gt_points.normals').to(self.device)
|
||||
points_gt = data.get("gt_points").to(self.device)
|
||||
normals_gt = data.get("gt_points.normals").to(self.device)
|
||||
|
||||
if exp_pcl:
|
||||
dir_pcl = self.cfg['train']['dir_pcl']
|
||||
export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}.ply'.format(epoch, id)), scale2onet(points), normals)
|
||||
dir_pcl = self.cfg["train"]["dir_pcl"]
|
||||
export_pointcloud(os.path.join(dir_pcl, f"{epoch:04d}_{id:01d}.ply"), scale2onet(points), normals)
|
||||
if exp_gt:
|
||||
export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}_oracle.ply'.format(epoch, id)), scale2onet(points_gt), normals_gt)
|
||||
export_pointcloud(
|
||||
os.path.join(dir_pcl, f"{epoch:04d}_{id:01d}_oracle.ply"), scale2onet(points_gt), normals_gt,
|
||||
)
|
||||
if exp_input:
|
||||
export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}_input.ply'.format(epoch, id)), scale2onet(p))
|
||||
export_pointcloud(os.path.join(dir_pcl, f"{epoch:04d}_{id:01d}_input.ply"), scale2onet(p))
|
||||
|
||||
if exp_mesh:
|
||||
dir_mesh = self.cfg['train']['dir_mesh']
|
||||
dir_mesh = self.cfg["train"]["dir_mesh"]
|
||||
psr_grid = self.dpsr(points, normals)
|
||||
# psr_grid = torch.tanh(psr_grid)
|
||||
with torch.no_grad():
|
||||
v, f, _ = mc_from_psr(psr_grid,
|
||||
zero_level=self.cfg['data']['zero_level'])
|
||||
outdir_mesh = os.path.join(dir_mesh, '{:04d}_{:01d}.ply'.format(epoch, id))
|
||||
v, f, _ = mc_from_psr(psr_grid, zero_level=self.cfg["data"]["zero_level"])
|
||||
outdir_mesh = os.path.join(dir_mesh, f"{epoch:04d}_{id:01d}.ply")
|
||||
export_mesh(outdir_mesh, scale2onet(v), f)
|
||||
if exp_gt:
|
||||
psr_gt = self.dpsr(points_gt, normals_gt)
|
||||
with torch.no_grad():
|
||||
v, f, _ = mc_from_psr(psr_gt,
|
||||
zero_level=self.cfg['data']['zero_level'])
|
||||
export_mesh(os.path.join(dir_mesh, '{:04d}_{:01d}_oracle.ply'.format(epoch, id)), scale2onet(v), f)
|
||||
v, f, _ = mc_from_psr(psr_gt, zero_level=self.cfg["data"]["zero_level"])
|
||||
export_mesh(os.path.join(dir_mesh, f"{epoch:04d}_{id:01d}_oracle.ply"), scale2onet(v), f)
|
||||
|
||||
def evaluate(self, val_loader, model):
|
||||
''' Performs an evaluation.
|
||||
"""Performs an evaluation.
|
||||
|
||||
Args:
|
||||
val_loader (dataloader): pytorch dataloader
|
||||
'''
|
||||
val_loader (dataloader): pytorch dataloader.
|
||||
"""
|
||||
eval_list = defaultdict(list)
|
||||
|
||||
for data in tqdm(val_loader):
|
||||
|
@ -185,15 +191,16 @@ class Trainer(object):
|
|||
return eval_dict
|
||||
|
||||
def eval_step(self, data, model):
|
||||
''' Performs an evaluation step.
|
||||
"""Performs an evaluation step.
|
||||
|
||||
Args:
|
||||
data (dict): data dictionary
|
||||
'''
|
||||
data (dict): data dictionary.
|
||||
"""
|
||||
model.eval()
|
||||
eval_dict = {}
|
||||
|
||||
p = data.get('inputs').to(self.device)
|
||||
psr_gt = data.get('gt_psr').to(self.device)
|
||||
p = data.get("inputs").to(self.device)
|
||||
psr_gt = data.get("gt_psr").to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
# forward pass
|
||||
|
@ -201,7 +208,7 @@ class Trainer(object):
|
|||
# DPSR to get predicted psr grid
|
||||
psr_grid = self.dpsr(points, normals)
|
||||
|
||||
eval_dict['psr_l1'] = F.l1_loss(psr_grid, psr_gt).item()
|
||||
eval_dict['psr_l2'] = F.mse_loss(psr_grid, psr_gt).item()
|
||||
eval_dict["psr_l1"] = F.l1_loss(psr_grid, psr_gt).item()
|
||||
eval_dict["psr_l2"] = F.mse_loss(psr_grid, psr_gt).item()
|
||||
|
||||
return eval_dict
|
335
src/utils.py
335
src/utils.py
|
@ -1,53 +1,53 @@
|
|||
import torch
|
||||
import io, os, logging, urllib
|
||||
import yaml
|
||||
import trimesh
|
||||
import imageio
|
||||
import numbers
|
||||
import logging
|
||||
import math
|
||||
import numpy as np
|
||||
import numbers
|
||||
import os
|
||||
import urllib
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import open3d as o3d
|
||||
import torch
|
||||
import trimesh
|
||||
import yaml
|
||||
from igl import adjacency_matrix, connected_components
|
||||
from plyfile import PlyData
|
||||
from pytorch3d.renderer import PerspectiveCameras, rasterize_meshes
|
||||
from pytorch3d.structures import Meshes
|
||||
from skimage import measure
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.utils import model_zoo
|
||||
from skimage import measure, img_as_float32
|
||||
from pytorch3d.structures import Meshes
|
||||
from pytorch3d.renderer import PerspectiveCameras, rasterize_meshes
|
||||
from igl import adjacency_matrix, connected_components
|
||||
import open3d as o3d
|
||||
|
||||
##################################################
|
||||
# Below are functions for DPSR
|
||||
|
||||
|
||||
def fftfreqs(res, dtype=torch.float32, exact=True):
|
||||
"""
|
||||
Helper function to return frequency tensors
|
||||
"""Helper function to return frequency tensors
|
||||
:param res: n_dims int tuple of number of frequency modes
|
||||
:return:
|
||||
"""
|
||||
|
||||
n_dims = len(res)
|
||||
freqs = []
|
||||
for dim in range(n_dims - 1):
|
||||
r_ = res[dim]
|
||||
freq = np.fft.fftfreq(r_, d=1/r_)
|
||||
freq = np.fft.fftfreq(r_, d=1 / r_)
|
||||
freqs.append(torch.tensor(freq, dtype=dtype))
|
||||
r_ = res[-1]
|
||||
if exact:
|
||||
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype))
|
||||
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1 / r_), dtype=dtype))
|
||||
else:
|
||||
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype))
|
||||
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1 / r_)[:-1], dtype=dtype))
|
||||
omega = torch.meshgrid(freqs)
|
||||
omega = list(omega)
|
||||
omega = torch.stack(omega, dim=-1)
|
||||
|
||||
return omega
|
||||
|
||||
|
||||
def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag)
|
||||
"""
|
||||
multiply tensor x by i ** deg
|
||||
"""
|
||||
"""Multiply tensor x by i ** deg."""
|
||||
deg %= 4
|
||||
if deg == 0:
|
||||
res = x
|
||||
|
@ -61,17 +61,18 @@ def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag)
|
|||
res[..., 1] = -res[..., 1]
|
||||
return res
|
||||
|
||||
|
||||
def spec_gaussian_filter(res, sig):
|
||||
omega = fftfreqs(res, dtype=torch.float64) # [dim0, dim1, dim2, d]
|
||||
dis = torch.sqrt(torch.sum(omega ** 2, dim=-1))
|
||||
filter_ = torch.exp(-0.5*((sig*2*dis/res[0])**2)).unsqueeze(-1).unsqueeze(-1)
|
||||
dis = torch.sqrt(torch.sum(omega**2, dim=-1))
|
||||
filter_ = torch.exp(-0.5 * ((sig * 2 * dis / res[0]) ** 2)).unsqueeze(-1).unsqueeze(-1)
|
||||
filter_.requires_grad = False
|
||||
|
||||
return filter_
|
||||
|
||||
|
||||
def grid_interp(grid, pts, batched=True):
|
||||
"""
|
||||
:param grid: tensor of shape (batch, *size, in_features)
|
||||
""":param grid: tensor of shape (batch, *size, in_features)
|
||||
:param pts: tensor of shape (batch, num_points, dim) within range (0, 1)
|
||||
:return values at query points
|
||||
"""
|
||||
|
@ -86,7 +87,7 @@ def grid_interp(grid, pts, batched=True):
|
|||
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
|
||||
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
|
||||
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
|
||||
tmp = torch.tensor([0,1],dtype=torch.long)
|
||||
tmp = torch.tensor([0, 1], dtype=torch.long)
|
||||
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
|
||||
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
|
||||
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
|
||||
|
@ -96,14 +97,16 @@ def grid_interp(grid, pts, batched=True):
|
|||
if dim == 2:
|
||||
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1]] # (batch, num_points, 2**dim, in_features)
|
||||
else:
|
||||
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2]] # (batch, num_points, 2**dim, in_features)
|
||||
lat = grid.clone()[
|
||||
ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2],
|
||||
] # (batch, num_points, 2**dim, in_features)
|
||||
|
||||
# weights of neighboring nodes
|
||||
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
|
||||
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
|
||||
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
|
||||
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
||||
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
||||
xyz01[com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
|
||||
pos_ = xyz01[1 - com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
|
||||
pos_ = pos_.type(pts.dtype)
|
||||
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
|
||||
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
|
||||
|
@ -113,38 +116,38 @@ def grid_interp(grid, pts, batched=True):
|
|||
|
||||
return query_values
|
||||
|
||||
|
||||
def scatter_to_grid(inds, vals, size):
|
||||
"""
|
||||
Scatter update values into empty tensor of size size.
|
||||
"""Scatter update values into empty tensor of size size.
|
||||
:param inds: (#values, dims)
|
||||
:param vals: (#values)
|
||||
:param size: tuple for size. len(size)=dims
|
||||
:param size: tuple for size. len(size)=dims.
|
||||
"""
|
||||
dims = inds.shape[1]
|
||||
assert(inds.shape[0] == vals.shape[0])
|
||||
assert(len(size) == dims)
|
||||
assert inds.shape[0] == vals.shape[0]
|
||||
assert len(size) == dims
|
||||
dev = vals.device
|
||||
# result = torch.zeros(*size).view(-1).to(dev).type(vals.dtype) # flatten
|
||||
# # flatten inds
|
||||
result = torch.zeros(*size, device=dev).view(-1).type(vals.dtype) # flatten
|
||||
# flatten inds
|
||||
fac = [np.prod(size[i+1:]) for i in range(len(size)-1)] + [1]
|
||||
fac = [np.prod(size[i + 1 :]) for i in range(len(size) - 1)] + [1]
|
||||
fac = torch.tensor(fac, device=dev).type(inds.dtype)
|
||||
inds_fold = torch.sum(inds*fac, dim=-1) # [#values,]
|
||||
inds_fold = torch.sum(inds * fac, dim=-1) # [#values,]
|
||||
result.scatter_add_(0, inds_fold, vals)
|
||||
result = result.view(*size)
|
||||
return result
|
||||
|
||||
|
||||
def point_rasterize(pts, vals, size):
|
||||
"""
|
||||
:param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1)
|
||||
""":param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1)
|
||||
:param vals: point values, tensor of shape (batch, num_points, features)
|
||||
:param size: len(size)=dim tuple for grid size
|
||||
:return rasterized values (batch, features, res0, res1, res2)
|
||||
"""
|
||||
dim = pts.shape[-1]
|
||||
assert(pts.shape[:2] == vals.shape[:2])
|
||||
assert(pts.shape[2] == dim)
|
||||
assert pts.shape[:2] == vals.shape[:2]
|
||||
assert pts.shape[2] == dim
|
||||
size_list = list(size)
|
||||
size = torch.tensor(size).to(pts.device).float()
|
||||
cubesize = 1.0 / size
|
||||
|
@ -156,20 +159,22 @@ def point_rasterize(pts, vals, size):
|
|||
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
|
||||
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
|
||||
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
|
||||
tmp = torch.tensor([0,1],dtype=torch.long)
|
||||
tmp = torch.tensor([0, 1], dtype=torch.long)
|
||||
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
|
||||
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
|
||||
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
|
||||
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
|
||||
# ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
|
||||
ind_b = torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
|
||||
ind_b = (
|
||||
torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1)
|
||||
) # (batch, num_points, 2**dim)
|
||||
|
||||
# weights of neighboring nodes
|
||||
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
|
||||
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
|
||||
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
|
||||
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
||||
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
||||
xyz01[com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
|
||||
pos_ = xyz01[1 - com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
|
||||
pos_ = pos_.type(pts.dtype)
|
||||
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
|
||||
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
|
||||
|
@ -187,20 +192,21 @@ def point_rasterize(pts, vals, size):
|
|||
# weighted values
|
||||
vals = weights.unsqueeze(-1) * vals.unsqueeze(-2) # (batch, num_points, 2**dim, nf)
|
||||
|
||||
inds = inds.view(-1, dim+2).permute(1, 0).long() # (1+dim+1, bs*npts*2**dim*nf)
|
||||
inds = inds.view(-1, dim + 2).permute(1, 0).long() # (1+dim+1, bs*npts*2**dim*nf)
|
||||
vals = vals.reshape(-1) # (bs*npts*2**dim*nf)
|
||||
tensor_size = [bs, nf] + size_list
|
||||
raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf] + size_list)
|
||||
[bs, nf, *size_list]
|
||||
raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf, *size_list])
|
||||
|
||||
return raster # [batch, nf, res, res, res]
|
||||
|
||||
|
||||
|
||||
##################################################
|
||||
# Below are the utilization functions in general
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value."""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
|
@ -226,47 +232,47 @@ class AverageMeter(object):
|
|||
def avgcavg(self):
|
||||
return self.avg.sum().item() / (self.count != 0).sum().item()
|
||||
|
||||
|
||||
def load_model_manual(state_dict, model):
|
||||
new_state_dict = OrderedDict()
|
||||
is_model_parallel = isinstance(model, torch.nn.DataParallel)
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith('module.') != is_model_parallel:
|
||||
if k.startswith('module.'):
|
||||
if k.startswith("module.") != is_model_parallel:
|
||||
if k.startswith("module."):
|
||||
# remove module
|
||||
k = k[7:]
|
||||
else:
|
||||
# add module
|
||||
k = 'module.' + k
|
||||
k = "module." + k
|
||||
|
||||
new_state_dict[k]=v
|
||||
new_state_dict[k] = v
|
||||
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
|
||||
def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0):
|
||||
'''
|
||||
Run marching cubes from PSR grid
|
||||
'''
|
||||
"""Run marching cubes from PSR grid."""
|
||||
batch_size = psr_grid.shape[0]
|
||||
s = psr_grid.shape[-1] # size of psr_grid
|
||||
psr_grid_numpy = psr_grid.squeeze().detach().cpu().numpy()
|
||||
|
||||
if batch_size>1:
|
||||
if batch_size > 1:
|
||||
verts, faces, normals = [], [], []
|
||||
for i in range(batch_size):
|
||||
verts_cur, faces_cur, normals_cur, values = measure.marching_cubes(psr_grid_numpy[i], level=0)
|
||||
verts.append(verts_cur)
|
||||
faces.append(faces_cur)
|
||||
normals.append(normals_cur)
|
||||
verts = np.stack(verts, axis = 0)
|
||||
faces = np.stack(faces, axis = 0)
|
||||
normals = np.stack(normals, axis = 0)
|
||||
verts = np.stack(verts, axis=0)
|
||||
faces = np.stack(faces, axis=0)
|
||||
normals = np.stack(normals, axis=0)
|
||||
else:
|
||||
try:
|
||||
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy, level=zero_level)
|
||||
except:
|
||||
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy)
|
||||
if real_scale:
|
||||
verts = verts / (s-1) # scale to range [0, 1]
|
||||
verts = verts / (s - 1) # scale to range [0, 1]
|
||||
else:
|
||||
verts = verts / s # scale to range [0, 1)
|
||||
|
||||
|
@ -278,6 +284,7 @@ def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0):
|
|||
|
||||
return verts, faces, normals
|
||||
|
||||
|
||||
def calc_inters_points(verts, faces, pose, img_size, mask_gt=None):
|
||||
verts = verts.squeeze()
|
||||
faces = faces.squeeze()
|
||||
|
@ -293,37 +300,34 @@ def calc_inters_points(verts, faces, pose, img_size, mask_gt=None):
|
|||
v_a, v_b, v_c = verts[f_p[..., 0]], verts[f_p[..., 1]], verts[f_p[..., 2]]
|
||||
|
||||
# calculate the intersection point of each pixel and the mesh
|
||||
p_inters = w_masked[..., 0, None] * v_a + \
|
||||
w_masked[..., 1, None] * v_b + \
|
||||
w_masked[..., 2, None] * v_c
|
||||
p_inters = w_masked[..., 0, None] * v_a + w_masked[..., 1, None] * v_b + w_masked[..., 2, None] * v_c
|
||||
else:
|
||||
# backproject ndc to world coordinates using z-buffer
|
||||
W, H = img_size[1], img_size[0]
|
||||
xy = uv.to(mask.device)[mask]
|
||||
x_ndc = 1 - (2*xy[:, 0]) / (W - 1)
|
||||
y_ndc = 1 - (2*xy[:, 1]) / (H - 1)
|
||||
x_ndc = 1 - (2 * xy[:, 0]) / (W - 1)
|
||||
y_ndc = 1 - (2 * xy[:, 1]) / (H - 1)
|
||||
z = zbuf.squeeze().reshape(H * W)[mask]
|
||||
xy_depth = torch.stack((x_ndc, y_ndc, z), dim=1)
|
||||
|
||||
p_inters = pose.unproject_points(xy_depth, world_coordinates=True)
|
||||
|
||||
# if there are outlier points, we should remove it
|
||||
if (p_inters.max()>1) | (p_inters.min()<-1):
|
||||
mask_bound = (p_inters>=-1) & (p_inters<=1)
|
||||
mask_bound = (mask_bound.sum(dim=-1)==3)
|
||||
mask[mask==True] = mask_bound
|
||||
if (p_inters.max() > 1) | (p_inters.min() < -1):
|
||||
mask_bound = (p_inters >= -1) & (p_inters <= 1)
|
||||
mask_bound = mask_bound.sum(dim=-1) == 3
|
||||
mask[mask is True] = mask_bound
|
||||
p_inters = p_inters[mask_bound]
|
||||
print('!!!!!find outlier!')
|
||||
print("!!!!!find outlier!")
|
||||
|
||||
return p_inters, mask, f_p, w_masked
|
||||
|
||||
|
||||
def mesh_rasterization(verts, faces, pose, img_size):
|
||||
'''
|
||||
Use PyTorch3D to rasterize the mesh given a camera
|
||||
'''
|
||||
"""Use PyTorch3D to rasterize the mesh given a camera."""
|
||||
transformed_v = pose.transform_points(verts.detach()) # world -> ndc coordinate system
|
||||
if isinstance(pose, PerspectiveCameras):
|
||||
transformed_v[..., 2] = 1/transformed_v[..., 2]
|
||||
transformed_v[..., 2] = 1 / transformed_v[..., 2]
|
||||
# find p_closest on mesh of each pixel via rasterization
|
||||
transformed_mesh = Meshes(verts=[transformed_v], faces=[faces])
|
||||
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
|
||||
|
@ -331,7 +335,7 @@ def mesh_rasterization(verts, faces, pose, img_size):
|
|||
image_size=img_size,
|
||||
blur_radius=0,
|
||||
faces_per_pixel=1,
|
||||
perspective_correct=False
|
||||
perspective_correct=False,
|
||||
)
|
||||
pix_to_face = pix_to_face.reshape(1, -1) # B x reso x reso -> B x (reso x reso)
|
||||
mask = pix_to_face.clone() != -1
|
||||
|
@ -341,11 +345,11 @@ def mesh_rasterization(verts, faces, pose, img_size):
|
|||
|
||||
return pix_to_face, w, mask
|
||||
|
||||
|
||||
def verts_on_largest_mesh(verts, faces):
|
||||
'''
|
||||
verts: Numpy array or Torch.Tensor (N, 3)
|
||||
faces: Numpy array (N, 3)
|
||||
'''
|
||||
"""verts: Numpy array or Torch.Tensor (N, 3)
|
||||
faces: Numpy array (N, 3).
|
||||
"""
|
||||
if torch.is_tensor(faces):
|
||||
verts = verts.squeeze().detach().cpu().numpy()
|
||||
faces = faces.squeeze().int().detach().cpu().numpy()
|
||||
|
@ -356,7 +360,7 @@ def verts_on_largest_mesh(verts, faces):
|
|||
v_large, f_large = verts, faces
|
||||
else:
|
||||
max_idx = conn_size.argmax() # find the index of the largest component
|
||||
v_large = verts[conn_idx==max_idx] # keep points on the largest component
|
||||
v_large = verts[conn_idx == max_idx] # keep points on the largest component
|
||||
|
||||
if True:
|
||||
mesh_largest = trimesh.Trimesh(verts, faces)
|
||||
|
@ -366,36 +370,41 @@ def verts_on_largest_mesh(verts, faces):
|
|||
v_large = v_large.astype(np.float32)
|
||||
return v_large, f_large
|
||||
|
||||
|
||||
def load_pointcloud(in_file):
|
||||
plydata = PlyData.read(in_file)
|
||||
vertices = np.stack([
|
||||
plydata['vertex']['x'],
|
||||
plydata['vertex']['y'],
|
||||
plydata['vertex']['z']
|
||||
], axis=1)
|
||||
vertices = np.stack(
|
||||
[
|
||||
plydata["vertex"]["x"],
|
||||
plydata["vertex"]["y"],
|
||||
plydata["vertex"]["z"],
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
return vertices
|
||||
|
||||
|
||||
# General config
|
||||
def load_config(path, default_path=None):
|
||||
''' Loads config file.
|
||||
"""Loads config file.
|
||||
|
||||
Args:
|
||||
path (str): path to config file
|
||||
default_path (bool): whether to use default path
|
||||
'''
|
||||
"""
|
||||
# Load configuration from file itself
|
||||
with open(path, 'r') as f:
|
||||
with open(path) as f:
|
||||
cfg_special = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
# Check if we should inherit from a config
|
||||
inherit_from = cfg_special.get('inherit_from')
|
||||
inherit_from = cfg_special.get("inherit_from")
|
||||
|
||||
# If yes, load this config first as default
|
||||
# If no, use the default_path
|
||||
if inherit_from is not None:
|
||||
cfg = load_config(inherit_from, default_path)
|
||||
elif default_path is not None:
|
||||
with open(default_path, 'r') as f:
|
||||
with open(default_path) as f:
|
||||
cfg = yaml.load(f, Loader=yaml.Loader)
|
||||
else:
|
||||
cfg = dict()
|
||||
|
@ -405,65 +414,67 @@ def load_config(path, default_path=None):
|
|||
|
||||
return cfg
|
||||
|
||||
|
||||
def update_config(config, unknown):
|
||||
# update config given args
|
||||
for idx,arg in enumerate(unknown):
|
||||
for idx, arg in enumerate(unknown):
|
||||
if arg.startswith("--"):
|
||||
keys = arg.replace("--","").split(':')
|
||||
assert(len(keys)==2)
|
||||
keys = arg.replace("--", "").split(":")
|
||||
assert len(keys) == 2
|
||||
k1, k2 = keys
|
||||
argtype = type(config[k1][k2])
|
||||
if argtype == bool:
|
||||
v = unknown[idx+1].lower() == 'true'
|
||||
v = unknown[idx + 1].lower() == "true"
|
||||
else:
|
||||
if config[k1][k2] is not None:
|
||||
v = type(config[k1][k2])(unknown[idx+1])
|
||||
v = type(config[k1][k2])(unknown[idx + 1])
|
||||
else:
|
||||
v = unknown[idx+1]
|
||||
print(f'Changing {k1}:{k2} ---- {config[k1][k2]} to {v}')
|
||||
v = unknown[idx + 1]
|
||||
print(f"Changing {k1}:{k2} ---- {config[k1][k2]} to {v}")
|
||||
config[k1][k2] = v
|
||||
return config
|
||||
|
||||
|
||||
def initialize_logger(cfg):
|
||||
out_dir = cfg['train']['out_dir']
|
||||
out_dir = cfg["train"]["out_dir"]
|
||||
if not out_dir:
|
||||
os.makedirs(out_dir)
|
||||
|
||||
cfg['train']['dir_model'] = os.path.join(out_dir, 'model')
|
||||
os.makedirs(cfg['train']['dir_model'], exist_ok=True)
|
||||
|
||||
if cfg['train']['exp_mesh']:
|
||||
cfg['train']['dir_mesh'] = os.path.join(out_dir, 'vis/mesh')
|
||||
os.makedirs(cfg['train']['dir_mesh'], exist_ok=True)
|
||||
if cfg['train']['exp_pcl']:
|
||||
cfg['train']['dir_pcl'] = os.path.join(out_dir, 'vis/pointcloud')
|
||||
os.makedirs(cfg['train']['dir_pcl'], exist_ok=True)
|
||||
if cfg['train']['vis_rendering']:
|
||||
cfg['train']['dir_rendering'] = os.path.join(out_dir, 'vis/rendering')
|
||||
os.makedirs(cfg['train']['dir_rendering'], exist_ok=True)
|
||||
if cfg['train']['o3d_show']:
|
||||
cfg['train']['dir_o3d'] = os.path.join(out_dir, 'vis/o3d')
|
||||
os.makedirs(cfg['train']['dir_o3d'], exist_ok=True)
|
||||
cfg["train"]["dir_model"] = os.path.join(out_dir, "model")
|
||||
os.makedirs(cfg["train"]["dir_model"], exist_ok=True)
|
||||
|
||||
if cfg["train"]["exp_mesh"]:
|
||||
cfg["train"]["dir_mesh"] = os.path.join(out_dir, "vis/mesh")
|
||||
os.makedirs(cfg["train"]["dir_mesh"], exist_ok=True)
|
||||
if cfg["train"]["exp_pcl"]:
|
||||
cfg["train"]["dir_pcl"] = os.path.join(out_dir, "vis/pointcloud")
|
||||
os.makedirs(cfg["train"]["dir_pcl"], exist_ok=True)
|
||||
if cfg["train"]["vis_rendering"]:
|
||||
cfg["train"]["dir_rendering"] = os.path.join(out_dir, "vis/rendering")
|
||||
os.makedirs(cfg["train"]["dir_rendering"], exist_ok=True)
|
||||
if cfg["train"]["o3d_show"]:
|
||||
cfg["train"]["dir_o3d"] = os.path.join(out_dir, "vis/o3d")
|
||||
os.makedirs(cfg["train"]["dir_o3d"], exist_ok=True)
|
||||
|
||||
logger = logging.getLogger("train")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.handlers = []
|
||||
# ch = logging.StreamHandler()
|
||||
# logger.addHandler(ch)
|
||||
fh = logging.FileHandler(os.path.join(cfg['train']['out_dir'], "log.txt"))
|
||||
fh = logging.FileHandler(os.path.join(cfg["train"]["out_dir"], "log.txt"))
|
||||
logger.addHandler(fh)
|
||||
logger.info('Outout dir: %s', out_dir)
|
||||
logger.info("Outout dir: %s", out_dir)
|
||||
return logger
|
||||
|
||||
|
||||
def update_recursive(dict1, dict2):
|
||||
''' Update two config dictionaries recursively.
|
||||
"""Update two config dictionaries recursively.
|
||||
|
||||
Args:
|
||||
dict1 (dict): first dictionary to be updated
|
||||
dict2 (dict): second dictionary which entries should be used
|
||||
|
||||
'''
|
||||
"""
|
||||
for k, v in dict2.items():
|
||||
if k not in dict1:
|
||||
dict1[k] = dict()
|
||||
|
@ -472,6 +483,7 @@ def update_recursive(dict1, dict2):
|
|||
else:
|
||||
dict1[k] = v
|
||||
|
||||
|
||||
def export_pointcloud(name, points, normals=None):
|
||||
if len(points.shape) > 2:
|
||||
points = points[0]
|
||||
|
@ -487,6 +499,7 @@ def export_pointcloud(name, points, normals=None):
|
|||
pcd.normals = o3d.utility.Vector3dVector(normals)
|
||||
o3d.io.write_point_cloud(name, pcd)
|
||||
|
||||
|
||||
def export_mesh(name, v, f):
|
||||
if len(v.shape) > 2:
|
||||
v, f = v[0], f[0]
|
||||
|
@ -498,59 +511,63 @@ def export_mesh(name, v, f):
|
|||
mesh.triangles = o3d.utility.Vector3iVector(f)
|
||||
o3d.io.write_triangle_mesh(name, mesh)
|
||||
|
||||
|
||||
def scale2onet(p, scale=1.2):
|
||||
'''
|
||||
Scale the point cloud from SAP to ONet range
|
||||
'''
|
||||
"""Scale the point cloud from SAP to ONet range."""
|
||||
return (p - 0.5) * scale
|
||||
|
||||
|
||||
def update_optimizer(inputs, cfg, epoch, model=None, schedule=None):
|
||||
if model is not None:
|
||||
if schedule is not None:
|
||||
optimizer = torch.optim.Adam([
|
||||
{"params": model.parameters(),
|
||||
"lr": schedule[0].get_learning_rate(epoch)},
|
||||
{"params": inputs,
|
||||
"lr": schedule[1].get_learning_rate(epoch)}])
|
||||
elif 'lr' in cfg['train']:
|
||||
optimizer = torch.optim.Adam([
|
||||
{"params": model.parameters(),
|
||||
"lr": float(cfg['train']['lr'])},
|
||||
{"params": inputs,
|
||||
"lr": float(cfg['train']['lr_pcl'])}])
|
||||
optimizer = torch.optim.Adam(
|
||||
[
|
||||
{"params": model.parameters(), "lr": schedule[0].get_learning_rate(epoch)},
|
||||
{"params": inputs, "lr": schedule[1].get_learning_rate(epoch)},
|
||||
],
|
||||
)
|
||||
elif "lr" in cfg["train"]:
|
||||
optimizer = torch.optim.Adam(
|
||||
[
|
||||
{"params": model.parameters(), "lr": float(cfg["train"]["lr"])},
|
||||
{"params": inputs, "lr": float(cfg["train"]["lr_pcl"])},
|
||||
],
|
||||
)
|
||||
else:
|
||||
raise Exception('no known learning rate')
|
||||
msg = "no known learning rate"
|
||||
raise Exception(msg)
|
||||
else:
|
||||
if schedule is not None:
|
||||
optimizer = torch.optim.Adam([inputs], lr=schedule[0].get_learning_rate(epoch))
|
||||
else:
|
||||
optimizer = torch.optim.Adam([inputs], lr=float(cfg['train']['lr_pcl']))
|
||||
optimizer = torch.optim.Adam([inputs], lr=float(cfg["train"]["lr_pcl"]))
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
def is_url(url):
|
||||
scheme = urllib.parse.urlparse(url).scheme
|
||||
return scheme in ('http', 'https')
|
||||
return scheme in ("http", "https")
|
||||
|
||||
|
||||
def load_url(url):
|
||||
'''Load a module dictionary from url.
|
||||
"""Load a module dictionary from url.
|
||||
|
||||
Args:
|
||||
url (str): url to saved model
|
||||
'''
|
||||
"""
|
||||
print(url)
|
||||
print('=> Loading checkpoint from url...')
|
||||
print("=> Loading checkpoint from url...")
|
||||
state_dict = model_zoo.load_url(url, progress=True)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
class GaussianSmoothing(nn.Module):
|
||||
"""
|
||||
Apply gaussian smoothing on a
|
||||
"""Apply gaussian smoothing on a
|
||||
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
|
||||
in the input using a depthwise convolution.
|
||||
|
||||
Arguments:
|
||||
channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well.
|
||||
kernel_size (int, sequence): Size of the gaussian kernel.
|
||||
|
@ -558,8 +575,9 @@ class GaussianSmoothing(nn.Module):
|
|||
dim (int, optional): The number of dimensions of the data.
|
||||
Default value is 2 (spatial).
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size, sigma, dim=3):
|
||||
super(GaussianSmoothing, self).__init__()
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, numbers.Number):
|
||||
kernel_size = [kernel_size] * dim
|
||||
if isinstance(sigma, numbers.Number):
|
||||
|
@ -569,15 +587,11 @@ class GaussianSmoothing(nn.Module):
|
|||
# gaussian function of each dimension.
|
||||
kernel = 1
|
||||
meshgrids = torch.meshgrid(
|
||||
[
|
||||
torch.arange(size, dtype=torch.float32)
|
||||
for size in kernel_size
|
||||
]
|
||||
[torch.arange(size, dtype=torch.float32) for size in kernel_size],
|
||||
)
|
||||
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
||||
mean = (size - 1) / 2
|
||||
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
|
||||
torch.exp(-((mgrid - mean) / std) ** 2 / 2)
|
||||
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / std) ** 2) / 2)
|
||||
|
||||
# Make sure sum of values in gaussian kernel equals 1.
|
||||
kernel = kernel / torch.sum(kernel)
|
||||
|
@ -586,7 +600,7 @@ class GaussianSmoothing(nn.Module):
|
|||
kernel = kernel.view(1, 1, *kernel.size())
|
||||
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
||||
|
||||
self.register_buffer('weight', kernel)
|
||||
self.register_buffer("weight", kernel)
|
||||
self.groups = channels
|
||||
|
||||
if dim == 1:
|
||||
|
@ -596,36 +610,44 @@ class GaussianSmoothing(nn.Module):
|
|||
elif dim == 3:
|
||||
self.conv = F.conv3d
|
||||
else:
|
||||
msg = f"Only 1, 2 and 3 dimensions are supported. Received {dim}."
|
||||
raise RuntimeError(
|
||||
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
|
||||
msg,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
Apply gaussian filter to input.
|
||||
"""Apply gaussian filter to input.
|
||||
|
||||
Arguments:
|
||||
input (torch.Tensor): Input to apply gaussian filter on.
|
||||
|
||||
Returns:
|
||||
filtered (torch.Tensor): Filtered output.
|
||||
"""
|
||||
return self.conv(input, weight=self.weight, groups=self.groups)
|
||||
|
||||
|
||||
# Originally from https://github.com/amosgropp/IGR/blob/0db06b1273/code/utils/general.py
|
||||
def get_learning_rate_schedules(schedule_specs):
|
||||
|
||||
schedules = []
|
||||
|
||||
for key in schedule_specs.keys():
|
||||
schedules.append(StepLearningRateSchedule(
|
||||
schedule_specs[key]['initial'],
|
||||
schedules.append(
|
||||
StepLearningRateSchedule(
|
||||
schedule_specs[key]["initial"],
|
||||
schedule_specs[key]["interval"],
|
||||
schedule_specs[key]["factor"],
|
||||
schedule_specs[key]["final"]))
|
||||
schedule_specs[key]["final"],
|
||||
),
|
||||
)
|
||||
return schedules
|
||||
|
||||
|
||||
class LearningRateSchedule:
|
||||
def get_learning_rate(self, epoch):
|
||||
pass
|
||||
|
||||
|
||||
class StepLearningRateSchedule(LearningRateSchedule):
|
||||
def __init__(self, initial, interval, factor, final=1e-6):
|
||||
self.initial = float(initial)
|
||||
|
@ -640,6 +662,7 @@ class StepLearningRateSchedule(LearningRateSchedule):
|
|||
else:
|
||||
return self.final
|
||||
|
||||
|
||||
def adjust_learning_rate(lr_schedules, optimizer, epoch):
|
||||
for i, param_group in enumerate(optimizer.param_groups):
|
||||
param_group["lr"] = lr_schedules[i].get_learning_rate(epoch)
|
||||
|
|
106
src/visualize.py
106
src/visualize.py
|
@ -1,41 +1,41 @@
|
|||
import os
|
||||
import torch
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import open3d as o3d
|
||||
import matplotlib.pyplot as plt
|
||||
from skimage import measure
|
||||
from src.utils import calc_inters_points, grid_interp
|
||||
import torch
|
||||
from scipy import ndimage
|
||||
from tqdm import trange
|
||||
from torchvision.utils import save_image
|
||||
from pdb import set_trace as st
|
||||
from tqdm import trange
|
||||
|
||||
from src.utils import calc_inters_points, grid_interp
|
||||
|
||||
|
||||
def visualize_points_mesh(vis, points, normals, verts, faces, cfg, it, epoch, color_v=None):
|
||||
''' Visualization.
|
||||
"""Visualization.
|
||||
|
||||
Args:
|
||||
data (dict): data dictionary
|
||||
depth (int): PSR depth
|
||||
out_path (str): output path for the mesh
|
||||
'''
|
||||
"""
|
||||
mesh = o3d.geometry.TriangleMesh()
|
||||
mesh.vertices = o3d.utility.Vector3dVector(verts)
|
||||
mesh.triangles = o3d.utility.Vector3iVector(faces)
|
||||
mesh.paint_uniform_color(np.array([0.7,0.7,0.7]))
|
||||
mesh.paint_uniform_color(np.array([0.7, 0.7, 0.7]))
|
||||
if color_v is not None:
|
||||
mesh.vertex_colors = o3d.utility.Vector3dVector(color_v)
|
||||
|
||||
if vis is not None:
|
||||
dir_o3d = cfg['train']['dir_o3d']
|
||||
wire = o3d.geometry.LineSet.create_from_triangle_mesh(mesh)
|
||||
dir_o3d = cfg["train"]["dir_o3d"]
|
||||
o3d.geometry.LineSet.create_from_triangle_mesh(mesh)
|
||||
|
||||
p = points.squeeze(0).detach().cpu().numpy()
|
||||
n = normals.squeeze(0).detach().cpu().numpy()
|
||||
pcd = o3d.geometry.PointCloud()
|
||||
pcd.points = o3d.utility.Vector3dVector(p)
|
||||
pcd.normals = o3d.utility.Vector3dVector(n)
|
||||
pcd.paint_uniform_color(np.array([0.7,0.7,1.0]))
|
||||
pcd.paint_uniform_color(np.array([0.7, 0.7, 1.0]))
|
||||
# pcd = pcd.uniform_down_sample(5)
|
||||
|
||||
vis.clear_geometries()
|
||||
|
@ -43,53 +43,51 @@ def visualize_points_mesh(vis, points, normals, verts, faces, cfg, it, epoch, co
|
|||
vis.update_geometry(mesh)
|
||||
|
||||
#! Thingi wheel - an example for how to change cameras in Open3D viewers
|
||||
vis.get_view_control().set_front([ 0.0461, -0.7467, 0.6635 ])
|
||||
vis.get_view_control().set_lookat([ 0.0092, 0.0078, 0.0638 ])
|
||||
vis.get_view_control().set_up([ 0.0520, 0.6651, 0.7449 ])
|
||||
vis.get_view_control().set_front([0.0461, -0.7467, 0.6635])
|
||||
vis.get_view_control().set_lookat([0.0092, 0.0078, 0.0638])
|
||||
vis.get_view_control().set_up([0.0520, 0.6651, 0.7449])
|
||||
vis.get_view_control().set_zoom(0.7)
|
||||
vis.poll_events()
|
||||
|
||||
out_path = os.path.join(dir_o3d, '{}.jpg'.format(it))
|
||||
out_path = os.path.join(dir_o3d, f"{it}.jpg")
|
||||
vis.capture_screen_image(out_path)
|
||||
|
||||
vis.clear_geometries()
|
||||
vis.add_geometry(pcd, reset_bounding_box=False)
|
||||
vis.update_geometry(pcd)
|
||||
vis.get_render_option().point_show_normal=True # visualize point normals
|
||||
vis.get_render_option().point_show_normal = True # visualize point normals
|
||||
|
||||
vis.get_view_control().set_front([ 0.0461, -0.7467, 0.6635 ])
|
||||
vis.get_view_control().set_lookat([ 0.0092, 0.0078, 0.0638 ])
|
||||
vis.get_view_control().set_up([ 0.0520, 0.6651, 0.7449 ])
|
||||
vis.get_view_control().set_front([0.0461, -0.7467, 0.6635])
|
||||
vis.get_view_control().set_lookat([0.0092, 0.0078, 0.0638])
|
||||
vis.get_view_control().set_up([0.0520, 0.6651, 0.7449])
|
||||
vis.get_view_control().set_zoom(0.7)
|
||||
vis.poll_events()
|
||||
|
||||
out_path = os.path.join(dir_o3d, '{}_pcd.jpg'.format(it))
|
||||
out_path = os.path.join(dir_o3d, f"{it}_pcd.jpg")
|
||||
vis.capture_screen_image(out_path)
|
||||
|
||||
def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name='video.mp4'):
|
||||
|
||||
def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name="video.mp4"):
|
||||
if pose is not None:
|
||||
device = psr_grid.device
|
||||
# get world coordinate of grid points [-1, 1]
|
||||
res = psr_grid.shape[-1]
|
||||
x = torch.linspace(-1, 1, steps=res)
|
||||
co_x, co_y, co_z = torch.meshgrid(x, x, x)
|
||||
co_grid = torch.stack(
|
||||
[co_x.reshape(-1), co_y.reshape(-1), co_z.reshape(-1)],
|
||||
dim=1).to(device).unsqueeze(0)
|
||||
co_grid = torch.stack([co_x.reshape(-1), co_y.reshape(-1), co_z.reshape(-1)], dim=1).to(device).unsqueeze(0)
|
||||
|
||||
# visualize the projected occ_soft value
|
||||
res = 128
|
||||
psr_grid = psr_grid.reshape(-1)
|
||||
out_mask = psr_grid>0
|
||||
in_mask = psr_grid<0
|
||||
out_mask = psr_grid > 0
|
||||
in_mask = psr_grid < 0
|
||||
pix = pose.transform_points_screen(co_grid, ((res, res),))[..., :2].round().long().squeeze()
|
||||
vis_mask = (pix[..., 0]>=0) & (pix[..., 0]<=res-1) & \
|
||||
(pix[..., 1]>=0) & (pix[..., 1]<=res-1)
|
||||
vis_mask = (pix[..., 0] >= 0) & (pix[..., 0] <= res - 1) & (pix[..., 1] >= 0) & (pix[..., 1] <= res - 1)
|
||||
pix_out = pix[vis_mask & out_mask]
|
||||
pix_in = pix[vis_mask & in_mask]
|
||||
|
||||
img = torch.ones([res,res]).to(device)
|
||||
psr_grid = torch.sigmoid(- psr_grid * 5)
|
||||
img = torch.ones([res, res]).to(device)
|
||||
psr_grid = torch.sigmoid(-psr_grid * 5)
|
||||
img[pix_out[:, 1], pix_out[:, 0]] = psr_grid[vis_mask & out_mask]
|
||||
img[pix_in[:, 1], pix_in[:, 0]] = psr_grid[vis_mask & in_mask]
|
||||
# save_image(img, 'tmp.png', normalize=True)
|
||||
|
@ -98,70 +96,70 @@ def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name='video.
|
|||
dir_psr_vis = out_dir
|
||||
os.makedirs(dir_psr_vis, exist_ok=True)
|
||||
psr_grid = psr_grid.squeeze().detach().cpu().numpy()
|
||||
axis = ['x', 'y', 'z']
|
||||
s = psr_grid.shape[0]
|
||||
for idx in trange(s):
|
||||
my_dpi = 100
|
||||
plt.figure(figsize=(1000/my_dpi, 300/my_dpi), dpi=my_dpi)
|
||||
plt.figure(figsize=(1000 / my_dpi, 300 / my_dpi), dpi=my_dpi)
|
||||
plt.subplot(1, 3, 1)
|
||||
plt.imshow(ndimage.rotate(psr_grid[idx], 180, mode='nearest'), cmap='nipy_spectral')
|
||||
plt.imshow(ndimage.rotate(psr_grid[idx], 180, mode="nearest"), cmap="nipy_spectral")
|
||||
plt.clim(-1, 1)
|
||||
plt.colorbar()
|
||||
plt.title('x')
|
||||
plt.title("x")
|
||||
plt.grid("off")
|
||||
plt.axis("off")
|
||||
|
||||
plt.subplot(1, 3, 2)
|
||||
plt.imshow(ndimage.rotate(psr_grid[:, idx], 180, mode='nearest'), cmap='nipy_spectral')
|
||||
plt.imshow(ndimage.rotate(psr_grid[:, idx], 180, mode="nearest"), cmap="nipy_spectral")
|
||||
plt.clim(-1, 1)
|
||||
plt.colorbar()
|
||||
plt.title('y')
|
||||
plt.title("y")
|
||||
plt.grid("off")
|
||||
plt.axis("off")
|
||||
|
||||
plt.subplot(1, 3, 3)
|
||||
plt.imshow(ndimage.rotate(psr_grid[:,:,idx], 90, mode='nearest'), cmap='nipy_spectral')
|
||||
plt.imshow(ndimage.rotate(psr_grid[:, :, idx], 90, mode="nearest"), cmap="nipy_spectral")
|
||||
plt.clim(-1, 1)
|
||||
plt.colorbar()
|
||||
plt.title('z')
|
||||
plt.title("z")
|
||||
plt.grid("off")
|
||||
plt.axis("off")
|
||||
|
||||
|
||||
plt.savefig(os.path.join(dir_psr_vis, '{}'.format(idx)), pad_inches = 0, dpi=100)
|
||||
plt.savefig(os.path.join(dir_psr_vis, f"{idx}"), pad_inches=0, dpi=100)
|
||||
plt.close()
|
||||
os.system("rm {}/{}".format(dir_psr_vis, out_video_name))
|
||||
os.system("ffmpeg -framerate 25 -start_number 0 -i {}/%d.png -pix_fmt yuv420p -crf 17 {}/{}".format(dir_psr_vis, dir_psr_vis, out_video_name))
|
||||
os.system(f"rm {dir_psr_vis}/{out_video_name}")
|
||||
os.system(
|
||||
f"ffmpeg -framerate 25 -start_number 0 -i {dir_psr_vis}/%d.png -pix_fmt yuv420p -crf 17 {dir_psr_vis}/{out_video_name}",
|
||||
)
|
||||
return None
|
||||
return None
|
||||
|
||||
def visualize_mesh_phong(v, f, n, pose, img_size, name, device='cpu'):
|
||||
|
||||
def visualize_mesh_phong(v, f, n, pose, img_size, name, device="cpu"):
|
||||
#! Mesh rendering using Phong shading model
|
||||
_, mask, f_p, w = calc_inters_points(v, f, pose, img_size)
|
||||
n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]]
|
||||
n_inters = w[..., 0, None] * n_a.squeeze() + \
|
||||
w[..., 1, None] * n_b.squeeze() + \
|
||||
w[..., 2, None] * n_c.squeeze()
|
||||
n_inters = w[..., 0, None] * n_a.squeeze() + w[..., 1, None] * n_b.squeeze() + w[..., 2, None] * n_c.squeeze()
|
||||
n_inters = n_inters.detach().to(device)
|
||||
light_source = -pose.R@pose.T.squeeze()
|
||||
light_source = -pose.R @ pose.T.squeeze()
|
||||
light = (light_source / light_source.norm(2)).permute(1, 0).to(device).float()
|
||||
diffuse_per = torch.Tensor([0.7,0.7,0.7]).float()
|
||||
ambiant = torch.Tensor([0.3,0.3,0.3]).float()
|
||||
diffuse_per = torch.Tensor([0.7, 0.7, 0.7]).float()
|
||||
ambiant = torch.Tensor([0.3, 0.3, 0.3]).float()
|
||||
|
||||
diffuse = torch.mm(n_inters, light).clamp_min(0).repeat(1, 3) * diffuse_per.unsqueeze(0).to(device)
|
||||
|
||||
phong = torch.ones([img_size[0]*img_size[1], 3]).to(device)
|
||||
phong = torch.ones([img_size[0] * img_size[1], 3]).to(device)
|
||||
phong[mask] = (ambiant.unsqueeze(0).to(device) + diffuse).clamp_max(1.0)
|
||||
pp = phong.reshape(img_size[0], img_size[1], -1)
|
||||
save_image(pp.permute(2, 0, 1), name)
|
||||
|
||||
|
||||
def render_rgb(v, f, n, pose, renderer, img_size, mask_gt=None, ray=None, fea_grid=None):
|
||||
p_inters, mask, f_p, w = calc_inters_points(v.detach(), f, pose, img_size, mask_gt=mask_gt)
|
||||
# normals for p_inters
|
||||
n_inters = None
|
||||
if n is not None:
|
||||
n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]]
|
||||
n_inters = w[..., 0, None] * n_a.squeeze() + \
|
||||
w[..., 1, None] * n_b.squeeze() + \
|
||||
w[..., 2, None] * n_c.squeeze()
|
||||
n_inters = w[..., 0, None] * n_a.squeeze() + w[..., 1, None] * n_b.squeeze() + w[..., 2, None] * n_c.squeeze()
|
||||
if ray is not None:
|
||||
ray = ray.squeeze()[mask]
|
||||
|
||||
|
|
159
train.py
159
train.py
|
@ -4,79 +4,90 @@ abspath = os.path.abspath(__file__)
|
|||
dname = os.path.dirname(abspath)
|
||||
os.chdir(dname)
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import open3d as o3d
|
||||
|
||||
import numpy as np; np.set_printoptions(precision=4)
|
||||
import shutil, argparse, time
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from src import config
|
||||
from src.data import collate_remove_none, collate_stack_together, worker_init_fn
|
||||
from src.training import Trainer
|
||||
from src.data import collate_remove_none, worker_init_fn
|
||||
from src.model import Encode2Points
|
||||
from src.utils import load_config, initialize_logger, \
|
||||
AverageMeter, load_model_manual
|
||||
from src.training import Trainer
|
||||
from src.utils import AverageMeter, initialize_logger, load_config, load_model_manual
|
||||
|
||||
np.set_printoptions(precision=4)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='MNIST toy experiment')
|
||||
parser.add_argument('config', type=str, help='Path to config file.')
|
||||
parser.add_argument('--no_cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
|
||||
parser = argparse.ArgumentParser(description="MNIST toy experiment")
|
||||
parser.add_argument("config", type=str, help="Path to config file.")
|
||||
parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
|
||||
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
|
||||
|
||||
args = parser.parse_args()
|
||||
cfg = load_config(args.config, 'configs/default.yaml')
|
||||
cfg = load_config(args.config, "configs/default.yaml")
|
||||
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
input_type = cfg['data']['input_type']
|
||||
batch_size = cfg['train']['batch_size']
|
||||
model_selection_metric = cfg['train']['model_selection_metric']
|
||||
cfg["data"]["input_type"]
|
||||
batch_size = cfg["train"]["batch_size"]
|
||||
model_selection_metric = cfg["train"]["model_selection_metric"]
|
||||
|
||||
# PYTORCH VERSION > 1.0.0
|
||||
assert(float(torch.__version__.split('.')[-3]) > 0)
|
||||
assert float(torch.__version__.split(".")[-3]) > 0
|
||||
|
||||
# boiler-plate
|
||||
if cfg['train']['timestamp']:
|
||||
cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S")
|
||||
if cfg["train"]["timestamp"]:
|
||||
cfg["train"]["out_dir"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
|
||||
logger = initialize_logger(cfg)
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
shutil.copyfile(args.config, os.path.join(cfg['train']['out_dir'], 'config.yaml'))
|
||||
shutil.copyfile(args.config, os.path.join(cfg["train"]["out_dir"], "config.yaml"))
|
||||
|
||||
logger.info("using GPU: " + torch.cuda.get_device_name(0))
|
||||
|
||||
# TensorboardX writer
|
||||
tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log")
|
||||
tblogdir = os.path.join(cfg["train"]["out_dir"], "tensorboard_log")
|
||||
if not os.path.exists(tblogdir):
|
||||
os.makedirs(tblogdir, exist_ok=True)
|
||||
writer = SummaryWriter(log_dir=tblogdir)
|
||||
|
||||
|
||||
inputs = None
|
||||
train_dataset = config.get_dataset('train', cfg)
|
||||
val_dataset = config.get_dataset('val', cfg)
|
||||
vis_dataset = config.get_dataset('vis', cfg)
|
||||
|
||||
train_dataset = config.get_dataset("train", cfg)
|
||||
val_dataset = config.get_dataset("val", cfg)
|
||||
vis_dataset = config.get_dataset("vis", cfg)
|
||||
|
||||
collate_fn = collate_remove_none
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=batch_size, num_workers=cfg['train']['n_workers'], shuffle=True,
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=cfg["train"]["n_workers"],
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
worker_init_fn=worker_init_fn)
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
val_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False,
|
||||
val_dataset,
|
||||
batch_size=1,
|
||||
num_workers=cfg["train"]["n_workers_val"],
|
||||
shuffle=False,
|
||||
collate_fn=collate_remove_none,
|
||||
worker_init_fn=worker_init_fn)
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
vis_loader = torch.utils.data.DataLoader(
|
||||
vis_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False,
|
||||
vis_dataset,
|
||||
batch_size=1,
|
||||
num_workers=cfg["train"]["n_workers_val"],
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn,
|
||||
worker_init_fn=worker_init_fn)
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
model = torch.nn.DataParallel(Encode2Points(cfg)).to(device)
|
||||
|
@ -84,38 +95,35 @@ def main():
|
|||
model = Encode2Points(cfg).to(device)
|
||||
|
||||
n_parameter = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
logger.info('Number of parameters: %d'% n_parameter)
|
||||
logger.info("Number of parameters: %d" % n_parameter)
|
||||
# load model
|
||||
try:
|
||||
# load model
|
||||
state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt'))
|
||||
load_model_manual(state_dict['state_dict'], model)
|
||||
state_dict = torch.load(os.path.join(cfg["train"]["out_dir"], "model.pt"))
|
||||
load_model_manual(state_dict["state_dict"], model)
|
||||
|
||||
out = "Load model from iteration %d" % state_dict.get('it', 0)
|
||||
out = "Load model from iteration %d" % state_dict.get("it", 0)
|
||||
logger.info(out)
|
||||
# load point cloud
|
||||
except:
|
||||
state_dict = dict()
|
||||
|
||||
metric_val_best = state_dict.get(
|
||||
'loss_val_best', np.inf)
|
||||
metric_val_best = state_dict.get("loss_val_best", np.inf)
|
||||
|
||||
logger.info('Current best validation metric (%s): %.8f'
|
||||
% (model_selection_metric, metric_val_best))
|
||||
logger.info(f"Current best validation metric ({model_selection_metric}): {metric_val_best:.8f}")
|
||||
|
||||
LR = float(cfg['train']['lr'])
|
||||
LR = float(cfg["train"]["lr"])
|
||||
optimizer = optim.Adam(model.parameters(), lr=LR)
|
||||
|
||||
start_epoch = state_dict.get('epoch', -1)
|
||||
it = state_dict.get('it', -1)
|
||||
start_epoch = state_dict.get("epoch", -1)
|
||||
it = state_dict.get("it", -1)
|
||||
|
||||
trainer = Trainer(cfg, optimizer, device=device)
|
||||
runtime = {}
|
||||
runtime['all'] = AverageMeter()
|
||||
runtime["all"] = AverageMeter()
|
||||
|
||||
# training loop
|
||||
for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1):
|
||||
|
||||
for epoch in range(start_epoch + 1, cfg["train"]["total_epochs"] + 1):
|
||||
for batch in train_loader:
|
||||
it += 1
|
||||
|
||||
|
@ -124,62 +132,57 @@ def main():
|
|||
|
||||
# measure elapsed time
|
||||
end = time.time()
|
||||
runtime['all'].update(end - start)
|
||||
runtime["all"].update(end - start)
|
||||
|
||||
if it % cfg['train']['print_every'] == 0:
|
||||
log_text = ('[Epoch %02d] it=%d, loss=%.4f') %(epoch, it, loss)
|
||||
writer.add_scalar('train/loss', loss, it)
|
||||
if it % cfg["train"]["print_every"] == 0:
|
||||
log_text = ("[Epoch %02d] it=%d, loss=%.4f") % (epoch, it, loss)
|
||||
writer.add_scalar("train/loss", loss, it)
|
||||
if loss_each is not None:
|
||||
for k, l in loss_each.items():
|
||||
if l.item() != 0.:
|
||||
log_text += (' loss_%s=%.4f') % (k, l.item())
|
||||
writer.add_scalar('train/%s' % k, l, it)
|
||||
if l.item() != 0.0:
|
||||
log_text += f" loss_{k}={l.item():.4f}"
|
||||
writer.add_scalar("train/%s" % k, l, it)
|
||||
|
||||
log_text += (' time=%.3f / %.2f') % (runtime['all'].val, runtime['all'].sum)
|
||||
log_text += (" time={:.3f} / {:.2f}").format(runtime["all"].val, runtime["all"].sum)
|
||||
logger.info(log_text)
|
||||
|
||||
if (it>0)& (it % cfg['train']['visualize_every'] == 0):
|
||||
if (it > 0) & (it % cfg["train"]["visualize_every"] == 0):
|
||||
for i, batch_vis in enumerate(vis_loader):
|
||||
trainer.save(model, batch_vis, it, i)
|
||||
if i >= 4:
|
||||
break
|
||||
logger.info('Saved mesh and pointcloud')
|
||||
logger.info("Saved mesh and pointcloud")
|
||||
|
||||
# run validation
|
||||
if it > 0 and (it % cfg['train']['validate_every']) == 0:
|
||||
if it > 0 and (it % cfg["train"]["validate_every"]) == 0:
|
||||
eval_dict = trainer.evaluate(val_loader, model)
|
||||
metric_val = eval_dict[model_selection_metric]
|
||||
logger.info('Validation metric (%s): %.4f'
|
||||
% (model_selection_metric, metric_val))
|
||||
logger.info(f"Validation metric ({model_selection_metric}): {metric_val:.4f}")
|
||||
|
||||
for k, v in eval_dict.items():
|
||||
writer.add_scalar('val/%s' % k, v, it)
|
||||
writer.add_scalar("val/%s" % k, v, it)
|
||||
|
||||
if -(metric_val - metric_val_best) >= 0:
|
||||
metric_val_best = metric_val
|
||||
logger.info('New best model (loss %.4f)' % metric_val_best)
|
||||
state = {'epoch': epoch,
|
||||
'it': it,
|
||||
'loss_val_best': metric_val_best}
|
||||
state['state_dict'] = model.state_dict()
|
||||
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model_best.pt'))
|
||||
logger.info("New best model (loss %.4f)" % metric_val_best)
|
||||
state = {"epoch": epoch, "it": it, "loss_val_best": metric_val_best}
|
||||
state["state_dict"] = model.state_dict()
|
||||
torch.save(state, os.path.join(cfg["train"]["out_dir"], "model_best.pt"))
|
||||
|
||||
# save checkpoint
|
||||
if (epoch > 0) & (it % cfg['train']['checkpoint_every'] == 0):
|
||||
state = {'epoch': epoch,
|
||||
'it': it,
|
||||
'loss_val_best': metric_val_best}
|
||||
pcl = None
|
||||
state['state_dict'] = model.state_dict()
|
||||
if (epoch > 0) & (it % cfg["train"]["checkpoint_every"] == 0):
|
||||
state = {"epoch": epoch, "it": it, "loss_val_best": metric_val_best}
|
||||
state["state_dict"] = model.state_dict()
|
||||
|
||||
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt'))
|
||||
torch.save(state, os.path.join(cfg["train"]["out_dir"], "model.pt"))
|
||||
|
||||
if (it % cfg['train']['backup_every'] == 0):
|
||||
torch.save(state, os.path.join(cfg['train']['dir_model'], '%04d' % it + '.pt'))
|
||||
if it % cfg["train"]["backup_every"] == 0:
|
||||
torch.save(state, os.path.join(cfg["train"]["dir_model"], "%04d" % it + ".pt"))
|
||||
logger.info("Backup model at iteration %d" % it)
|
||||
logger.info("Save new model at iteration %d" % it)
|
||||
|
||||
done=time.time()
|
||||
time.time()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in a new issue