🎨 apply auto formatting

This commit is contained in:
Laurent FAINSIN 2023-05-26 14:59:53 +02:00
parent ee1be08bce
commit 41b213e4a5
28 changed files with 2189 additions and 2048 deletions

View file

@ -1,155 +1,145 @@
import argparse
import os
import numpy as np
import pandas as pd
import torch import torch
import trimesh import trimesh
from torch.utils.data import Dataset, DataLoader
import numpy as np; np.set_printoptions(precision=4)
import shutil, argparse, time, os
import pandas as pd
from src.data import collate_remove_none, collate_stack_together, worker_init_fn
from src.training import Trainer
from src.model import Encode2Points
from src.data import PointCloudField, IndexField, Shapes3dDataset
from src.utils import load_config, load_pointcloud
from src.eval import MeshEvaluator
from tqdm import tqdm from tqdm import tqdm
from pdb import set_trace as st
from src.data import IndexField, PointCloudField, Shapes3dDataset
from src.eval import MeshEvaluator
from src.utils import load_config, load_pointcloud
np.set_printoptions(precision=4)
def main(): def main():
parser = argparse.ArgumentParser(description='MNIST toy experiment') parser = argparse.ArgumentParser(description="MNIST toy experiment")
parser.add_argument('config', type=str, help='Path to config file.') parser.add_argument("config", type=str, help="Path to config file.")
parser.add_argument('--no_cuda', action='store_true', default=False, parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
help='disables CUDA training') parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument("--iter", type=int, metavar="S", help="the training iteration to be evaluated.")
parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.')
args = parser.parse_args() args = parser.parse_args()
cfg = load_config(args.config, 'configs/default.yaml') cfg = load_config(args.config, "configs/default.yaml")
use_cuda = not args.no_cuda and torch.cuda.is_available() use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu") torch.device("cuda" if use_cuda else "cpu")
data_type = cfg['data']['data_type'] cfg["data"]["data_type"]
# Shorthands # Shorthands
out_dir = cfg['train']['out_dir'] out_dir = cfg["train"]["out_dir"]
generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir']) generation_dir = os.path.join(out_dir, cfg["generation"]["generation_dir"])
if cfg['generation'].get('iter', 0)!=0: if cfg["generation"].get("iter", 0) != 0:
generation_dir += '_%04d'%cfg['generation']['iter'] generation_dir += "_%04d" % cfg["generation"]["iter"]
elif args.iter is not None: elif args.iter is not None:
generation_dir += '_%04d'%args.iter generation_dir += "_%04d" % args.iter
print('Evaluate meshes under %s'%generation_dir)
out_file = os.path.join(generation_dir, 'eval_meshes_full.pkl')
out_file_class = os.path.join(generation_dir, 'eval_meshes.csv')
# PYTORCH VERSION > 1.0.0
assert(float(torch.__version__.split('.')[-3]) > 0)
pointcloud_field = PointCloudField(cfg['data']['pointcloud_file']) print("Evaluate meshes under %s" % generation_dir)
out_file = os.path.join(generation_dir, "eval_meshes_full.pkl")
out_file_class = os.path.join(generation_dir, "eval_meshes.csv")
# PYTORCH VERSION > 1.0.0
assert float(torch.__version__.split(".")[-3]) > 0
pointcloud_field = PointCloudField(cfg["data"]["pointcloud_file"])
fields = { fields = {
'pointcloud': pointcloud_field, "pointcloud": pointcloud_field,
'idx': IndexField(), "idx": IndexField(),
} }
print('Test split: ', cfg['data']['test_split']) print("Test split: ", cfg["data"]["test_split"])
dataset_folder = cfg['data']['path'] dataset_folder = cfg["data"]["path"]
dataset = Shapes3dDataset( dataset = Shapes3dDataset(
dataset_folder, fields, dataset_folder, fields, cfg["data"]["test_split"], categories=cfg["data"]["class"], cfg=cfg,
cfg['data']['test_split'], )
categories=cfg['data']['class'], cfg=cfg)
# Loader # Loader
test_loader = torch.utils.data.DataLoader( test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False)
dataset, batch_size=1, num_workers=0, shuffle=False)
# Evaluator # Evaluator
evaluator = MeshEvaluator(n_points=100000) evaluator = MeshEvaluator(n_points=100000)
eval_dicts = [] eval_dicts = []
print('Evaluating meshes...') print("Evaluating meshes...")
for it, data in enumerate(tqdm(test_loader)): for _it, data in enumerate(tqdm(test_loader)):
if data is None: if data is None:
print('Invalid data.') print("Invalid data.")
continue continue
mesh_dir = os.path.join(generation_dir, 'meshes') mesh_dir = os.path.join(generation_dir, "meshes")
pointcloud_dir = os.path.join(generation_dir, 'pointcloud') pointcloud_dir = os.path.join(generation_dir, "pointcloud")
# Get index etc. # Get index etc.
idx = data['idx'].item() idx = data["idx"].item()
try: try:
model_dict = dataset.get_model_dict(idx) model_dict = dataset.get_model_dict(idx)
except AttributeError: except AttributeError:
model_dict = {'model': str(idx), 'category': 'n/a'} model_dict = {"model": str(idx), "category": "n/a"}
modelname = model_dict['model'] modelname = model_dict["model"]
category_id = model_dict['category'] category_id = model_dict["category"]
try: try:
category_name = dataset.metadata[category_id].get('name', 'n/a') category_name = dataset.metadata[category_id].get("name", "n/a")
except AttributeError: except AttributeError:
category_name = 'n/a' category_name = "n/a"
if category_id != 'n/a': if category_id != "n/a":
mesh_dir = os.path.join(mesh_dir, category_id) mesh_dir = os.path.join(mesh_dir, category_id)
pointcloud_dir = os.path.join(pointcloud_dir, category_id) pointcloud_dir = os.path.join(pointcloud_dir, category_id)
# Evaluate # Evaluate
pointcloud_tgt = data['pointcloud'].squeeze(0).numpy() pointcloud_tgt = data["pointcloud"].squeeze(0).numpy()
normals_tgt = data['pointcloud.normals'].squeeze(0).numpy() normals_tgt = data["pointcloud.normals"].squeeze(0).numpy()
eval_dict = { eval_dict = {
'idx': idx, "idx": idx,
'class id': category_id, "class id": category_id,
'class name': category_name, "class name": category_name,
'modelname':modelname, "modelname": modelname,
} }
eval_dicts.append(eval_dict) eval_dicts.append(eval_dict)
# Evaluate mesh # Evaluate mesh
if cfg['test']['eval_mesh']: if cfg["test"]["eval_mesh"]:
mesh_file = os.path.join(mesh_dir, '%s.off' % modelname) mesh_file = os.path.join(mesh_dir, "%s.off" % modelname)
if os.path.exists(mesh_file): if os.path.exists(mesh_file):
mesh = trimesh.load(mesh_file, process=False) mesh = trimesh.load(mesh_file, process=False)
eval_dict_mesh = evaluator.eval_mesh( eval_dict_mesh = evaluator.eval_mesh(mesh, pointcloud_tgt, normals_tgt)
mesh, pointcloud_tgt, normals_tgt)
for k, v in eval_dict_mesh.items(): for k, v in eval_dict_mesh.items():
eval_dict[k + ' (mesh)'] = v eval_dict[k + " (mesh)"] = v
else: else:
print('Warning: mesh does not exist: %s' % mesh_file) print("Warning: mesh does not exist: %s" % mesh_file)
# Evaluate point cloud # Evaluate point cloud
if cfg['test']['eval_pointcloud']: if cfg["test"]["eval_pointcloud"]:
pointcloud_file = os.path.join( pointcloud_file = os.path.join(pointcloud_dir, "%s.ply" % modelname)
pointcloud_dir, '%s.ply' % modelname)
if os.path.exists(pointcloud_file): if os.path.exists(pointcloud_file):
pointcloud = load_pointcloud(pointcloud_file).astype(np.float32) pointcloud = load_pointcloud(pointcloud_file).astype(np.float32)
eval_dict_pcl = evaluator.eval_pointcloud( eval_dict_pcl = evaluator.eval_pointcloud(pointcloud, pointcloud_tgt)
pointcloud, pointcloud_tgt)
for k, v in eval_dict_pcl.items(): for k, v in eval_dict_pcl.items():
eval_dict[k + ' (pcl)'] = v eval_dict[k + " (pcl)"] = v
else: else:
print('Warning: pointcloud does not exist: %s' print("Warning: pointcloud does not exist: %s" % pointcloud_file)
% pointcloud_file)
# Create pandas dataframe and save # Create pandas dataframe and save
eval_df = pd.DataFrame(eval_dicts) eval_df = pd.DataFrame(eval_dicts)
eval_df.set_index(['idx'], inplace=True) eval_df.set_index(["idx"], inplace=True)
eval_df.to_pickle(out_file) eval_df.to_pickle(out_file)
# Create CSV file with main statistics # Create CSV file with main statistics
eval_df_class = eval_df.groupby(by=['class name']).mean() eval_df_class = eval_df.groupby(by=["class name"]).mean()
eval_df_class.loc['mean'] = eval_df_class.mean() eval_df_class.loc["mean"] = eval_df_class.mean()
eval_df_class.to_csv(out_file_class) eval_df_class.to_csv(out_file_class)
# Print results # Print results
print(eval_df_class) print(eval_df_class)
if __name__ == '__main__':
main() if __name__ == "__main__":
main()

View file

@ -1,155 +1,164 @@
import torch import argparse
from torch.utils.data import Dataset, DataLoader import os
import numpy as np; np.set_printoptions(precision=4) import shutil
import shutil, argparse, time, os
import pandas as pd
from collections import defaultdict from collections import defaultdict
from src import config
from src.utils import mc_from_psr, export_mesh, export_pointcloud import numpy as np
from src.dpsr import DPSR import pandas as pd
from src.training import Trainer import torch
from src.model import Encode2Points
from src.utils import load_config, load_model_manual, scale2onet, is_url, load_url
from tqdm import tqdm from tqdm import tqdm
from pdb import set_trace as st
from src import config
from src.dpsr import DPSR
from src.model import Encode2Points
from src.utils import (
export_mesh,
export_pointcloud,
is_url,
load_config,
load_model_manual,
load_url,
mc_from_psr,
scale2onet,
)
np.set_printoptions(precision=4)
def main(): def main():
parser = argparse.ArgumentParser(description='MNIST toy experiment') parser = argparse.ArgumentParser(description="MNIST toy experiment")
parser.add_argument('config', type=str, help='Path to config file.') parser.add_argument("config", type=str, help="Path to config file.")
parser.add_argument('--no_cuda', action='store_true', default=False, parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
help='disables CUDA training') parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument("--iter", type=int, metavar="S", help="the training iteration to be evaluated.")
parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.')
args = parser.parse_args() args = parser.parse_args()
cfg = load_config(args.config, 'configs/default.yaml') cfg = load_config(args.config, "configs/default.yaml")
use_cuda = not args.no_cuda and torch.cuda.is_available() use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu") device = torch.device("cuda" if use_cuda else "cpu")
data_type = cfg['data']['data_type'] cfg["data"]["data_type"]
input_type = cfg['data']['input_type'] cfg["data"]["input_type"]
vis_n_outputs = cfg['generation']['vis_n_outputs'] vis_n_outputs = cfg["generation"]["vis_n_outputs"]
if vis_n_outputs is None: if vis_n_outputs is None:
vis_n_outputs = -1 vis_n_outputs = -1
# Shorthands # Shorthands
out_dir = cfg['train']['out_dir'] out_dir = cfg["train"]["out_dir"]
if not out_dir: if not out_dir:
os.makedirs(out_dir) os.makedirs(out_dir)
generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir']) generation_dir = os.path.join(out_dir, cfg["generation"]["generation_dir"])
out_time_file = os.path.join(generation_dir, 'time_generation_full.pkl') out_time_file = os.path.join(generation_dir, "time_generation_full.pkl")
out_time_file_class = os.path.join(generation_dir, 'time_generation.pkl') out_time_file_class = os.path.join(generation_dir, "time_generation.pkl")
# PYTORCH VERSION > 1.0.0 # PYTORCH VERSION > 1.0.0
assert(float(torch.__version__.split('.')[-3]) > 0) assert float(torch.__version__.split(".")[-3]) > 0
dataset = config.get_dataset('test', cfg, return_idx=True) dataset = config.get_dataset("test", cfg, return_idx=True)
test_loader = torch.utils.data.DataLoader( test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False)
dataset, batch_size=1, num_workers=0, shuffle=False)
model = Encode2Points(cfg).to(device) model = Encode2Points(cfg).to(device)
# load model # load model
try: try:
if is_url(cfg['test']['model_file']): if is_url(cfg["test"]["model_file"]):
state_dict = load_url(cfg['test']['model_file']) state_dict = load_url(cfg["test"]["model_file"])
elif cfg['generation'].get('iter', 0)!=0: elif cfg["generation"].get("iter", 0) != 0:
state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% cfg['generation']['iter'])) state_dict = torch.load(os.path.join(out_dir, "model", "%04d.pt" % cfg["generation"]["iter"]))
generation_dir += '_%04d'%cfg['generation']['iter'] generation_dir += "_%04d" % cfg["generation"]["iter"]
elif args.iter is not None: elif args.iter is not None:
state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% args.iter)) state_dict = torch.load(os.path.join(out_dir, "model", "%04d.pt" % args.iter))
else: else:
state_dict = torch.load(os.path.join(out_dir, 'model_best.pt')) state_dict = torch.load(os.path.join(out_dir, "model_best.pt"))
load_model_manual(state_dict['state_dict'], model) load_model_manual(state_dict["state_dict"], model)
except: except:
print('Model loading error. Exiting.') print("Model loading error. Exiting.")
exit() exit()
# Generator # Generator
generator = config.get_generator(model, cfg, device=device) generator = config.get_generator(model, cfg, device=device)
# Determine what to generate # Determine what to generate
generate_mesh = cfg['generation']['generate_mesh'] generate_mesh = cfg["generation"]["generate_mesh"]
generate_pointcloud = cfg['generation']['generate_pointcloud'] generate_pointcloud = cfg["generation"]["generate_pointcloud"]
# Statistics # Statistics
time_dicts = [] time_dicts = []
# Generate # Generate
model.eval() model.eval()
dpsr = DPSR(res=(cfg['generation']['psr_resolution'], dpsr = DPSR(
cfg['generation']['psr_resolution'], res=(
cfg['generation']['psr_resolution']), cfg["generation"]["psr_resolution"],
sig= cfg['generation']['psr_sigma']).to(device) cfg["generation"]["psr_resolution"],
cfg["generation"]["psr_resolution"],
),
sig=cfg["generation"]["psr_sigma"],
).to(device)
# Count how many models already created # Count how many models already created
model_counter = defaultdict(int) model_counter = defaultdict(int)
print('Generating...') print("Generating...")
for it, data in enumerate(tqdm(test_loader)): for _it, data in enumerate(tqdm(test_loader)):
# Output folders # Output folders
mesh_dir = os.path.join(generation_dir, 'meshes') mesh_dir = os.path.join(generation_dir, "meshes")
in_dir = os.path.join(generation_dir, 'input') in_dir = os.path.join(generation_dir, "input")
pointcloud_dir = os.path.join(generation_dir, 'pointcloud') pointcloud_dir = os.path.join(generation_dir, "pointcloud")
generation_vis_dir = os.path.join(generation_dir, 'vis', ) generation_vis_dir = os.path.join(generation_dir, "vis")
# Get index etc. # Get index etc.
idx = data['idx'].item() idx = data["idx"].item()
try: try:
model_dict = dataset.get_model_dict(idx) model_dict = dataset.get_model_dict(idx)
except AttributeError: except AttributeError:
model_dict = {'model': str(idx), 'category': 'n/a'} model_dict = {"model": str(idx), "category": "n/a"}
modelname = model_dict['model'] modelname = model_dict["model"]
category_id = model_dict['category'] category_id = model_dict["category"]
try: try:
category_name = dataset.metadata[category_id].get('name', 'n/a') category_name = dataset.metadata[category_id].get("name", "n/a")
except AttributeError: except AttributeError:
category_name = 'n/a' category_name = "n/a"
if category_id != 'n/a': if category_id != "n/a":
mesh_dir = os.path.join(mesh_dir, str(category_id)) mesh_dir = os.path.join(mesh_dir, str(category_id))
pointcloud_dir = os.path.join(pointcloud_dir, str(category_id)) pointcloud_dir = os.path.join(pointcloud_dir, str(category_id))
in_dir = os.path.join(in_dir, str(category_id)) in_dir = os.path.join(in_dir, str(category_id))
folder_name = str(category_id) folder_name = str(category_id)
if category_name != 'n/a': if category_name != "n/a":
folder_name = str(folder_name) + '_' + category_name.split(',')[0] folder_name = str(folder_name) + "_" + category_name.split(",")[0]
generation_vis_dir = os.path.join(generation_vis_dir, folder_name) generation_vis_dir = os.path.join(generation_vis_dir, folder_name)
# Create directories if necessary # Create directories if necessary
if vis_n_outputs >= 0 and not os.path.exists(generation_vis_dir): if vis_n_outputs >= 0 and not os.path.exists(generation_vis_dir):
os.makedirs(generation_vis_dir) os.makedirs(generation_vis_dir)
if generate_mesh and not os.path.exists(mesh_dir): if generate_mesh and not os.path.exists(mesh_dir):
os.makedirs(mesh_dir) os.makedirs(mesh_dir)
if generate_pointcloud and not os.path.exists(pointcloud_dir): if generate_pointcloud and not os.path.exists(pointcloud_dir):
os.makedirs(pointcloud_dir) os.makedirs(pointcloud_dir)
if not os.path.exists(in_dir): if not os.path.exists(in_dir):
os.makedirs(in_dir) os.makedirs(in_dir)
# Timing dict # Timing dict
time_dict = { time_dict = {
'idx': idx, "idx": idx,
'class id': category_id, "class id": category_id,
'class name': category_name, "class name": category_name,
'modelname':modelname, "modelname": modelname,
} }
time_dicts.append(time_dict) time_dicts.append(time_dict)
# Generate outputs # Generate outputs
out_file_dict = {} out_file_dict = {}
if generate_mesh: if generate_mesh:
#! deploy the generator to a separate class #! deploy the generator to a separate class
out = generator.generate_mesh(data) out = generator.generate_mesh(data)
@ -158,60 +167,56 @@ def main():
time_dict.update(stats_dict) time_dict.update(stats_dict)
# Write output # Write output
mesh_out_file = os.path.join(mesh_dir, '%s.off' % modelname) mesh_out_file = os.path.join(mesh_dir, "%s.off" % modelname)
export_mesh(mesh_out_file, scale2onet(v), f) export_mesh(mesh_out_file, scale2onet(v), f)
out_file_dict['mesh'] = mesh_out_file out_file_dict["mesh"] = mesh_out_file
if generate_pointcloud: if generate_pointcloud:
pointcloud_out_file = os.path.join( pointcloud_out_file = os.path.join(pointcloud_dir, "%s.ply" % modelname)
pointcloud_dir, '%s.ply' % modelname)
export_pointcloud(pointcloud_out_file, scale2onet(points), normals) export_pointcloud(pointcloud_out_file, scale2onet(points), normals)
out_file_dict['pointcloud'] = pointcloud_out_file out_file_dict["pointcloud"] = pointcloud_out_file
if cfg['generation']['copy_input']: if cfg["generation"]["copy_input"]:
inputs_path = os.path.join(in_dir, '%s.ply' % modelname) inputs_path = os.path.join(in_dir, "%s.ply" % modelname)
p = data.get('inputs').to(device) p = data.get("inputs").to(device)
export_pointcloud(inputs_path, scale2onet(p)) export_pointcloud(inputs_path, scale2onet(p))
out_file_dict['in'] = inputs_path out_file_dict["in"] = inputs_path
# Copy to visualization directory for first vis_n_output samples # Copy to visualization directory for first vis_n_output samples
c_it = model_counter[category_id] c_it = model_counter[category_id]
if c_it < vis_n_outputs: if c_it < vis_n_outputs:
# Save output files # Save output files
img_name = '%02d.off' % c_it "%02d.off" % c_it
for k, filepath in out_file_dict.items(): for k, filepath in out_file_dict.items():
ext = os.path.splitext(filepath)[1] ext = os.path.splitext(filepath)[1]
out_file = os.path.join(generation_vis_dir, '%02d_%s%s' out_file = os.path.join(generation_vis_dir, "%02d_%s%s" % (c_it, k, ext))
% (c_it, k, ext))
shutil.copyfile(filepath, out_file) shutil.copyfile(filepath, out_file)
# Also generate oracle meshes
if cfg['generation']['exp_oracle']:
points_gt = data.get('gt_points').to(device)
normals_gt = data.get('gt_points.normals').to(device)
psr_gt = dpsr(points_gt, normals_gt)
v, f, _ = mc_from_psr(psr_gt,
zero_level=cfg['data']['zero_level'])
out_file = os.path.join(generation_vis_dir, '%02d_%s%s'
% (c_it, 'mesh_oracle', '.off'))
export_mesh(out_file, scale2onet(v), f)
model_counter[category_id] += 1
# Also generate oracle meshes
if cfg["generation"]["exp_oracle"]:
points_gt = data.get("gt_points").to(device)
normals_gt = data.get("gt_points.normals").to(device)
psr_gt = dpsr(points_gt, normals_gt)
v, f, _ = mc_from_psr(psr_gt, zero_level=cfg["data"]["zero_level"])
out_file = os.path.join(generation_vis_dir, "%02d_%s%s" % (c_it, "mesh_oracle", ".off"))
export_mesh(out_file, scale2onet(v), f)
model_counter[category_id] += 1
# Create pandas dataframe and save # Create pandas dataframe and save
time_df = pd.DataFrame(time_dicts) time_df = pd.DataFrame(time_dicts)
time_df.set_index(['idx'], inplace=True) time_df.set_index(["idx"], inplace=True)
time_df.to_pickle(out_time_file) time_df.to_pickle(out_time_file)
# Create pickle files with main statistics # Create pickle files with main statistics
time_df_class = time_df.groupby(by=['class name']).mean() time_df_class = time_df.groupby(by=["class name"]).mean()
time_df_class.loc['mean'] = time_df_class.mean() time_df_class.loc["mean"] = time_df_class.mean()
time_df_class.to_pickle(out_time_file_class) time_df_class.to_pickle(out_time_file_class)
# Print results # Print results
print('Timings [s]:') print("Timings [s]:")
print(time_df_class) print(time_df_class)
if __name__ == '__main__':
main() if __name__ == "__main__":
main()

336
optim.py
View file

@ -1,79 +1,86 @@
import argparse
import glob
import os
import shutil
import time
import numpy as np
import open3d as o3d
import torch import torch
import trimesh import trimesh
import shutil, argparse, time, os, glob from plyfile import PlyData
from pytorch3d.io import load_objs_as_meshes
import numpy as np; np.set_printoptions(precision=4) from pytorch3d.ops import sample_points_from_meshes
import open3d as o3d from pytorch3d.structures import Meshes
from skimage import measure
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image
from torchvision.io import write_video
from src.optimization import Trainer from src.optimization import Trainer
from src.utils import load_config, update_config, initialize_logger, \ from src.utils import (
get_learning_rate_schedules, adjust_learning_rate, AverageMeter,\ AverageMeter,
update_optimizer, export_pointcloud adjust_learning_rate,
from skimage import measure export_pointcloud,
from plyfile import PlyData get_learning_rate_schedules,
from pytorch3d.ops import sample_points_from_meshes initialize_logger,
from pytorch3d.io import load_objs_as_meshes load_config,
from pytorch3d.structures import Meshes update_config,
update_optimizer,
)
np.set_printoptions(precision=4)
def main(): def main():
parser = argparse.ArgumentParser(description='MNIST toy experiment') parser = argparse.ArgumentParser(description="MNIST toy experiment")
parser.add_argument('config', type=str, help='Path to config file.') parser.add_argument("config", type=str, help="Path to config file.")
parser.add_argument('--no_cuda', action='store_true', default=False, parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
help='disables CUDA training') parser.add_argument("--seed", type=int, default=1457, metavar="S", help="random seed")
parser.add_argument('--seed', type=int, default=1457, metavar='S',
help='random seed') args, unknown = parser.parse_known_args()
cfg = load_config(args.config, "configs/default.yaml")
args, unknown = parser.parse_known_args()
cfg = load_config(args.config, 'configs/default.yaml')
cfg = update_config(cfg, unknown) cfg = update_config(cfg, unknown)
use_cuda = not args.no_cuda and torch.cuda.is_available() use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu") device = torch.device("cuda" if use_cuda else "cpu")
data_type = cfg['data']['data_type'] data_type = cfg["data"]["data_type"]
data_class = cfg['data']['class'] cfg["data"]["class"]
print(cfg['train']['out_dir']) print(cfg["train"]["out_dir"])
# PYTORCH VERSION > 1.0.0 # PYTORCH VERSION > 1.0.0
assert(float(torch.__version__.split('.')[-3]) > 0) assert float(torch.__version__.split(".")[-3]) > 0
# boiler-plate # boiler-plate
if cfg['train']['timestamp']: if cfg["train"]["timestamp"]:
cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S") cfg["train"]["out_dir"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
logger = initialize_logger(cfg) logger = initialize_logger(cfg)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
shutil.copyfile(args.config, shutil.copyfile(args.config, os.path.join(cfg["train"]["out_dir"], "config.yaml"))
os.path.join(cfg['train']['out_dir'], 'config.yaml'))
# tensorboardX writer # tensorboardX writer
tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log") tblogdir = os.path.join(cfg["train"]["out_dir"], "tensorboard_log")
if not os.path.exists(tblogdir): if not os.path.exists(tblogdir):
os.makedirs(tblogdir) os.makedirs(tblogdir)
writer = SummaryWriter(log_dir=tblogdir) SummaryWriter(log_dir=tblogdir)
# initialize o3d visualizer # initialize o3d visualizer
vis = None vis = None
if cfg['train']['o3d_show']: if cfg["train"]["o3d_show"]:
vis = o3d.visualization.Visualizer() vis = o3d.visualization.Visualizer()
vis.create_window(width=cfg['train']['o3d_window_size'], vis.create_window(width=cfg["train"]["o3d_window_size"], height=cfg["train"]["o3d_window_size"])
height=cfg['train']['o3d_window_size'])
# initialize dataset # initialize dataset
if data_type == 'point': if data_type == "point":
if cfg['data']['object_id'] != -1: if cfg["data"]["object_id"] != -1:
data_paths = sorted(glob.glob(cfg['data']['data_path'])) data_paths = sorted(glob.glob(cfg["data"]["data_path"]))
data_path = data_paths[cfg['data']['object_id']] data_path = data_paths[cfg["data"]["object_id"]]
print('Loaded %d/%d object' % (cfg['data']['object_id']+1, len(data_paths))) print("Loaded %d/%d object" % (cfg["data"]["object_id"] + 1, len(data_paths)))
else: else:
data_path = cfg['data']['data_path'] data_path = cfg["data"]["data_path"]
print('Data loaded') print("Data loaded")
ext = data_path.split('.')[-1] ext = data_path.split(".")[-1]
if ext == 'obj': # have GT mesh if ext == "obj": # have GT mesh
mesh = load_objs_as_meshes([data_path], device=device) mesh = load_objs_as_meshes([data_path], device=device)
# scale the mesh into unit cube # scale the mesh into unit cube
verts = mesh.verts_packed() verts = mesh.verts_packed()
@ -81,20 +88,15 @@ def main():
center = verts.mean(0) center = verts.mean(0)
mesh.offset_verts_(-center.expand(N, 3)) mesh.offset_verts_(-center.expand(N, 3))
scale = max((verts - center).abs().max(0)[0]) scale = max((verts - center).abs().max(0)[0])
mesh.scale_verts_((1.0 / float(scale))) mesh.scale_verts_(1.0 / float(scale))
# important for our DPSR to have the range in [0, 1), not reaching 1 # important for our DPSR to have the range in [0, 1), not reaching 1
mesh.scale_verts_(0.9) mesh.scale_verts_(0.9)
target_pts, target_normals = sample_points_from_meshes(mesh, target_pts, target_normals = sample_points_from_meshes(mesh, num_samples=200000, return_normals=True)
num_samples=200000, return_normals=True) elif ext == "ply": # only have the point cloud
elif ext == 'ply': # only have the point cloud
plydata = PlyData.read(data_path) plydata = PlyData.read(data_path)
vertices = np.stack([plydata['vertex']['x'], vertices = np.stack([plydata["vertex"]["x"], plydata["vertex"]["y"], plydata["vertex"]["z"]], axis=1)
plydata['vertex']['y'], normals = np.stack([plydata["vertex"]["nx"], plydata["vertex"]["ny"], plydata["vertex"]["nz"]], axis=1)
plydata['vertex']['z']], axis=1)
normals = np.stack([plydata['vertex']['nx'],
plydata['vertex']['ny'],
plydata['vertex']['nz']], axis=1)
N = vertices.shape[0] N = vertices.shape[0]
center = vertices.mean(0) center = vertices.mean(0)
scale = np.max(np.max(np.abs(vertices - center), axis=0)) scale = np.max(np.max(np.abs(vertices - center), axis=0))
@ -104,212 +106,212 @@ def main():
target_pts = torch.tensor(vertices, device=device)[None].float() target_pts = torch.tensor(vertices, device=device)[None].float()
target_normals = torch.tensor(normals, device=device)[None].float() target_normals = torch.tensor(normals, device=device)[None].float()
mesh = None # no GT mesh mesh = None # no GT mesh
if not torch.is_tensor(center): if not torch.is_tensor(center):
center = torch.from_numpy(center) center = torch.from_numpy(center)
if not torch.is_tensor(scale): if not torch.is_tensor(scale):
scale = torch.from_numpy(np.array([scale])) scale = torch.from_numpy(np.array([scale]))
data = {'target_points': target_pts, data = {
'target_normals': target_normals, # normals are never used "target_points": target_pts,
'gt_mesh': mesh} "target_normals": target_normals, # normals are never used
"gt_mesh": mesh,
}
else: else:
raise NotImplementedError raise NotImplementedError
# save the input point cloud # save the input point cloud
if 'target_points' in data.keys(): if "target_points" in data.keys():
outdir_pcl = os.path.join(cfg['train']['out_dir'], 'target_pcl.ply') outdir_pcl = os.path.join(cfg["train"]["out_dir"], "target_pcl.ply")
if 'target_normals' in data.keys(): if "target_normals" in data.keys():
export_pointcloud(outdir_pcl, data['target_points'], data['target_normals']) export_pointcloud(outdir_pcl, data["target_points"], data["target_normals"])
else: else:
export_pointcloud(outdir_pcl, data['target_points']) export_pointcloud(outdir_pcl, data["target_points"])
# save oracle PSR mesh (mesh from our PSR using GT point+normals) # save oracle PSR mesh (mesh from our PSR using GT point+normals)
if data.get('gt_mesh') is not None: if data.get("gt_mesh") is not None:
gt_verts, gt_faces = data['gt_mesh'].get_mesh_verts_faces(0) gt_verts, gt_faces = data["gt_mesh"].get_mesh_verts_faces(0)
pts_gt, norms_gt = sample_points_from_meshes(data['gt_mesh'], pts_gt, norms_gt = sample_points_from_meshes(data["gt_mesh"], num_samples=500000, return_normals=True)
num_samples=500000, return_normals=True)
pts_gt = (pts_gt + 1) / 2 pts_gt = (pts_gt + 1) / 2
from src.dpsr import DPSR from src.dpsr import DPSR
dpsr_tmp = DPSR(res=(cfg['model']['grid_res'],
cfg['model']['grid_res'], dpsr_tmp = DPSR(
cfg['model']['grid_res']), res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
sig=cfg['model']['psr_sigma']).to(device) sig=cfg["model"]["psr_sigma"],
).to(device)
target = dpsr_tmp(pts_gt, norms_gt).unsqueeze(1).to(device) target = dpsr_tmp(pts_gt, norms_gt).unsqueeze(1).to(device)
target = torch.tanh(target) target = torch.tanh(target)
s = target.shape[-1] # size of psr_grid s = target.shape[-1] # size of psr_grid
psr_grid_numpy = target.squeeze().detach().cpu().numpy() psr_grid_numpy = target.squeeze().detach().cpu().numpy()
verts, faces, _, _ = measure.marching_cubes(psr_grid_numpy) verts, faces, _, _ = measure.marching_cubes(psr_grid_numpy)
verts = verts / s * 2. - 1 # [-1, 1] verts = verts / s * 2.0 - 1 # [-1, 1]
mesh = o3d.geometry.TriangleMesh() mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(verts) mesh.vertices = o3d.utility.Vector3dVector(verts)
mesh.triangles = o3d.utility.Vector3iVector(faces) mesh.triangles = o3d.utility.Vector3iVector(faces)
outdir_mesh = os.path.join(cfg['train']['out_dir'], 'oracle_mesh.ply') outdir_mesh = os.path.join(cfg["train"]["out_dir"], "oracle_mesh.ply")
o3d.io.write_triangle_mesh(outdir_mesh, mesh) o3d.io.write_triangle_mesh(outdir_mesh, mesh)
# initialize the source point cloud given an input mesh # initialize the source point cloud given an input mesh
if 'input_mesh' in cfg['train'].keys() and \ if "input_mesh" in cfg["train"].keys() and os.path.isfile(cfg["train"]["input_mesh"]):
os.path.isfile(cfg['train']['input_mesh']): if cfg["train"]["input_mesh"].split("/")[-2] == "mesh":
if cfg['train']['input_mesh'].split('/')[-2] == 'mesh': mesh_tmp = trimesh.load_mesh(cfg["train"]["input_mesh"])
mesh_tmp = trimesh.load_mesh(cfg['train']['input_mesh'])
verts = torch.from_numpy(mesh_tmp.vertices[None]).float().to(device) verts = torch.from_numpy(mesh_tmp.vertices[None]).float().to(device)
faces = torch.from_numpy(mesh_tmp.faces[None]).to(device) faces = torch.from_numpy(mesh_tmp.faces[None]).to(device)
mesh = Meshes(verts=verts, faces=faces) mesh = Meshes(verts=verts, faces=faces)
points, normals = sample_points_from_meshes(mesh, points, normals = sample_points_from_meshes(
num_samples=cfg['data']['num_points'], return_normals=True) mesh, num_samples=cfg["data"]["num_points"], return_normals=True,
)
# mesh is saved in the original scale of the gt # mesh is saved in the original scale of the gt
points -= center.float().to(device) points -= center.float().to(device)
points /= scale.float().to(device) points /= scale.float().to(device)
points *= 0.9 points *= 0.9
# make sure the points are within the range of [0, 1) # make sure the points are within the range of [0, 1)
points = points / 2. + 0.5 points = points / 2.0 + 0.5
else: else:
# directly initialize from a point cloud # directly initialize from a point cloud
pcd = o3d.io.read_point_cloud(cfg['train']['input_mesh']) pcd = o3d.io.read_point_cloud(cfg["train"]["input_mesh"])
points = torch.from_numpy(np.array(pcd.points)[None]).float().to(device) points = torch.from_numpy(np.array(pcd.points)[None]).float().to(device)
normals = torch.from_numpy(np.array(pcd.normals)[None]).float().to(device) normals = torch.from_numpy(np.array(pcd.normals)[None]).float().to(device)
points -= center.float().to(device) points -= center.float().to(device)
points /= scale.float().to(device) points /= scale.float().to(device)
points *= 0.9 points *= 0.9
points = points / 2. + 0.5 points = points / 2.0 + 0.5
else: #! initialize our source point cloud from a sphere else: #! initialize our source point cloud from a sphere
sphere_radius = cfg['model']['sphere_radius'] sphere_radius = cfg["model"]["sphere_radius"]
sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius, sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius, count=[256, 256])
count=[256,256]) points, idx = sphere_mesh.sample(cfg["data"]["num_points"], return_index=True)
points, idx = sphere_mesh.sample(cfg['data']['num_points'], points += 0.5 # make sure the points are within the range of [0, 1)
return_index=True)
points += 0.5 # make sure the points are within the range of [0, 1)
normals = sphere_mesh.face_normals[idx] normals = sphere_mesh.face_normals[idx]
points = torch.from_numpy(points).unsqueeze(0).to(device) points = torch.from_numpy(points).unsqueeze(0).to(device)
normals = torch.from_numpy(normals).unsqueeze(0).to(device) normals = torch.from_numpy(normals).unsqueeze(0).to(device)
points = torch.log(points / (1 - points)) # inverse sigmoid
points = torch.log(points/(1-points)) # inverse sigmoid
inputs = torch.cat([points, normals], axis=-1).float() inputs = torch.cat([points, normals], axis=-1).float()
inputs.requires_grad = True inputs.requires_grad = True
model = None # no network model = None # no network
# initialize optimizer # initialize optimizer
cfg['train']['schedule']['pcl']['initial'] = cfg['train']['lr_pcl'] cfg["train"]["schedule"]["pcl"]["initial"] = cfg["train"]["lr_pcl"]
print('Initial learning rate:', cfg['train']['schedule']['pcl']['initial']) print("Initial learning rate:", cfg["train"]["schedule"]["pcl"]["initial"])
if 'schedule' in cfg['train']: if "schedule" in cfg["train"]:
lr_schedules = get_learning_rate_schedules(cfg['train']['schedule']) lr_schedules = get_learning_rate_schedules(cfg["train"]["schedule"])
else: else:
lr_schedules = None lr_schedules = None
optimizer = update_optimizer(inputs, cfg, optimizer = update_optimizer(inputs, cfg, epoch=0, model=model, schedule=lr_schedules)
epoch=0, model=model, schedule=lr_schedules)
try: try:
# load model # load model
state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt')) state_dict = torch.load(os.path.join(cfg["train"]["out_dir"], "model.pt"))
if ('pcl' in state_dict.keys()) & (state_dict['pcl'] is not None): if ("pcl" in state_dict.keys()) & (state_dict["pcl"] is not None):
inputs = state_dict['pcl'].to(device) inputs = state_dict["pcl"].to(device)
inputs.requires_grad = True inputs.requires_grad = True
optimizer = update_optimizer(inputs, cfg, optimizer = update_optimizer(inputs, cfg, epoch=state_dict.get("epoch"), schedule=lr_schedules)
epoch=state_dict.get('epoch'), schedule=lr_schedules)
out = "Load model from epoch %d" % state_dict.get("epoch", 0)
out = "Load model from epoch %d" % state_dict.get('epoch', 0)
print(out) print(out)
logger.info(out) logger.info(out)
except: except:
state_dict = dict() state_dict = dict()
start_epoch = state_dict.get('epoch', -1) start_epoch = state_dict.get("epoch", -1)
trainer = Trainer(cfg, optimizer, device=device) trainer = Trainer(cfg, optimizer, device=device)
runtime = {} runtime = {}
runtime['all'] = AverageMeter() runtime["all"] = AverageMeter()
# training loop # training loop
for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1): for epoch in range(start_epoch + 1, cfg["train"]["total_epochs"] + 1):
# schedule the learning rate # schedule the learning rate
if (epoch>0) & (lr_schedules is not None): if (epoch > 0) & (lr_schedules is not None):
if (epoch % lr_schedules[0].interval == 0): if epoch % lr_schedules[0].interval == 0:
adjust_learning_rate(lr_schedules, optimizer, epoch) adjust_learning_rate(lr_schedules, optimizer, epoch)
if len(lr_schedules) >1: if len(lr_schedules) > 1:
print('[epoch {}] net_lr: {}, pcl_lr: {}'.format(epoch, print(
lr_schedules[0].get_learning_rate(epoch), "[epoch {}] net_lr: {}, pcl_lr: {}".format(
lr_schedules[1].get_learning_rate(epoch))) epoch, lr_schedules[0].get_learning_rate(epoch), lr_schedules[1].get_learning_rate(epoch),
),
)
else: else:
print('[epoch {}] adjust pcl_lr to: {}'.format(epoch, print(f"[epoch {epoch}] adjust pcl_lr to: {lr_schedules[0].get_learning_rate(epoch)}")
lr_schedules[0].get_learning_rate(epoch)))
start = time.time() start = time.time()
loss, loss_each = trainer.train_step(data, inputs, model, epoch) loss, loss_each = trainer.train_step(data, inputs, model, epoch)
runtime['all'].update(time.time() - start) runtime["all"].update(time.time() - start)
if epoch % cfg['train']['print_every'] == 0: if epoch % cfg["train"]["print_every"] == 0:
log_text = ('[Epoch %02d] loss=%.5f') %(epoch, loss) log_text = ("[Epoch %02d] loss=%.5f") % (epoch, loss)
if loss_each is not None: if loss_each is not None:
for k, l in loss_each.items(): for k, l in loss_each.items():
if l.item() != 0.: if l.item() != 0.0:
log_text += (' loss_%s=%.5f') % (k, l.item()) log_text += f" loss_{k}={l.item():.5f}"
log_text += (' time=%.3f / %.3f') % (runtime['all'].val, log_text += (" time={:.3f} / {:.3f}").format(runtime["all"].val, runtime["all"].sum)
runtime['all'].sum)
logger.info(log_text) logger.info(log_text)
print(log_text) print(log_text)
# visualize point clouds and meshes # visualize point clouds and meshes
if (epoch % cfg['train']['visualize_every'] == 0) & (vis is not None): if (epoch % cfg["train"]["visualize_every"] == 0) & (vis is not None):
trainer.visualize(data, inputs, model, epoch, o3d_vis=vis) trainer.visualize(data, inputs, model, epoch, o3d_vis=vis)
# save outputs # save outputs
if epoch % cfg['train']['save_every'] == 0: if epoch % cfg["train"]["save_every"] == 0:
trainer.save_mesh_pointclouds(inputs, epoch, trainer.save_mesh_pointclouds(inputs, epoch, center.cpu().numpy(), scale.cpu().numpy() * (1 / 0.9))
center.cpu().numpy(),
scale.cpu().numpy()*(1/0.9))
# save checkpoints # save checkpoints
if (epoch > 0) & (epoch % cfg['train']['checkpoint_every'] == 0): if (epoch > 0) & (epoch % cfg["train"]["checkpoint_every"] == 0):
state = {'epoch': epoch} state = {"epoch": epoch}
pcl = None
if isinstance(inputs, torch.Tensor): if isinstance(inputs, torch.Tensor):
state['pcl'] = inputs.detach().cpu() state["pcl"] = inputs.detach().cpu()
torch.save(state, os.path.join(cfg['train']['dir_model'], torch.save(state, os.path.join(cfg["train"]["dir_model"], "%04d" % epoch + ".pt"))
'%04d' % epoch + '.pt'))
print("Save new model at epoch %d" % epoch) print("Save new model at epoch %d" % epoch)
logger.info("Save new model at epoch %d" % epoch) logger.info("Save new model at epoch %d" % epoch)
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt')) torch.save(state, os.path.join(cfg["train"]["out_dir"], "model.pt"))
# resample and gradually add new points to the source pcl # resample and gradually add new points to the source pcl
if (epoch > 0) & \ if (
(cfg['train']['resample_every']!=0) & \ (epoch > 0)
(epoch % cfg['train']['resample_every'] == 0) & \ & (cfg["train"]["resample_every"] != 0)
(epoch < cfg['train']['total_epochs']): & (epoch % cfg["train"]["resample_every"] == 0)
inputs = trainer.point_resampling(inputs) & (epoch < cfg["train"]["total_epochs"])
optimizer = update_optimizer(inputs, cfg, ):
epoch=epoch, model=model, schedule=lr_schedules) inputs = trainer.point_resampling(inputs)
trainer = Trainer(cfg, optimizer, device=device) optimizer = update_optimizer(inputs, cfg, epoch=epoch, model=model, schedule=lr_schedules)
trainer = Trainer(cfg, optimizer, device=device)
# visualize the Open3D outputs # visualize the Open3D outputs
if cfg['train']['o3d_show']: if cfg["train"]["o3d_show"]:
out_video_dir = os.path.join(cfg['train']['out_dir'], out_video_dir = os.path.join(cfg["train"]["out_dir"], "vis/o3d/video.mp4")
'vis/o3d/video.mp4')
if os.path.isfile(out_video_dir): if os.path.isfile(out_video_dir):
os.system('rm {}'.format(out_video_dir)) os.system(f"rm {out_video_dir}")
os.system('ffmpeg -framerate 30 \ os.system(
"ffmpeg -framerate 30 \
-start_number 0 \ -start_number 0 \
-i {}/vis/o3d/%04d.jpg \ -i {}/vis/o3d/%04d.jpg \
-pix_fmt yuv420p \ -pix_fmt yuv420p \
-crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir)) -crf 17 {}".format(
out_video_dir = os.path.join(cfg['train']['out_dir'], cfg["train"]["out_dir"], out_video_dir,
'vis/o3d/video_pcd.mp4') ),
)
out_video_dir = os.path.join(cfg["train"]["out_dir"], "vis/o3d/video_pcd.mp4")
if os.path.isfile(out_video_dir): if os.path.isfile(out_video_dir):
os.system('rm {}'.format(out_video_dir)) os.system(f"rm {out_video_dir}")
os.system('ffmpeg -framerate 30 \ os.system(
"ffmpeg -framerate 30 \
-start_number 0 \ -start_number 0 \
-i {}/vis/o3d/%04d_pcd.jpg \ -i {}/vis/o3d/%04d_pcd.jpg \
-pix_fmt yuv420p \ -pix_fmt yuv420p \
-crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir)) -crf 17 {}".format(
print('Video saved.') cfg["train"]["out_dir"], out_video_dir,
),
)
print("Video saved.")
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View file

@ -1,69 +1,80 @@
import sys, os
import argparse import argparse
import os
from src.utils import load_config from src.utils import load_config
import subprocess
os.environ['MKL_THREADING_LAYER'] = 'GNU' os.environ["MKL_THREADING_LAYER"] = "GNU"
def main(): def main():
parser = argparse.ArgumentParser(description="MNIST toy experiment")
parser.add_argument("config", type=str, help="Path to config file.")
parser.add_argument("--start_res", type=int, default=-1, help="Resolution to start with.")
parser.add_argument("--object_id", type=int, default=-1, help="Object index.")
parser = argparse.ArgumentParser(description='MNIST toy experiment') args, unknown = parser.parse_known_args()
parser.add_argument('config', type=str, help='Path to config file.') cfg = load_config(args.config, "configs/default.yaml")
parser.add_argument('--start_res', type=int, default=-1, help='Resolution to start with.')
parser.add_argument('--object_id', type=int, default=-1, help='Object index.')
args, unknown = parser.parse_known_args() resolutions = [32, 64, 128, 256]
cfg = load_config(args.config, 'configs/default.yaml') iterations = [1000, 1000, 1000, 200]
lrs = [2e-3, 2e-3 * 0.7, 2e-3 * (0.7**2), 2e-3 * (0.7**3)] # reduce lr
resolutions=[32, 64, 128, 256] for idx, (res, iteration, lr) in enumerate(zip(resolutions, iterations, lrs)):
iterations=[1000, 1000, 1000, 200] if res < args.start_res:
lrs=[2e-3, 2e-3*0.7, 2e-3*(0.7**2), 2e-3*(0.7**3)] # reduce lr
for idx,(res, iteration, lr) in enumerate(zip(resolutions, iterations, lrs)):
if res<args.start_res:
continue continue
if res>cfg['model']['grid_res']: if res > cfg["model"]["grid_res"]:
continue continue
psr_sigma= 2 if res<=128 else 3 psr_sigma = 2 if res <= 128 else 3
if res > 128: if res > 128:
psr_sigma = 5 if 'thingi_noisy' in args.config else 3 psr_sigma = 5 if "thingi_noisy" in args.config else 3
if args.object_id != -1: if args.object_id != -1:
out_dir = os.path.join(cfg['train']['out_dir'], 'object_%02d'%args.object_id, 'res_%d'%res) out_dir = os.path.join(cfg["train"]["out_dir"], "object_%02d" % args.object_id, "res_%d" % res)
else: else:
out_dir = os.path.join(cfg['train']['out_dir'], 'res_%d'%res) out_dir = os.path.join(cfg["train"]["out_dir"], "res_%d" % res)
# sample from mesh when resampling is enabled, otherwise reuse the pointcloud # sample from mesh when resampling is enabled, otherwise reuse the pointcloud
init_shape='mesh' if cfg['train']['resample_every']>0 else 'pointcloud' init_shape = "mesh" if cfg["train"]["resample_every"] > 0 else "pointcloud"
if args.object_id != -1: if args.object_id != -1:
input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'], input_mesh = (
'object_%02d'%args.object_id, 'res_%d' % (resolutions[idx-1]), "None"
'vis', init_shape, '%04d.ply' % (iterations[idx-1])) if idx == 0
else os.path.join(
cfg["train"]["out_dir"],
"object_%02d" % args.object_id,
"res_%d" % (resolutions[idx - 1]),
"vis",
init_shape,
"%04d.ply" % (iterations[idx - 1]),
)
)
else: else:
input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'], input_mesh = (
'res_%d' % (resolutions[idx-1]), "None"
'vis', init_shape, '%04d.ply' % (iterations[idx-1])) if idx == 0
else os.path.join(
cfg["train"]["out_dir"],
cmd = 'export MKL_SERVICE_FORCE_INTEL=1 && ' "res_%d" % (resolutions[idx - 1]),
cmd += "python optim.py %s --model:grid_res %d --model:psr_sigma %d \ "vis",
init_shape,
"%04d.ply" % (iterations[idx - 1]),
)
)
cmd = "export MKL_SERVICE_FORCE_INTEL=1 && "
cmd += (
"python optim.py %s --model:grid_res %d --model:psr_sigma %d \
--train:input_mesh %s --train:total_epochs %d \ --train:input_mesh %s --train:total_epochs %d \
--train:out_dir %s --train:lr_pcl %f \ --train:out_dir %s --train:lr_pcl %f \
--data:object_id %d" % ( --data:object_id %d"
args.config, % (args.config, res, psr_sigma, input_mesh, iteration, out_dir, lr, args.object_id)
res, )
psr_sigma,
input_mesh,
iteration,
out_dir,
lr,
args.object_id)
print(cmd) print(cmd)
os.system(cmd) os.system(cmd)
if __name__=="__main__":
if __name__ == "__main__":
main() main()

View file

@ -1,14 +1,16 @@
import os
import torch
import time
import multiprocessing import multiprocessing
import os
import time
import numpy as np import numpy as np
import torch
from tqdm import tqdm from tqdm import tqdm
from src.dpsr import DPSR from src.dpsr import DPSR
data_path = 'data/ShapeNet' # path for ShapeNet from ONet data_path = "data/ShapeNet" # path for ShapeNet from ONet
base = 'data' # output base directory base = "data" # output base directory
dataset_name = 'shapenet_psr' dataset_name = "shapenet_psr"
multiprocess = True multiprocess = True
njobs = 8 njobs = 8
save_pointcloud = True save_pointcloud = True
@ -20,52 +22,56 @@ padding = 1.2
dpsr = DPSR(res=(resolution, resolution, resolution), sig=0) dpsr = DPSR(res=(resolution, resolution, resolution), sig=0)
def process_one(obj):
obj_name = obj.split('/')[-1] def process_one(obj):
c = obj.split('/')[-2] obj_name = obj.split("/")[-1]
c = obj.split("/")[-2]
# create new for the current object # create new for the current object
out_path_cur = os.path.join(base, dataset_name, c) out_path_cur = os.path.join(base, dataset_name, c)
out_path_cur_obj = os.path.join(out_path_cur, obj_name) out_path_cur_obj = os.path.join(out_path_cur, obj_name)
os.makedirs(out_path_cur_obj, exist_ok=True) os.makedirs(out_path_cur_obj, exist_ok=True)
gt_path = os.path.join(data_path, c, obj_name, 'pointcloud.npz') gt_path = os.path.join(data_path, c, obj_name, "pointcloud.npz")
data = np.load(gt_path) data = np.load(gt_path)
points = data['points'] points = data["points"]
normals = data['normals'] normals = data["normals"]
# normalize the point to [0, 1) # normalize the point to [0, 1)
points = points / padding + 0.5 points = points / padding + 0.5
# to scale back during inference, we should: # to scale back during inference, we should:
#! p = (p - 0.5) * padding #! p = (p - 0.5) * padding
if save_pointcloud:
outdir = os.path.join(out_path_cur_obj, 'pointcloud.npz')
# np.savez(outdir, points=points, normals=normals)
np.savez(outdir, points=data['points'], normals=data['normals'])
# return
if save_psr_field:
psr_gt = dpsr(torch.from_numpy(points.astype(np.float32))[None],
torch.from_numpy(normals.astype(np.float32))[None]).squeeze().cpu().numpy().astype(np.float16)
outdir = os.path.join(out_path_cur_obj, 'psr.npz') if save_pointcloud:
outdir = os.path.join(out_path_cur_obj, "pointcloud.npz")
# np.savez(outdir, points=points, normals=normals)
np.savez(outdir, points=data["points"], normals=data["normals"])
# return
if save_psr_field:
psr_gt = (
dpsr(torch.from_numpy(points.astype(np.float32))[None], torch.from_numpy(normals.astype(np.float32))[None])
.squeeze()
.cpu()
.numpy()
.astype(np.float16)
)
outdir = os.path.join(out_path_cur_obj, "psr.npz")
np.savez(outdir, psr=psr_gt) np.savez(outdir, psr=psr_gt)
def main(c): def main(c):
print("---------------------------------------")
print(f"Processing {c} {split}")
print("---------------------------------------")
print('---------------------------------------') for split in ["train", "val", "test"]:
print('Processing {} {}'.format(c, split)) fname = os.path.join(data_path, c, split + ".lst")
print('---------------------------------------') with open(fname) as f:
obj_list = f.read().splitlines()
for split in ['train', 'val', 'test']: obj_list = [c + "/" + s for s in obj_list]
fname = os.path.join(data_path, c, split+'.lst')
with open(fname, 'r') as f:
obj_list = f.read().splitlines()
obj_list = [c+'/'+s for s in obj_list]
if multiprocess: if multiprocess:
# multiprocessing.set_start_method('spawn', force=True) # multiprocessing.set_start_method('spawn', force=True)
@ -81,21 +87,30 @@ def main(c):
else: else:
for obj in tqdm(obj_list): for obj in tqdm(obj_list):
process_one(obj) process_one(obj)
print('Done Processing {} {}!'.format(c, split)) print(f"Done Processing {c} {split}!")
if __name__ == "__main__": if __name__ == "__main__":
classes = [
"02691156",
"02828884",
"02933112",
"02958343",
"03211117",
"03001627",
"03636649",
"03691459",
"04090263",
"04256520",
"04379243",
"04401088",
"04530566",
]
classes = ['02691156', '02828884', '02933112',
'02958343', '03211117', '03001627',
'03636649', '03691459', '04090263',
'04256520', '04379243', '04401088', '04530566']
t_start = time.time() t_start = time.time()
for c in classes: for c in classes:
main(c) main(c)
t_end = time.time() t_end = time.time()
print('Total processing time: ', t_end - t_start) print("Total processing time: ", t_end - t_start)

View file

@ -1,146 +1,151 @@
import yaml
from torchvision import transforms from torchvision import transforms
from src import data, generation from src import data, generation
from src.dpsr import DPSR from src.dpsr import DPSR
from ipdb import set_trace as st
# Generator for final mesh extraction # Generator for final mesh extraction
def get_generator(model, cfg, device, **kwargs): def get_generator(model, cfg, device, **kwargs):
''' Returns the generator object. """Returns the generator object.
Args: Args:
model (nn.Module): Occupancy Network model model (nn.Module): Occupancy Network model
cfg (dict): imported yaml config cfg (dict): imported yaml config
device (device): pytorch device device (device): pytorch device
''' """
if cfg["generation"]["psr_resolution"] == 0:
if cfg['generation']['psr_resolution'] == 0: psr_res = cfg["model"]["grid_res"]
psr_res = cfg['model']['grid_res'] psr_sigma = cfg["model"]["psr_sigma"]
psr_sigma = cfg['model']['psr_sigma']
else: else:
psr_res = cfg['generation']['psr_resolution'] psr_res = cfg["generation"]["psr_resolution"]
psr_sigma = cfg['generation']['psr_sigma'] psr_sigma = cfg["generation"]["psr_sigma"]
dpsr = DPSR(res=(psr_res, psr_res, psr_res), dpsr = DPSR(res=(psr_res, psr_res, psr_res), sig=psr_sigma).to(device)
sig= psr_sigma).to(device)
generator = generation.Generator3D( generator = generation.Generator3D(
model, model,
device=device, device=device,
threshold=cfg['data']['zero_level'], threshold=cfg["data"]["zero_level"],
sample=cfg['generation']['use_sampling'], sample=cfg["generation"]["use_sampling"],
input_type = cfg['data']['input_type'], input_type=cfg["data"]["input_type"],
padding=cfg['data']['padding'], padding=cfg["data"]["padding"],
dpsr=dpsr, dpsr=dpsr,
psr_tanh=cfg['model']['psr_tanh'] psr_tanh=cfg["model"]["psr_tanh"],
) )
return generator return generator
# Datasets # Datasets
def get_dataset(mode, cfg, return_idx=False): def get_dataset(mode, cfg, return_idx=False):
''' Returns the dataset. """Returns the dataset.
Args: Args:
model (nn.Module): the model which is used model (nn.Module): the model which is used
cfg (dict): config dictionary cfg (dict): config dictionary
return_idx (bool): whether to include an ID field return_idx (bool): whether to include an ID field
''' """
dataset_type = cfg['data']['dataset'] dataset_type = cfg["data"]["dataset"]
dataset_folder = cfg['data']['path'] dataset_folder = cfg["data"]["path"]
categories = cfg['data']['class'] categories = cfg["data"]["class"]
# Get split # Get split
splits = { splits = {
'train': cfg['data']['train_split'], "train": cfg["data"]["train_split"],
'val': cfg['data']['val_split'], "val": cfg["data"]["val_split"],
'test': cfg['data']['test_split'], "test": cfg["data"]["test_split"],
'vis': cfg['data']['val_split'], "vis": cfg["data"]["val_split"],
} }
split = splits[mode] split = splits[mode]
# Create dataset # Create dataset
if dataset_type == 'Shapes3D': if dataset_type == "Shapes3D":
fields = get_data_fields(mode, cfg) fields = get_data_fields(mode, cfg)
# Input fields # Input fields
inputs_field = get_inputs_field(mode, cfg) inputs_field = get_inputs_field(mode, cfg)
if inputs_field is not None: if inputs_field is not None:
fields['inputs'] = inputs_field fields["inputs"] = inputs_field
if return_idx: if return_idx:
fields['idx'] = data.IndexField() fields["idx"] = data.IndexField()
dataset = data.Shapes3dDataset( dataset = data.Shapes3dDataset(
dataset_folder, fields, dataset_folder,
fields,
split=split, split=split,
categories=categories, categories=categories,
cfg = cfg cfg=cfg,
) )
else: else:
raise ValueError('Invalid dataset "%s"' % cfg['data']['dataset']) raise ValueError('Invalid dataset "%s"' % cfg["data"]["dataset"])
return dataset return dataset
def get_inputs_field(mode, cfg): def get_inputs_field(mode, cfg):
''' Returns the inputs fields. """Returns the inputs fields.
Args: Args:
mode (str): the mode which is used mode (str): the mode which is used
cfg (dict): config dictionary cfg (dict): config dictionary
''' """
input_type = cfg['data']['input_type'] input_type = cfg["data"]["input_type"]
if input_type is None: if input_type is None:
inputs_field = None inputs_field = None
elif input_type == 'pointcloud': elif input_type == "pointcloud":
noise_level = cfg['data']['pointcloud_noise'] noise_level = cfg["data"]["pointcloud_noise"]
if cfg['data']['pointcloud_outlier_ratio']>0: if cfg["data"]["pointcloud_outlier_ratio"] > 0:
transform = transforms.Compose([ transform = transforms.Compose(
data.SubsamplePointcloud(cfg['data']['pointcloud_n']), [
data.PointcloudNoise(noise_level), data.SubsamplePointcloud(cfg["data"]["pointcloud_n"]),
data.PointcloudOutliers(cfg['data']['pointcloud_outlier_ratio']) data.PointcloudNoise(noise_level),
]) data.PointcloudOutliers(cfg["data"]["pointcloud_outlier_ratio"]),
],
)
else: else:
transform = transforms.Compose([ transform = transforms.Compose(
data.SubsamplePointcloud(cfg['data']['pointcloud_n']), [
data.PointcloudNoise(noise_level) data.SubsamplePointcloud(cfg["data"]["pointcloud_n"]),
]) data.PointcloudNoise(noise_level),
],
)
data_type = cfg['data']['data_type'] data_type = cfg["data"]["data_type"]
inputs_field = data.PointCloudField( inputs_field = data.PointCloudField(
cfg['data']['pointcloud_file'], data_type, transform, cfg["data"]["pointcloud_file"],
multi_files= cfg['data']['multi_files'] data_type,
) transform,
multi_files=cfg["data"]["multi_files"],
)
else: else:
raise ValueError( raise ValueError("Invalid input type (%s)" % input_type)
'Invalid input type (%s)' % input_type)
return inputs_field return inputs_field
def get_data_fields(mode, cfg): def get_data_fields(mode, cfg):
''' Returns the data fields. """Returns the data fields.
Args: Args:
mode (str): the mode which is used mode (str): the mode which is used
cfg (dict): imported yaml config cfg (dict): imported yaml config
''' """
data_type = cfg['data']['data_type'] data_type = cfg["data"]["data_type"]
fields = {} fields = {}
if (mode in ('val', 'test')): if mode in ("val", "test"):
transform = data.SubsamplePointcloud(100000) transform = data.SubsamplePointcloud(100000)
else: else:
transform = data.SubsamplePointcloud(cfg['data']['num_gt_points']) transform = data.SubsamplePointcloud(cfg["data"]["num_gt_points"])
data_name = cfg['data']['pointcloud_file']
fields['gt_points'] = data.PointCloudField(data_name,
transform=transform, data_type=data_type, multi_files=cfg['data']['multi_files'])
if data_type == 'psr_full':
if mode != 'test':
fields['gt_psr'] = data.FullPSRField(multi_files=cfg['data']['multi_files'])
else:
raise ValueError('Invalid data type (%s)' % data_type)
return fields data_name = cfg["data"]["pointcloud_file"]
fields["gt_points"] = data.PointCloudField(
data_name, transform=transform, data_type=data_type, multi_files=cfg["data"]["multi_files"],
)
if data_type == "psr_full":
if mode != "test":
fields["gt_psr"] = data.FullPSRField(multi_files=cfg["data"]["multi_files"])
else:
raise ValueError("Invalid data type (%s)" % data_type)
return fields

View file

@ -1,14 +1,12 @@
from src.data.core import ( from src.data.core import (
Shapes3dDataset, collate_remove_none, worker_init_fn, collate_stack_together Shapes3dDataset,
) collate_remove_none,
from src.data.fields import ( collate_stack_together,
IndexField, PointCloudField, FullPSRField worker_init_fn,
)
from src.data.transforms import (
PointcloudNoise, SubsamplePointcloud,
PointcloudOutliers,
) )
from src.data.fields import FullPSRField, IndexField, PointCloudField
from src.data.transforms import PointcloudNoise, PointcloudOutliers, SubsamplePointcloud
__all__ = [ __all__ = [
# Core # Core
Shapes3dDataset, Shapes3dDataset,

View file

@ -1,45 +1,41 @@
import os
import logging import logging
from torch.utils import data import os
from pdb import set_trace as st
import numpy as np import numpy as np
import yaml import yaml
from torch.utils import data
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Fields # Fields
class Field(object): class Field:
''' Data fields class. """Data fields class."""
'''
def load(self, data_path, idx, category): def load(self, data_path, idx, category):
''' Loads a data point. """Loads a data point.
Args: Args:
data_path (str): path to data file data_path (str): path to data file
idx (int): index of data point idx (int): index of data point
category (int): index of category category (int): index of category
''' """
raise NotImplementedError raise NotImplementedError
def check_complete(self, files): def check_complete(self, files):
''' Checks if set is complete. """Checks if set is complete.
Args: Args:
files: files files: files
''' """
raise NotImplementedError raise NotImplementedError
class Shapes3dDataset(data.Dataset): class Shapes3dDataset(data.Dataset):
''' 3D Shapes dataset class. """3D Shapes dataset class."""
'''
def __init__(self, dataset_folder, fields, split=None, def __init__(self, dataset_folder, fields, split=None, categories=None, no_except=True, transform=None, cfg=None):
categories=None, no_except=True, transform=None, cfg=None): """Initialization of the the 3D shape dataset.
''' Initialization of the the 3D shape dataset.
Args: Args:
dataset_folder (str): dataset folder dataset_folder (str): dataset folder
@ -49,7 +45,7 @@ class Shapes3dDataset(data.Dataset):
no_except (bool): no exception no_except (bool): no exception
transform (callable): transformation applied to data points transform (callable): transformation applied to data points
cfg (yaml): config file cfg (yaml): config file
''' """
# Attributes # Attributes
self.dataset_folder = dataset_folder self.dataset_folder = dataset_folder
self.fields = fields self.fields = fields
@ -60,76 +56,69 @@ class Shapes3dDataset(data.Dataset):
# If categories is None, use all subfolders # If categories is None, use all subfolders
if categories is None: if categories is None:
categories = os.listdir(dataset_folder) categories = os.listdir(dataset_folder)
categories = [c for c in categories categories = [c for c in categories if os.path.isdir(os.path.join(dataset_folder, c))]
if os.path.isdir(os.path.join(dataset_folder, c))]
# Read metadata file # Read metadata file
metadata_file = os.path.join(dataset_folder, 'metadata.yaml') metadata_file = os.path.join(dataset_folder, "metadata.yaml")
if os.path.exists(metadata_file): if os.path.exists(metadata_file):
with open(metadata_file, 'r') as f: with open(metadata_file) as f:
self.metadata = yaml.load(f, Loader=yaml.Loader) self.metadata = yaml.load(f, Loader=yaml.Loader)
else: else:
self.metadata = { self.metadata = {c: {"id": c, "name": "n/a"} for c in categories}
c: {'id': c, 'name': 'n/a'} for c in categories
}
# Set index # Set index
for c_idx, c in enumerate(categories): for c_idx, c in enumerate(categories):
self.metadata[c]['idx'] = c_idx self.metadata[c]["idx"] = c_idx
# Get all models # Get all models
self.models = [] self.models = []
for c_idx, c in enumerate(categories): for c_idx, c in enumerate(categories):
subpath = os.path.join(dataset_folder, c) subpath = os.path.join(dataset_folder, c)
if not os.path.isdir(subpath): if not os.path.isdir(subpath):
logger.warning('Category %s does not exist in dataset.' % c) logger.warning("Category %s does not exist in dataset." % c)
if split is None: if split is None:
self.models += [ self.models += [
{'category': c, 'model': m} for m in [d for d in os.listdir(subpath) if (os.path.isdir(os.path.join(subpath, d)) and d != '') ] {"category": c, "model": m}
for m in [d for d in os.listdir(subpath) if (os.path.isdir(os.path.join(subpath, d)) and d != "")]
] ]
else: else:
split_file = os.path.join(subpath, split + '.lst') split_file = os.path.join(subpath, split + ".lst")
with open(split_file, 'r') as f: with open(split_file) as f:
models_c = f.read().split('\n') models_c = f.read().split("\n")
if '' in models_c: if "" in models_c:
models_c.remove('') models_c.remove("")
self.models += [{"category": c, "model": m} for m in models_c]
self.models += [
{'category': c, 'model': m}
for m in models_c
]
# precompute # precompute
self.split = split self.split = split
def __len__(self): def __len__(self):
''' Returns the length of the dataset. """Returns the length of the dataset."""
'''
return len(self.models) return len(self.models)
def __getitem__(self, idx): def __getitem__(self, idx):
''' Returns an item of the dataset. """Returns an item of the dataset.
Args: Args:
idx (int): ID of data point idx (int): ID of data point
''' """
category = self.models[idx]["category"]
category = self.models[idx]['category'] model = self.models[idx]["model"]
model = self.models[idx]['model'] c_idx = self.metadata[category]["idx"]
c_idx = self.metadata[category]['idx']
model_path = os.path.join(self.dataset_folder, category, model) model_path = os.path.join(self.dataset_folder, category, model)
data = {} data = {}
info = c_idx info = c_idx
if self.cfg['data']['multi_files'] is not None: if self.cfg["data"]["multi_files"] is not None:
idx = np.random.randint(self.cfg['data']['multi_files']) idx = np.random.randint(self.cfg["data"]["multi_files"])
if self.split != 'train': if self.split != "train":
idx = 0 idx = 0
for field_name, field in self.fields.items(): for field_name, field in self.fields.items():
@ -137,9 +126,8 @@ class Shapes3dDataset(data.Dataset):
field_data = field.load(model_path, idx, info) field_data = field.load(model_path, idx, info)
except Exception: except Exception:
if self.no_except: if self.no_except:
logger.warn( logger.warning(
'Error occured when loading field %s of model %s' f"Error occured when loading field {field_name} of model {model}",
% (field_name, model)
) )
return None return None
else: else:
@ -150,7 +138,7 @@ class Shapes3dDataset(data.Dataset):
if k is None: if k is None:
data[field_name] = v data[field_name] = v
else: else:
data['%s.%s' % (field_name, k)] = v data[f"{field_name}.{k}"] = v
else: else:
data[field_name] = field_data data[field_name] = field_data
@ -158,78 +146,76 @@ class Shapes3dDataset(data.Dataset):
data = self.transform(data) data = self.transform(data)
return data return data
def get_model_dict(self, idx): def get_model_dict(self, idx):
return self.models[idx] return self.models[idx]
def test_model_complete(self, category, model): def test_model_complete(self, category, model):
''' Tests if model is complete. """Tests if model is complete.
Args: Args:
model (str): modelname model (str): modelname
''' """
model_path = os.path.join(self.dataset_folder, category, model) model_path = os.path.join(self.dataset_folder, category, model)
files = os.listdir(model_path) files = os.listdir(model_path)
for field_name, field in self.fields.items(): for field_name, field in self.fields.items():
if not field.check_complete(files): if not field.check_complete(files):
logger.warn('Field "%s" is incomplete: %s' logger.warning(f'Field "{field_name}" is incomplete: {model_path}')
% (field_name, model_path))
return False return False
return True return True
def collate_remove_none(batch): def collate_remove_none(batch):
''' Collater that puts each data field into a tensor with outer dimension """Collater that puts each data field into a tensor with outer dimension
batch size. batch size.
Args: Args:
batch: batch batch: batch
''' """
batch = list(filter(lambda x: x is not None, batch)) batch = list(filter(lambda x: x is not None, batch))
return data.dataloader.default_collate(batch) return data.dataloader.default_collate(batch)
def collate_stack_together(batch): def collate_stack_together(batch):
''' Collater that puts each data field into a tensor with outer dimension """Collater that puts each data field into a tensor with outer dimension
batch size. batch size.
Args: Args:
batch: batch batch: batch
''' """
batch = list(filter(lambda x: x is not None, batch)) batch = list(filter(lambda x: x is not None, batch))
keys = batch[0].keys() keys = batch[0].keys()
concat = {} concat = {}
if len(batch)>1: if len(batch) > 1:
for key in keys: for key in keys:
key_val = [item[key] for item in batch] key_val = [item[key] for item in batch]
concat[key] = np.concatenate(key_val, axis=0) concat[key] = np.concatenate(key_val, axis=0)
if key == 'inputs': if key == "inputs":
n_pts = [item[key].shape[0] for item in batch] n_pts = [item[key].shape[0] for item in batch]
concat['batch_ind'] = np.concatenate( concat["batch_ind"] = np.concatenate([i * np.ones(n, dtype=int) for i, n in enumerate(n_pts)], axis=0)
[i * np.ones(n, dtype=int) for i, n in enumerate(n_pts)], axis=0)
return data.dataloader.default_collate([concat]) return data.dataloader.default_collate([concat])
else: else:
n_pts = batch[0]['inputs'].shape[0] n_pts = batch[0]["inputs"].shape[0]
batch[0]['batch_ind'] = np.zeros(n_pts, dtype=int) batch[0]["batch_ind"] = np.zeros(n_pts, dtype=int)
return data.dataloader.default_collate(batch) return data.dataloader.default_collate(batch)
def worker_init_fn(worker_id): def worker_init_fn(worker_id):
''' Worker init function to ensure true randomness. """Worker init function to ensure true randomness."""
'''
def set_num_threads(nt): def set_num_threads(nt):
try: try:
import mkl; mkl.set_num_threads(nt) import mkl
except:
mkl.set_num_threads(nt)
except:
pass pass
torch.set_num_threads(1) torch.set_num_threads(1)
os.environ['IPC_ENABLE']='1' os.environ["IPC_ENABLE"] = "1"
for o in ['OPENBLAS_NUM_THREADS','NUMEXPR_NUM_THREADS','OMP_NUM_THREADS','MKL_NUM_THREADS']: for o in ["OPENBLAS_NUM_THREADS", "NUMEXPR_NUM_THREADS", "OMP_NUM_THREADS", "MKL_NUM_THREADS"]:
os.environ[o] = str(nt) os.environ[o] = str(nt)
random_data = os.urandom(4) random_data = os.urandom(4)

View file

@ -1,63 +1,61 @@
import os import os
import glob
import time
import random
from PIL import Image
import numpy as np import numpy as np
import trimesh
from src.data.core import Field from src.data.core import Field
from pdb import set_trace as st
class IndexField(Field): class IndexField(Field):
''' Basic index field.''' """Basic index field."""
def load(self, model_path, idx, category): def load(self, model_path, idx, category):
''' Loads the index field. """Loads the index field.
Args: Args:
model_path (str): path to model model_path (str): path to model
idx (int): ID of data point idx (int): ID of data point
category (int): index of category category (int): index of category
''' """
return idx return idx
def check_complete(self, files): def check_complete(self, files):
''' Check if field is complete. """Check if field is complete.
Args: Args:
files: files files: files
''' """
return True return True
class FullPSRField(Field): class FullPSRField(Field):
def __init__(self, transform=None, multi_files=None): def __init__(self, transform=None, multi_files=None):
self.transform = transform self.transform = transform
# self.unpackbits = unpackbits # self.unpackbits = unpackbits
self.multi_files = multi_files self.multi_files = multi_files
def load(self, model_path, idx, category):
def load(self, model_path, idx, category):
# try: # try:
# t0 = time.time() # t0 = time.time()
if self.multi_files is not None: if self.multi_files is not None:
psr_path = os.path.join(model_path, 'psr', 'psr_{:02d}.npz'.format(idx)) psr_path = os.path.join(model_path, "psr", f"psr_{idx:02d}.npz")
else: else:
psr_path = os.path.join(model_path, 'psr.npz') psr_path = os.path.join(model_path, "psr.npz")
psr_dict = np.load(psr_path) psr_dict = np.load(psr_path)
# t1 = time.time() # t1 = time.time()
psr = psr_dict['psr'] psr = psr_dict["psr"]
psr = psr.astype(np.float32) psr = psr.astype(np.float32)
# t2 = time.time() # t2 = time.time()
# print('load PSR: {:.4f}, change type: {:.4f}, total: {:.4f}'.format(t1 - t0, t2 - t1, t2-t0)) # print('load PSR: {:.4f}, change type: {:.4f}, total: {:.4f}'.format(t1 - t0, t2 - t1, t2-t0))
data = {None: psr} data = {None: psr}
if self.transform is not None: if self.transform is not None:
data = self.transform(data) data = self.transform(data)
return data return data
class PointCloudField(Field): class PointCloudField(Field):
''' Point cloud field. """Point cloud field.
It provides the field used for point cloud data. These are the points It provides the field used for point cloud data. These are the points
randomly sampled on the mesh. randomly sampled on the mesh.
@ -66,53 +64,54 @@ class PointCloudField(Field):
file_name (str): file name file_name (str): file name
transform (list): list of transformations applied to data points transform (list): list of transformations applied to data points
multi_files (callable): number of files multi_files (callable): number of files
''' """
def __init__(self, file_name, data_type=None, transform=None, multi_files=None, padding=0.1, scale=1.2): def __init__(self, file_name, data_type=None, transform=None, multi_files=None, padding=0.1, scale=1.2):
self.file_name = file_name self.file_name = file_name
self.data_type = data_type # to make sure the range of input is correct self.data_type = data_type # to make sure the range of input is correct
self.transform = transform self.transform = transform
self.multi_files = multi_files self.multi_files = multi_files
self.padding = padding self.padding = padding
self.scale = scale self.scale = scale
def load(self, model_path, idx, category): def load(self, model_path, idx, category):
''' Loads the data point. """Loads the data point.
Args: Args:
model_path (str): path to model model_path (str): path to model
idx (int): ID of data point idx (int): ID of data point
category (int): index of category category (int): index of category
''' """
if self.multi_files is None: if self.multi_files is None:
file_path = os.path.join(model_path, self.file_name) file_path = os.path.join(model_path, self.file_name)
else: else:
# num = np.random.randint(self.multi_files) # num = np.random.randint(self.multi_files)
# file_path = os.path.join(model_path, self.file_name, '%s_%02d.npz' % (self.file_name, num)) # file_path = os.path.join(model_path, self.file_name, '%s_%02d.npz' % (self.file_name, num))
file_path = os.path.join(model_path, self.file_name, 'pointcloud_%02d.npz' % (idx)) file_path = os.path.join(model_path, self.file_name, "pointcloud_%02d.npz" % (idx))
pointcloud_dict = np.load(file_path) pointcloud_dict = np.load(file_path)
points = pointcloud_dict['points'].astype(np.float32) points = pointcloud_dict["points"].astype(np.float32)
normals = pointcloud_dict['normals'].astype(np.float32) normals = pointcloud_dict["normals"].astype(np.float32)
data = { data = {
None: points, None: points,
'normals': normals, "normals": normals,
} }
if self.transform is not None: if self.transform is not None:
data = self.transform(data) data = self.transform(data)
if self.data_type == 'psr_full': if self.data_type == "psr_full":
# scale the point cloud to the range of (0, 1) # scale the point cloud to the range of (0, 1)
data[None] = data[None] / self.scale + 0.5 data[None] = data[None] / self.scale + 0.5
return data return data
def check_complete(self, files): def check_complete(self, files):
''' Check if field is complete. """Check if field is complete.
Args: Args:
files: files files: files
''' """
complete = (self.file_name in files) complete = self.file_name in files
return complete return complete

View file

@ -2,24 +2,24 @@ import numpy as np
# Transforms # Transforms
class PointcloudNoise(object): class PointcloudNoise:
''' Point cloud noise transformation class. """Point cloud noise transformation class.
It adds noise to point cloud data. It adds noise to point cloud data.
Args: Args:
stddev (int): standard deviation stddev (int): standard deviation
''' """
def __init__(self, stddev): def __init__(self, stddev):
self.stddev = stddev self.stddev = stddev
def __call__(self, data): def __call__(self, data):
''' Calls the transformation. """Calls the transformation.
Args: Args:
data (dictionary): data dictionary data (dictionary): data dictionary
''' """
data_out = data.copy() data_out = data.copy()
points = data[None] points = data[None]
noise = self.stddev * np.random.randn(*points.shape) noise = self.stddev * np.random.randn(*points.shape)
@ -27,60 +27,63 @@ class PointcloudNoise(object):
data_out[None] = points + noise data_out[None] = points + noise
return data_out return data_out
class PointcloudOutliers(object):
''' Point cloud outlier transformation class. class PointcloudOutliers:
"""Point cloud outlier transformation class.
It adds outliers to point cloud data. It adds outliers to point cloud data.
Args: Args:
ratio (int): outlier percentage to the entire point cloud ratio (int): outlier percentage to the entire point cloud
''' """
def __init__(self, ratio): def __init__(self, ratio):
self.ratio = ratio self.ratio = ratio
def __call__(self, data): def __call__(self, data):
''' Calls the transformation. """Calls the transformation.
Args: Args:
data (dictionary): data dictionary data (dictionary): data dictionary
''' """
data_out = data.copy() data_out = data.copy()
points = data[None] points = data[None]
n_points = points.shape[0] n_points = points.shape[0]
n_outlier_points = int(n_points*self.ratio) n_outlier_points = int(n_points * self.ratio)
ind = np.random.randint(0, n_points, n_outlier_points) ind = np.random.randint(0, n_points, n_outlier_points)
outliers = np.random.uniform(-0.55, 0.55, (n_outlier_points, 3)) outliers = np.random.uniform(-0.55, 0.55, (n_outlier_points, 3))
outliers = outliers.astype(np.float32) outliers = outliers.astype(np.float32)
points[ind] = outliers points[ind] = outliers
data_out[None] = points data_out[None] = points
return data_out return data_out
class SubsamplePointcloud(object):
''' Point cloud subsampling transformation class. class SubsamplePointcloud:
"""Point cloud subsampling transformation class.
It subsamples the point cloud data. It subsamples the point cloud data.
Args: Args:
N (int): number of points to be subsampled N (int): number of points to be subsampled
''' """
def __init__(self, N): def __init__(self, N):
self.N = N self.N = N
def __call__(self, data): def __call__(self, data):
''' Calls the transformation. """Calls the transformation.
Args: Args:
data (dict): data dictionary data (dict): data dictionary
''' """
data_out = data.copy() data_out = data.copy()
points = data[None] points = data[None]
indices = np.random.randint(points.shape[0], size=self.N) indices = np.random.randint(points.shape[0], size=self.N)
data_out[None] = points[indices, :] data_out[None] = points[indices, :]
if 'normals' in data.keys(): if "normals" in data.keys():
normals = data['normals'] normals = data["normals"]
data_out['normals'] = normals[indices, :] data_out["normals"] = normals[indices, :]
return data_out return data_out

View file

@ -1,17 +1,20 @@
import os import os
import cv2
import torch
import numpy as np
from glob import glob from glob import glob
from torch.utils import data
from src.utils import load_rgb, load_mask, get_camera_params import cv2
import numpy as np
import torch
from pytorch3d.renderer import PerspectiveCameras from pytorch3d.renderer import PerspectiveCameras
from skimage import img_as_float32 from skimage import img_as_float32
from torch.utils import data
from src.utils import get_camera_params, load_mask, load_rgb
################################################## ##################################################
# Below are for the differentiable renderer # Below are for the differentiable renderer
# Taken from https://github.com/lioryariv/idr/blob/main/code/utils/rend_util.py # Taken from https://github.com/lioryariv/idr/blob/main/code/utils/rend_util.py
def load_rgb(path): def load_rgb(path):
img = imageio.imread(path) img = imageio.imread(path)
img = img_as_float32(img) img = img_as_float32(img)
@ -23,6 +26,7 @@ def load_rgb(path):
# img = img.transpose(2, 0, 1) # img = img.transpose(2, 0, 1)
return img return img
def load_mask(path): def load_mask(path):
alpha = imageio.imread(path, as_gray=True) alpha = imageio.imread(path, as_gray=True)
alpha = img_as_float32(alpha) alpha = img_as_float32(alpha)
@ -32,13 +36,13 @@ def load_mask(path):
def get_camera_params(uv, pose, intrinsics): def get_camera_params(uv, pose, intrinsics):
if pose.shape[1] == 7: #In case of quaternion vector representation if pose.shape[1] == 7: # In case of quaternion vector representation
cam_loc = pose[:, 4:] cam_loc = pose[:, 4:]
R = quat_to_rot(pose[:,:4]) R = quat_to_rot(pose[:, :4])
p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float() p = torch.eye(4).repeat(pose.shape[0], 1, 1).cuda().float()
p[:, :3, :3] = R p[:, :3, :3] = R
p[:, :3, 3] = cam_loc p[:, :3, 3] = cam_loc
else: # In case of pose matrix representation else: # In case of pose matrix representation
cam_loc = pose[:, :3, 3] cam_loc = pose[:, :3, 3]
p = pose p = pose
@ -60,25 +64,27 @@ def get_camera_params(uv, pose, intrinsics):
return ray_dirs, cam_loc return ray_dirs, cam_loc
def quat_to_rot(q): def quat_to_rot(q):
batch_size, _ = q.shape batch_size, _ = q.shape
q = F.normalize(q, dim=1) q = F.normalize(q, dim=1)
R = torch.ones((batch_size, 3,3)).cuda() R = torch.ones((batch_size, 3, 3)).cuda()
qr=q[:,0] qr = q[:, 0]
qi = q[:, 1] qi = q[:, 1]
qj = q[:, 2] qj = q[:, 2]
qk = q[:, 3] qk = q[:, 3]
R[:, 0, 0]=1-2 * (qj**2 + qk**2) R[:, 0, 0] = 1 - 2 * (qj**2 + qk**2)
R[:, 0, 1] = 2 * (qj *qi -qk*qr) R[:, 0, 1] = 2 * (qj * qi - qk * qr)
R[:, 0, 2] = 2 * (qi * qk + qr * qj) R[:, 0, 2] = 2 * (qi * qk + qr * qj)
R[:, 1, 0] = 2 * (qj * qi + qk * qr) R[:, 1, 0] = 2 * (qj * qi + qk * qr)
R[:, 1, 1] = 1-2 * (qi**2 + qk**2) R[:, 1, 1] = 1 - 2 * (qi**2 + qk**2)
R[:, 1, 2] = 2*(qj*qk - qi*qr) R[:, 1, 2] = 2 * (qj * qk - qi * qr)
R[:, 2, 0] = 2 * (qk * qi-qj * qr) R[:, 2, 0] = 2 * (qk * qi - qj * qr)
R[:, 2, 1] = 2 * (qj*qk + qi*qr) R[:, 2, 1] = 2 * (qj * qk + qi * qr)
R[:, 2, 2] = 1-2 * (qi**2 + qj**2) R[:, 2, 2] = 1 - 2 * (qi**2 + qj**2)
return R return R
def lift(x, y, z, intrinsics): def lift(x, y, z, intrinsics):
# parse intrinsics # parse intrinsics
# intrinsics = intrinsics.cuda() # intrinsics = intrinsics.cuda()
@ -88,7 +94,16 @@ def lift(x, y, z, intrinsics):
cy = intrinsics[:, 1, 2] cy = intrinsics[:, 1, 2]
sk = intrinsics[:, 0, 1] sk = intrinsics[:, 0, 1]
x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z x_lift = (
(
x
- cx.unsqueeze(-1)
+ cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1)
- sk.unsqueeze(-1) * y / fy.unsqueeze(-1)
)
/ fx.unsqueeze(-1)
* z
)
y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z
# homogeneous # homogeneous
@ -96,21 +111,18 @@ def lift(x, y, z, intrinsics):
class PixelNeRFDTUDataset(data.Dataset): class PixelNeRFDTUDataset(data.Dataset):
""" """Processed DTU from pixelNeRF."""
Processed DTU from pixelNeRF
""" def __init__(
def __init__(self, self,
data_dir='data/DTU', data_dir="data/DTU",
scan_id=65, scan_id=65,
img_size=None, img_size=None,
device=None, device=None,
fixed_scale=0, fixed_scale=0,
): ):
data_dir = os.path.join(data_dir, "scan{}".format(scan_id)) data_dir = os.path.join(data_dir, f"scan{scan_id}")
rgb_paths = [ rgb_paths = [x for x in glob(os.path.join(data_dir, "image", "*")) if (x.endswith((".jpg", ".png")))]
x for x in glob(os.path.join(data_dir, "image", "*"))
if (x.endswith(".jpg") or x.endswith(".png"))
]
rgb_paths = sorted(rgb_paths) rgb_paths = sorted(rgb_paths)
mask_paths = sorted(glob(os.path.join(data_dir, "mask", "*.png"))) mask_paths = sorted(glob(os.path.join(data_dir, "mask", "*.png")))
if len(mask_paths) == 0: if len(mask_paths) == 0:
@ -129,27 +141,24 @@ class PixelNeRFDTUDataset(data.Dataset):
all_T = [] all_T = []
for idx, (rgb_path, mask_path) in enumerate(zip(rgb_paths, mask_paths)): for idx, (rgb_path, mask_path) in enumerate(zip(rgb_paths, mask_paths)):
i = sel_indices[idx] i = sel_indices[idx]
rgb = load_rgb(rgb_path) rgb = load_rgb(rgb_path)
mask = load_mask(mask_path) mask = load_mask(mask_path)
rgb[~mask] = 0. rgb[~mask] = 0.0
rgb = torch.from_numpy(rgb).float().to(device) rgb = torch.from_numpy(rgb).float().to(device)
mask = torch.from_numpy(mask).float().to(device) mask = torch.from_numpy(mask).float().to(device)
x_scale = y_scale = 1.0
xy_delta = 0.0
P = all_cam["world_mat_" + str(i)] P = all_cam["world_mat_" + str(i)]
P = P[:3] P = P[:3]
# scale the original shape to really [-0.9, 0.9] # scale the original shape to really [-0.9, 0.9]
if fixed_scale!=0.: if fixed_scale != 0.0:
scale_mat_new = np.eye(4, 4) scale_mat_new = np.eye(4, 4)
scale_mat_new[:3, :3] *= fixed_scale # scale to [-0.9, 0.9] scale_mat_new[:3, :3] *= fixed_scale # scale to [-0.9, 0.9]
P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)] @ scale_mat_new P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)] @ scale_mat_new
else: else:
P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)] P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)]
P = P[:3, :4] P = P[:3, :4]
K, R, t = cv2.decomposeProjectionMatrix(P)[:3] K, R, t = cv2.decomposeProjectionMatrix(P)[:3]
K = K / K[2, 2] K = K / K[2, 2]
@ -158,38 +167,34 @@ class PixelNeRFDTUDataset(data.Dataset):
########!!!!! ########!!!!!
RR = torch.from_numpy(R).permute(1, 0).unsqueeze(0) RR = torch.from_numpy(R).permute(1, 0).unsqueeze(0)
tt = torch.from_numpy(-R@(t[:3] / t[3])).permute(1, 0) tt = torch.from_numpy(-R @ (t[:3] / t[3])).permute(1, 0)
focal = torch.tensor((fx, fy), dtype=torch.float32).unsqueeze(0) focal = torch.tensor((fx, fy), dtype=torch.float32).unsqueeze(0)
pc = torch.tensor((cx, cy), dtype=torch.float32).unsqueeze(0) pc = torch.tensor((cx, cy), dtype=torch.float32).unsqueeze(0)
im_size = (rgb.shape[1], rgb.shape[0]) im_size = (rgb.shape[1], rgb.shape[0])
# check https://pytorch3d.org/docs/cameras for how to transform from screen to NDC # check https://pytorch3d.org/docs/cameras for how to transform from screen to NDC
s = min(im_size) s = min(im_size)
focal[:, 0] = focal[:, 0] * 2 / (s-1) focal[:, 0] = focal[:, 0] * 2 / (s - 1)
focal[:, 1] = focal[:, 1] * 2 /(s-1) focal[:, 1] = focal[:, 1] * 2 / (s - 1)
pc[:, 0] = -(pc[:, 0] - (im_size[0]-1)/2) * 2 / (s-1) pc[:, 0] = -(pc[:, 0] - (im_size[0] - 1) / 2) * 2 / (s - 1)
pc[:, 1] = -(pc[:, 1] - (im_size[1]-1)/2) * 2 / (s-1) pc[:, 1] = -(pc[:, 1] - (im_size[1] - 1) / 2) * 2 / (s - 1)
camera = PerspectiveCameras(focal_length=-focal, principal_point=pc, camera = PerspectiveCameras(focal_length=-focal, principal_point=pc, device=device, R=RR, T=tt)
device=device, R=RR, T=tt)
# calculate camera rays # calculate camera rays
uv = uv_creation(im_size)[None].float() uv = uv_creation(im_size)[None].float()
pose = np.eye(4, dtype=np.float32) pose = np.eye(4, dtype=np.float32)
pose[:3, :3] = R.transpose() pose[:3, :3] = R.transpose()
pose[:3,3] = (t[:3] / t[3])[:,0] pose[:3, 3] = (t[:3] / t[3])[:, 0]
pose = torch.from_numpy(pose)[None].float() pose = torch.from_numpy(pose)[None].float()
intrinsics = np.eye(4) intrinsics = np.eye(4)
intrinsics[:3, :3] = K intrinsics[:3, :3] = K
intrinsics[0, 1] = 0. #! remove skew for now intrinsics[0, 1] = 0.0 #! remove skew for now
intrinsics = torch.from_numpy(intrinsics)[None].float() intrinsics = torch.from_numpy(intrinsics)[None].float()
rays, _ = get_camera_params(uv, pose, intrinsics) rays, _ = get_camera_params(uv, pose, intrinsics)
rays = -rays.to(device) rays = -rays.to(device)
all_poses.append(camera) all_poses.append(camera)
all_imgs.append(rgb) all_imgs.append(rgb)
all_masks.append(mask) all_masks.append(mask)
@ -198,7 +203,7 @@ class PixelNeRFDTUDataset(data.Dataset):
# only for neural renderer # only for neural renderer
all_K.append(torch.tensor(K).to(device)) all_K.append(torch.tensor(K).to(device))
all_R.append(torch.tensor(R).to(device)) all_R.append(torch.tensor(R).to(device))
all_T.append(torch.tensor(t[:3]/t[3]).to(device)) all_T.append(torch.tensor(t[:3] / t[3]).to(device))
all_imgs = torch.stack(all_imgs) all_imgs = torch.stack(all_imgs)
all_masks = torch.stack(all_masks) all_masks = torch.stack(all_masks)
@ -210,19 +215,20 @@ class PixelNeRFDTUDataset(data.Dataset):
all_T = torch.stack(all_T).permute(0, 2, 1).float() all_T = torch.stack(all_T).permute(0, 2, 1).float()
uv = uv_creation((all_imgs.size(2), all_imgs.size(1))) uv = uv_creation((all_imgs.size(2), all_imgs.size(1)))
self.data = {'rgbs': all_imgs, self.data = {
'masks': all_masks, "rgbs": all_imgs,
'poses': all_poses, "masks": all_masks,
'rays': all_rays, "poses": all_poses,
'uv': uv, "rays": all_rays,
'light_pose': all_light_pose, # for rendering lights "uv": uv,
'K': all_K, "light_pose": all_light_pose, # for rendering lights
'R': all_R, "K": all_K,
'T': all_T, "R": all_R,
} "T": all_T,
}
def __len__(self): def __len__(self):
return 1 return 1
def __getitem__(self, idx): def __getitem__(self, idx):
return self.data return self.data

View file

@ -1,17 +1,17 @@
import torch
import torch.nn as nn
from src.utils import spec_gaussian_filter, fftfreqs, img, grid_interp, point_rasterize
import numpy as np import numpy as np
import torch
import torch.fft import torch.fft
import torch.nn as nn
from src.utils import fftfreqs, grid_interp, img, point_rasterize, spec_gaussian_filter
class DPSR(nn.Module): class DPSR(nn.Module):
def __init__(self, res, sig=10, scale=True, shift=True): def __init__(self, res, sig=10, scale=True, shift=True):
""" """:param res: tuple of output field resolution. eg., (128,128)
:param res: tuple of output field resolution. eg., (128,128)
:param sig: degree of gaussian smoothing :param sig: degree of gaussian smoothing
""" """
super(DPSR, self).__init__() super().__init__()
self.res = res self.res = res
self.sig = sig self.sig = sig
self.dim = len(res) self.dim = len(res)
@ -22,45 +22,44 @@ class DPSR(nn.Module):
self.scale = scale self.scale = scale
self.shift = shift self.shift = shift
self.register_buffer("G", G) self.register_buffer("G", G)
def forward(self, V, N): def forward(self, V, N):
""" """:param V: (batch, nv, 2 or 3) tensor for point cloud coordinates
:param V: (batch, nv, 2 or 3) tensor for point cloud coordinates
:param N: (batch, nv, 2 or 3) tensor for point normals :param N: (batch, nv, 2 or 3) tensor for point normals
:return phi: (batch, res, res, ...) tensor of output indicator function field :return phi: (batch, res, res, ...) tensor of output indicator function field
""" """
assert(V.shape == N.shape) # [b, nv, ndims] assert V.shape == N.shape # [b, nv, ndims]
ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2] ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2]
ras_s = torch.fft.rfftn(ras_p, dim=(2,3,4))
ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+1))+[self.dim+1, 1]))
N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1]
omega = fftfreqs(self.res, dtype=torch.float32).unsqueeze(-1) # [dim0, dim1, dim2/2+1, n_dim, 1] ras_s = torch.fft.rfftn(ras_p, dim=(2, 3, 4))
ras_s = ras_s.permute(*tuple([0, *list(range(2, self.dim + 1)), self.dim + 1, 1]))
N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1]
omega = fftfreqs(self.res, dtype=torch.float32).unsqueeze(-1) # [dim0, dim1, dim2/2+1, n_dim, 1]
omega *= 2 * np.pi # normalize frequencies omega *= 2 * np.pi # normalize frequencies
omega = omega.to(V.device) omega = omega.to(V.device)
DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2) DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2)
Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1] Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1]
Phi = DivN / (Lap+1e-6) # [b, dim0, dim1, dim2/2+1, 2] Phi = DivN / (Lap + 1e-6) # [b, dim0, dim1, dim2/2+1, 2]
Phi = Phi.permute(*tuple([list(range(1,self.dim+2)) + [0]])) # [dim0, dim1, dim2/2+1, 2, b] Phi = Phi.permute(*tuple([[*list(range(1, self.dim + 2)), 0]])) # [dim0, dim1, dim2/2+1, 2, b]
Phi[tuple([0] * self.dim)] = 0 Phi[tuple([0] * self.dim)] = 0
Phi = Phi.permute(*tuple([[self.dim+1] + list(range(self.dim+1))])) # [b, dim0, dim1, dim2/2+1, 2] Phi = Phi.permute(*tuple([[self.dim + 1, *list(range(self.dim + 1))]])) # [b, dim0, dim1, dim2/2+1, 2]
phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1,2,3)) phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1, 2, 3))
if self.shift or self.scale: if self.shift or self.scale:
# ensure values at points are zero # ensure values at points are zero
fv = grid_interp(phi.unsqueeze(-1), V, batched=True).squeeze(-1) # [b, nv] fv = grid_interp(phi.unsqueeze(-1), V, batched=True).squeeze(-1) # [b, nv]
if self.shift: # offset points to have mean of 0 if self.shift: # offset points to have mean of 0
offset = torch.mean(fv, dim=-1) # [b,] offset = torch.mean(fv, dim=-1) # [b,]
phi -= offset.view(*tuple([-1] + [1] * self.dim)) phi -= offset.view(*tuple([-1] + [1] * self.dim))
phi = phi.permute(*tuple([list(range(1,self.dim+1)) + [0]])) phi = phi.permute(*tuple([[*list(range(1, self.dim + 1)), 0]]))
fv0 = phi[tuple([0] * self.dim)] # [b,] fv0 = phi[tuple([0] * self.dim)] # [b,]
phi = phi.permute(*tuple([[self.dim] + list(range(self.dim))])) phi = phi.permute(*tuple([[self.dim, *list(range(self.dim))]]))
if self.scale: if self.scale:
phi = -phi / torch.abs(fv0.view(*tuple([-1]+[1] * self.dim))) *0.5 phi = -phi / torch.abs(fv0.view(*tuple([-1] + [1] * self.dim))) * 0.5
return phi return phi

View file

@ -1,43 +1,45 @@
import logging import logging
import numpy as np import numpy as np
import trimesh
from pykdtree.kdtree import KDTree from pykdtree.kdtree import KDTree
EMPTY_PCL_DICT = { EMPTY_PCL_DICT = {
'completeness': np.sqrt(3), "completeness": np.sqrt(3),
'accuracy': np.sqrt(3), "accuracy": np.sqrt(3),
'completeness2': 3, "completeness2": 3,
'accuracy2': 3, "accuracy2": 3,
'chamfer': 6, "chamfer": 6,
} }
EMPTY_PCL_DICT_NORMALS = { EMPTY_PCL_DICT_NORMALS = {
'normals completeness': -1., "normals completeness": -1.0,
'normals accuracy': -1., "normals accuracy": -1.0,
'normals': -1., "normals": -1.0,
} }
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MeshEvaluator(object): class MeshEvaluator:
''' Mesh evaluation class. """Mesh evaluation class.
It handles the mesh evaluation process. It handles the mesh evaluation process.
Args: Args:
n_points (int): number of points to be used for evaluation n_points (int): number of points to be used for evaluation.
''' """
def __init__(self, n_points=100000): def __init__(self, n_points=100000):
self.n_points = n_points self.n_points = n_points
def eval_mesh(self, mesh, pointcloud_tgt, normals_tgt, thresholds=np.linspace(1./1000, 1, 1000)): def eval_mesh(self, mesh, pointcloud_tgt, normals_tgt, thresholds=np.linspace(1.0 / 1000, 1, 1000)):
''' Evaluates a mesh. """Evaluates a mesh.
Args: Args:
mesh (trimesh): mesh which should be evaluated mesh (trimesh): mesh which should be evaluated
pointcloud_tgt (numpy array): target point cloud pointcloud_tgt (numpy array): target point cloud
normals_tgt (numpy array): target normals normals_tgt (numpy array): target normals
thresholds (numpy arry): for F-Score thresholds (numpy arry): for F-Score.
''' """
if len(mesh.vertices) != 0 and len(mesh.faces) != 0: if len(mesh.vertices) != 0 and len(mesh.faces) != 0:
pointcloud, idx = mesh.sample(self.n_points, return_index=True) pointcloud, idx = mesh.sample(self.n_points, return_index=True)
@ -47,25 +49,25 @@ class MeshEvaluator(object):
pointcloud = np.empty((0, 3)) pointcloud = np.empty((0, 3))
normals = np.empty((0, 3)) normals = np.empty((0, 3))
out_dict = self.eval_pointcloud( out_dict = self.eval_pointcloud(pointcloud, pointcloud_tgt, normals, normals_tgt, thresholds=thresholds)
pointcloud, pointcloud_tgt, normals, normals_tgt, thresholds=thresholds)
return out_dict return out_dict
def eval_pointcloud(self, pointcloud, pointcloud_tgt, def eval_pointcloud(
normals=None, normals_tgt=None, self, pointcloud, pointcloud_tgt, normals=None, normals_tgt=None, thresholds=np.linspace(1.0 / 1000, 1, 1000),
thresholds=np.linspace(1./1000, 1, 1000)): ):
''' Evaluates a point cloud. """Evaluates a point cloud.
Args: Args:
pointcloud (numpy array): predicted point cloud pointcloud (numpy array): predicted point cloud
pointcloud_tgt (numpy array): target point cloud pointcloud_tgt (numpy array): target point cloud
normals (numpy array): predicted normals normals (numpy array): predicted normals
normals_tgt (numpy array): target normals normals_tgt (numpy array): target normals
thresholds (numpy array): threshold values for the F-score calculation thresholds (numpy array): threshold values for the F-score calculation.
''' """
# Return maximum losses if pointcloud is empty # Return maximum losses if pointcloud is empty
if pointcloud.shape[0] == 0: if pointcloud.shape[0] == 0:
logger.warn('Empty pointcloud / mesh detected!') logger.warning("Empty pointcloud / mesh detected!")
out_dict = EMPTY_PCL_DICT.copy() out_dict = EMPTY_PCL_DICT.copy()
if normals is not None and normals_tgt is not None: if normals is not None and normals_tgt is not None:
out_dict.update(EMPTY_PCL_DICT_NORMALS) out_dict.update(EMPTY_PCL_DICT_NORMALS)
@ -74,11 +76,13 @@ class MeshEvaluator(object):
pointcloud = np.asarray(pointcloud) pointcloud = np.asarray(pointcloud)
pointcloud_tgt = np.asarray(pointcloud_tgt) pointcloud_tgt = np.asarray(pointcloud_tgt)
# Completeness: how far are the points of the target point cloud # Completeness: how far are the points of the target point cloud
# from thre predicted point cloud # from thre predicted point cloud
completeness, completeness_normals = distance_p2p( completeness, completeness_normals = distance_p2p(
pointcloud_tgt, normals_tgt, pointcloud, normals pointcloud_tgt,
normals_tgt,
pointcloud,
normals,
) )
recall = get_threshold_percentage(completeness, thresholds) recall = get_threshold_percentage(completeness, thresholds)
completeness2 = completeness**2 completeness2 = completeness**2
@ -90,7 +94,10 @@ class MeshEvaluator(object):
# Accuracy: how far are th points of the predicted pointcloud # Accuracy: how far are th points of the predicted pointcloud
# from the target pointcloud # from the target pointcloud
accuracy, accuracy_normals = distance_p2p( accuracy, accuracy_normals = distance_p2p(
pointcloud, normals, pointcloud_tgt, normals_tgt pointcloud,
normals,
pointcloud_tgt,
normals_tgt,
) )
precision = get_threshold_percentage(accuracy, thresholds) precision = get_threshold_percentage(accuracy, thresholds)
accuracy2 = accuracy**2 accuracy2 = accuracy**2
@ -101,68 +108,61 @@ class MeshEvaluator(object):
# Chamfer distance # Chamfer distance
chamferL2 = 0.5 * (completeness2 + accuracy2) chamferL2 = 0.5 * (completeness2 + accuracy2)
normals_correctness = ( normals_correctness = 0.5 * completeness_normals + 0.5 * accuracy_normals
0.5 * completeness_normals + 0.5 * accuracy_normals
)
chamferL1 = 0.5 * (completeness + accuracy) chamferL1 = 0.5 * (completeness + accuracy)
# F-Score # F-Score
F = [ F = [2 * precision[i] * recall[i] / (precision[i] + recall[i]) for i in range(len(precision))]
2 * precision[i] * recall[i] / (precision[i] + recall[i])
for i in range(len(precision))
]
out_dict = { out_dict = {
'completeness': completeness, "completeness": completeness,
'accuracy': accuracy, "accuracy": accuracy,
'normals completeness': completeness_normals, "normals completeness": completeness_normals,
'normals accuracy': accuracy_normals, "normals accuracy": accuracy_normals,
'normals': normals_correctness, "normals": normals_correctness,
'completeness2': completeness2, "completeness2": completeness2,
'accuracy2': accuracy2, "accuracy2": accuracy2,
'chamfer-L2': chamferL2, "chamfer-L2": chamferL2,
'chamfer-L1': chamferL1, "chamfer-L1": chamferL1,
'f-score': F[9], # threshold = 1.0% "f-score": F[9], # threshold = 1.0%
'f-score-15': F[14], # threshold = 1.5% "f-score-15": F[14], # threshold = 1.5%
'f-score-20': F[19], # threshold = 2.0% "f-score-20": F[19], # threshold = 2.0%
} }
return out_dict return out_dict
def distance_p2p(points_src, normals_src, points_tgt, normals_tgt): def distance_p2p(points_src, normals_src, points_tgt, normals_tgt):
''' Computes minimal distances of each point in points_src to points_tgt. """Computes minimal distances of each point in points_src to points_tgt.
Args: Args:
points_src (numpy array): source points points_src (numpy array): source points
normals_src (numpy array): source normals normals_src (numpy array): source normals
points_tgt (numpy array): target points points_tgt (numpy array): target points
normals_tgt (numpy array): target normals normals_tgt (numpy array): target normals.
''' """
kdtree = KDTree(points_tgt) kdtree = KDTree(points_tgt)
dist, idx = kdtree.query(points_src) dist, idx = kdtree.query(points_src)
if normals_src is not None and normals_tgt is not None: if normals_src is not None and normals_tgt is not None:
normals_src = \ normals_src = normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True)
normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True) normals_tgt = normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True)
normals_tgt = \
normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True)
normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1) normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1)
# Handle normals that point into wrong direction gracefully # Handle normals that point into wrong direction gracefully
# (mostly due to mehtod not caring about this in generation) # (mostly due to mehtod not caring about this in generation)
normals_dot_product = np.abs(normals_dot_product) normals_dot_product = np.abs(normals_dot_product)
else: else:
normals_dot_product = np.array( normals_dot_product = np.array([np.nan] * points_src.shape[0], dtype=np.float32)
[np.nan] * points_src.shape[0], dtype=np.float32)
return dist, normals_dot_product return dist, normals_dot_product
def get_threshold_percentage(dist, thresholds): def get_threshold_percentage(dist, thresholds):
''' Evaluates a point cloud. """Evaluates a point cloud.
Args: Args:
dist (numpy array): calculated distance dist (numpy array): calculated distance
thresholds (numpy array): threshold values for the F-score calculation thresholds (numpy array): threshold values for the F-score calculation.
''' """
in_threshold = [ in_threshold = [(dist <= t).mean() for t in thresholds]
(dist <= t).mean() for t in thresholds return in_threshold
]
return in_threshold

View file

@ -1,11 +1,12 @@
import torch
import time import time
import trimesh
import numpy as np import torch
from src.utils import mc_from_psr from src.utils import mc_from_psr
class Generator3D(object):
''' Generator class for Occupancy Networks. class Generator3D:
"""Generator class for Occupancy Networks.
It provides functions to generate the final mesh as well refining options. It provides functions to generate the final mesh as well refining options.
@ -17,11 +18,20 @@ class Generator3D(object):
padding (float): how much padding should be used for MISE padding (float): how much padding should be used for MISE
sample (bool): whether z should be sampled sample (bool): whether z should be sampled
input_type (str): type of input input_type (str): type of input
''' """
def __init__(self, model, points_batch_size=100000, def __init__(
threshold=0.5, device=None, padding=0.1, self,
sample=False, input_type = None, dpsr=None, psr_tanh=True): model,
points_batch_size=100000,
threshold=0.5,
device=None,
padding=0.1,
sample=False,
input_type=None,
dpsr=None,
psr_tanh=True,
):
self.model = model.to(device) self.model = model.to(device)
self.points_batch_size = points_batch_size self.points_batch_size = points_batch_size
self.threshold = threshold self.threshold = threshold
@ -31,33 +41,32 @@ class Generator3D(object):
self.sample = sample self.sample = sample
self.dpsr = dpsr self.dpsr = dpsr
self.psr_tanh = psr_tanh self.psr_tanh = psr_tanh
def generate_mesh(self, data, return_stats=True): def generate_mesh(self, data, return_stats=True):
''' Generates the output mesh. """Generates the output mesh.
Args: Args:
data (tensor): data tensor data (tensor): data tensor
return_stats (bool): whether stats should be returned return_stats (bool): whether stats should be returned
''' """
self.model.eval() self.model.eval()
device = self.device device = self.device
stats_dict = {} stats_dict = {}
p = data.get('inputs', torch.empty(1, 0)).to(device) p = data.get("inputs", torch.empty(1, 0)).to(device)
t0 = time.time() t0 = time.time()
points, normals = self.model(p) points, normals = self.model(p)
t1 = time.time() t1 = time.time()
psr_grid = self.dpsr(points, normals) psr_grid = self.dpsr(points, normals)
t2 = time.time() t2 = time.time()
v, f, _ = mc_from_psr(psr_grid, v, f, _ = mc_from_psr(psr_grid, zero_level=self.threshold)
zero_level=self.threshold) stats_dict["pcl"] = t1 - t0
stats_dict['pcl'] = t1 - t0 stats_dict["dpsr"] = t2 - t1
stats_dict['dpsr'] = t2 - t1 stats_dict["mc"] = time.time() - t2
stats_dict['mc'] = time.time() - t2 stats_dict["total"] = time.time() - t0
stats_dict['total'] = time.time() - t0
if return_stats: if return_stats:
return v, f, points, normals, stats_dict return v, f, points, normals, stats_dict
else: else:
return v, f, points, normals return v, f, points, normals

View file

@ -1,18 +1,17 @@
import torch
import numpy as np
import time import time
from src.utils import point_rasterize, grid_interp, mc_from_psr, \
calc_inters_points import torch
from src.dpsr import DPSR
import torch.nn as nn import torch.nn as nn
from src.network import encoder_dict, decoder_dict
from src.network import decoder_dict, encoder_dict
from src.network.utils import map2local from src.network.utils import map2local
from src.utils import calc_inters_points, grid_interp, mc_from_psr, point_rasterize
class PSR2Mesh(torch.autograd.Function): class PSR2Mesh(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, psr_grid): def forward(ctx, psr_grid):
""" """In the forward pass we receive a Tensor containing the input and return
In the forward pass we receive a Tensor containing the input and return
a Tensor containing the output. ctx is a context object that can be used a Tensor containing the output. ctx is a context object that can be used
to stash information for backward computation. You can cache arbitrary to stash information for backward computation. You can cache arbitrary
objects for use in the backward pass using the ctx.save_for_backward method. objects for use in the backward pass using the ctx.save_for_backward method.
@ -29,8 +28,7 @@ class PSR2Mesh(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals): def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals):
""" """In the backward pass we receive a Tensor containing the gradient of the loss
In the backward pass we receive a Tensor containing the gradient of the loss
with respect to the output, and we need to compute the gradient of the loss with respect to the output, and we need to compute the gradient of the loss
with respect to the input. with respect to the input.
""" """
@ -39,17 +37,17 @@ class PSR2Mesh(torch.autograd.Function):
# matrix multiplication between dL/dV and dV/dPSR # matrix multiplication between dL/dV and dV/dPSR
# dV/dPSR = - normals # dV/dPSR = - normals
grad_vert = torch.matmul(dL_dVertex.permute(1, 0, 2), -normals.permute(1, 2, 0)) grad_vert = torch.matmul(dL_dVertex.permute(1, 0, 2), -normals.permute(1, 2, 0))
grad_grid = point_rasterize(vert_pts, grad_vert.permute(1, 0, 2), res) # b x 1 x res x res x res grad_grid = point_rasterize(vert_pts, grad_vert.permute(1, 0, 2), res) # b x 1 x res x res x res
return grad_grid return grad_grid
class PSR2SurfacePoints(torch.autograd.Function): class PSR2SurfacePoints(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample): def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample):
verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True) verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
verts = verts * 2. - 1. # within the range of [-1, 1] verts = verts * 2.0 - 1.0 # within the range of [-1, 1]
p_all, n_all, mask_all = [], [], [] p_all, n_all, mask_all = [], [], []
for i in range(len(poses)): for i in range(len(poses)):
@ -67,7 +65,6 @@ class PSR2SurfacePoints(torch.autograd.Function):
n_inters_all = torch.cat(n_all, dim=0) n_inters_all = torch.cat(n_all, dim=0)
mask_visible = torch.stack(mask_all, dim=0) mask_visible = torch.stack(mask_all, dim=0)
res = torch.tensor(psr_grid.detach().shape[2]) res = torch.tensor(psr_grid.detach().shape[2])
ctx.save_for_backward(p_inters_all, n_inters_all, res) ctx.save_for_backward(p_inters_all, n_inters_all, res)
@ -80,30 +77,31 @@ class PSR2SurfacePoints(torch.autograd.Function):
# grad from the p_inters via MLP renderer # grad from the p_inters via MLP renderer
grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None]) grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None])
grad_grid_pts = point_rasterize((pts[None]+1)/2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res grad_grid_pts = point_rasterize((pts[None] + 1) / 2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res
return grad_grid_pts, None, None, None, None, None return grad_grid_pts, None, None, None, None, None
class Encode2Points(nn.Module): class Encode2Points(nn.Module):
def __init__(self, cfg): def __init__(self, cfg):
super().__init__() super().__init__()
self.cfg = cfg self.cfg = cfg
encoder = cfg['model']['encoder'] encoder = cfg["model"]["encoder"]
decoder = cfg['model']['decoder'] decoder = cfg["model"]["decoder"]
dim = cfg['data']['dim'] # input dim dim = cfg["data"]["dim"] # input dim
c_dim = cfg['model']['c_dim'] c_dim = cfg["model"]["c_dim"]
encoder_kwargs = cfg['model']['encoder_kwargs'] encoder_kwargs = cfg["model"]["encoder_kwargs"]
if encoder_kwargs == None: if encoder_kwargs is None:
encoder_kwargs = {} encoder_kwargs = {}
decoder_kwargs = cfg['model']['decoder_kwargs'] decoder_kwargs = cfg["model"]["decoder_kwargs"]
padding = cfg['data']['padding'] cfg["data"]["padding"]
self.predict_normal = cfg['model']['predict_normal'] self.predict_normal = cfg["model"]["predict_normal"]
self.predict_offset = cfg['model']['predict_offset'] self.predict_offset = cfg["model"]["predict_offset"]
out_dim = 3 out_dim = 3
out_dim_offset = 3 out_dim_offset = 3
num_offset = cfg['data']['num_offset'] num_offset = cfg["data"]["num_offset"]
# each point predict more than one offset to add output points # each point predict more than one offset to add output points
if num_offset > 1: if num_offset > 1:
out_dim_offset = out_dim * num_offset out_dim_offset = out_dim * num_offset
@ -111,44 +109,40 @@ class Encode2Points(nn.Module):
# local mapping # local mapping
self.map2local = None self.map2local = None
if cfg['model']['local_coord']: if cfg["model"]["local_coord"]:
if 'unet' in encoder_kwargs.keys(): if "unet" in encoder_kwargs.keys():
unit_size = 1 / encoder_kwargs['plane_resolution'] unit_size = 1 / encoder_kwargs["plane_resolution"]
else: else:
unit_size = 1 / encoder_kwargs['grid_resolution'] unit_size = 1 / encoder_kwargs["grid_resolution"]
local_mapping = map2local(unit_size) local_mapping = map2local(unit_size)
self.encoder = encoder_dict[encoder]( self.encoder = encoder_dict[encoder](
dim=dim, c_dim=c_dim, map2local=local_mapping, dim=dim,
**encoder_kwargs c_dim=c_dim,
map2local=local_mapping,
**encoder_kwargs,
) )
if self.predict_normal: if self.predict_normal:
# decoder for normal prediction # decoder for normal prediction
self.decoder_normal = decoder_dict[decoder]( self.decoder_normal = decoder_dict[decoder](dim=dim, c_dim=c_dim, out_dim=out_dim, **decoder_kwargs)
dim=dim, c_dim=c_dim, out_dim=out_dim,
**decoder_kwargs)
if self.predict_offset: if self.predict_offset:
# decoder for offset prediction # decoder for offset prediction
self.decoder_offset = decoder_dict[decoder]( self.decoder_offset = decoder_dict[decoder](
dim=dim, c_dim=c_dim, out_dim=out_dim_offset, dim=dim, c_dim=c_dim, out_dim=out_dim_offset, map2local=local_mapping, **decoder_kwargs,
map2local=local_mapping, )
**decoder_kwargs)
self.s_off = cfg["model"]["s_offset"]
self.s_off = cfg['model']['s_offset']
def forward(self, p): def forward(self, p):
''' Performs a forward pass through the network. """Performs a forward pass through the network.
Args: Args:
p (tensor): input unoriented points p (tensor): input unoriented points
''' """
time_dict = {} time_dict = {}
mask = None
batch_size = p.size(0) batch_size = p.size(0)
points = p.clone() points = p.clone()
@ -168,14 +162,12 @@ class Encode2Points(nn.Module):
if self.predict_normal: if self.predict_normal:
normals = self.decoder_normal(points, c) normals = self.decoder_normal(points, c)
t2 = time.perf_counter() t2 = time.perf_counter()
time_dict['encode'] = t1 - t0
time_dict['predict'] = t2 - t1
points = torch.clamp(points, 0.0, 0.99)
if self.cfg['model']['normal_normalize']:
normals = normals / (normals.norm(dim=-1, keepdim=True)+1e-8)
time_dict["encode"] = t1 - t0
time_dict["predict"] = t2 - t1
points = torch.clamp(points, 0.0, 0.99)
if self.cfg["model"]["normal_normalize"]:
normals = normals / (normals.norm(dim=-1, keepdim=True) + 1e-8)
return points, normals return points, normals

View file

@ -1,17 +1,19 @@
import torch import torch
from pytorch3d.renderer import (
MeshRasterizer,
MeshRenderer,
PerspectiveCameras,
RasterizationSettings,
SoftSilhouetteShader,
)
from pytorch3d.structures import Meshes
from src.network.net_rgb import RenderingNetwork from src.network.net_rgb import RenderingNetwork
from src.utils import approx_psr_grad from src.utils import approx_psr_grad
from pytorch3d.renderer import (
RasterizationSettings,
PerspectiveCameras,
MeshRenderer,
MeshRasterizer,
SoftSilhouetteShader)
from pytorch3d.structures import Meshes
def approx_psr_grad(psr_grid, res, normalize=True): def approx_psr_grad(psr_grid, res, normalize=True):
delta_x = delta_y = delta_z = 1/res delta_x = delta_y = delta_z = 1 / res
psr_pad = torch.nn.ReplicationPad3d(1)(psr_grid).squeeze() psr_pad = torch.nn.ReplicationPad3d(1)(psr_grid).squeeze()
grad_x = (psr_pad[2:, :, :] - psr_pad[:-2, :, :]) / 2 / delta_x grad_x = (psr_pad[2:, :, :] - psr_pad[:-2, :, :]) / 2 / delta_x
@ -35,76 +37,77 @@ class SAP2Image(nn.Module):
self.psr2sur = PSR2SurfacePoints.apply self.psr2sur = PSR2SurfacePoints.apply
self.psr2mesh = PSR2Mesh.apply self.psr2mesh = PSR2Mesh.apply
# initialize DPSR # initialize DPSR
self.dpsr = DPSR(res=(cfg['model']['grid_res'], self.dpsr = DPSR(
cfg['model']['grid_res'], res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
cfg['model']['grid_res']), sig=cfg["model"]["psr_sigma"],
sig=cfg['model']['psr_sigma']) )
self.cfg = cfg self.cfg = cfg
if cfg['train']['l_weight']['rgb'] != 0.: if cfg["train"]["l_weight"]["rgb"] != 0.0:
self.rendering_network = RenderingNetwork(**cfg['model']['renderer']) self.rendering_network = RenderingNetwork(**cfg["model"]["renderer"])
if cfg['train']['l_weight']['mask'] != 0.: if cfg["train"]["l_weight"]["mask"] != 0.0:
# initialize rasterizer # initialize rasterizer
sigma = 1e-4 sigma = 1e-4
raster_settings_soft = RasterizationSettings( raster_settings_soft = RasterizationSettings(
image_size=img_size, image_size=img_size,
blur_radius=np.log(1. / 1e-4 - 1.)*sigma, blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma,
faces_per_pixel=150, faces_per_pixel=150,
perspective_correct=False perspective_correct=False,
) )
# initialize silhouette renderer # initialize silhouette renderer
self.mesh_rasterizer = MeshRenderer( self.mesh_rasterizer = MeshRenderer(
rasterizer=MeshRasterizer( rasterizer=MeshRasterizer(
raster_settings=raster_settings_soft raster_settings=raster_settings_soft,
), ),
shader=SoftSilhouetteShader() shader=SoftSilhouetteShader(),
) )
self.cfg = cfg self.cfg = cfg
self.img_size = img_size self.img_size = img_size
def forward(self, inputs, data): def forward(self, inputs, data):
points, normals = inputs[...,:3], inputs[...,3:] points, normals = inputs[..., :3], inputs[..., 3:]
points = torch.sigmoid(points) points = torch.sigmoid(points)
normals = normals / normals.norm(dim=-1, keepdim=True) normals = normals / normals.norm(dim=-1, keepdim=True)
# DPSR to get grid # DPSR to get grid
psr_grid = self.dpsr(points, normals).unsqueeze(1) psr_grid = self.dpsr(points, normals).unsqueeze(1)
psr_grid = torch.tanh(psr_grid) psr_grid = torch.tanh(psr_grid)
return self.render_img(psr_grid, data) return self.render_img(psr_grid, data)
def render_img(self, psr_grid, data): def render_img(self, psr_grid, data):
n_views = len(data["masks"])
n_views = len(data['masks']) n_views_per_iter = self.cfg["data"]["n_views_per_iter"]
n_views_per_iter = self.cfg['data']['n_views_per_iter']
self.cfg["model"]["renderer"]["mode"]
rgb_render_mode = self.cfg['model']['renderer']['mode'] uv = data["uv"]
uv = data['uv']
idx = np.random.randint(0, n_views, n_views_per_iter) idx = np.random.randint(0, n_views, n_views_per_iter)
pose = [data['poses'][i] for i in idx] pose = [data["poses"][i] for i in idx]
rgb = data['rgbs'][idx] rgb = data["rgbs"][idx]
mask_gt = data['masks'][idx] mask_gt = data["masks"][idx]
ray = None ray = None
pred_rgb = None pred_rgb = None
pred_mask = None pred_mask = None
if self.cfg['train']['l_weight']['rgb'] != 0.: if self.cfg["train"]["l_weight"]["rgb"] != 0.0:
psr_grad = approx_psr_grad(psr_grid, self.cfg['model']['grid_res']) psr_grad = approx_psr_grad(psr_grid, self.cfg["model"]["grid_res"])
p_inters, visible_mask = self.psr2sur(psr_grid, pose, self.img_size, uv, psr_grad, None) p_inters, visible_mask = self.psr2sur(psr_grid, pose, self.img_size, uv, psr_grad, None)
n_inters = grid_interp(psr_grad[None], (p_inters.detach()[None] + 1) / 2) n_inters = grid_interp(psr_grad[None], (p_inters.detach()[None] + 1) / 2)
fea_interp = None fea_interp = None
if 'rays' in data.keys(): if "rays" in data.keys():
ray = data['rays'].squeeze()[idx][visible_mask] ray = data["rays"].squeeze()[idx][visible_mask]
pred_rgb = self.rendering_network(p_inters, normals=n_inters.squeeze(), view_dirs=ray, feature_vectors=fea_interp) pred_rgb = self.rendering_network(
p_inters, normals=n_inters.squeeze(), view_dirs=ray, feature_vectors=fea_interp,
)
# silhouette loss # silhouette loss
if self.cfg['train']['l_weight']['mask'] != 0.: if self.cfg["train"]["l_weight"]["mask"] != 0.0:
# build mesh # build mesh
v, f, _ = self.psr2mesh(psr_grid) v, f, _ = self.psr2mesh(psr_grid)
v = v * 2. - 1 # within the range of [-1, 1] v = v * 2.0 - 1 # within the range of [-1, 1]
# ! Fast but more GPU usage # ! Fast but more GPU usage
mesh = Meshes(verts=[v.squeeze()], faces=[f.squeeze()]) mesh = Meshes(verts=[v.squeeze()], faces=[f.squeeze()])
if True: if True:
@ -114,11 +117,7 @@ class SAP2Image(nn.Module):
T = torch.cat([p.T for p in pose], dim=0) T = torch.cat([p.T for p in pose], dim=0)
focal = torch.cat([p.focal_length for p in pose], dim=0) focal = torch.cat([p.focal_length for p in pose], dim=0)
pp = torch.cat([p.principal_point for p in pose], dim=0) pp = torch.cat([p.principal_point for p in pose], dim=0)
pose_cur = PerspectiveCameras( pose_cur = PerspectiveCameras(focal_length=focal, principal_point=pp, R=R, T=T, device=R.device)
focal_length=focal,
principal_point=pp,
R=R, T=T,
device=R.device)
pred_mask = self.mesh_rasterizer(mesh.extend(n_views_per_iter), cameras=pose_cur)[..., 3] pred_mask = self.mesh_rasterizer(mesh.extend(n_views_per_iter), cameras=pose_cur)[..., 3]
else: else:
pred_mask = [] pred_mask = []
@ -129,11 +128,11 @@ class SAP2Image(nn.Module):
pred_mask = torch.cat(pred_mask, dim=0) pred_mask = torch.cat(pred_mask, dim=0)
output = { output = {
'rgb': pred_rgb, "rgb": pred_rgb,
'rgb_gt': rgb, "rgb_gt": rgb,
'mask': pred_mask, "mask": pred_mask,
'mask_gt': mask_gt, "mask_gt": mask_gt,
'vis_mask': visible_mask, "vis_mask": visible_mask,
} }
return output return output

View file

@ -1,8 +1,8 @@
from src.network import encoder, decoder from src.network import decoder, encoder
encoder_dict = { encoder_dict = {
'local_pool_pointnet': encoder.LocalPoolPointnet, "local_pool_pointnet": encoder.LocalPoolPointnet,
} }
decoder_dict = { decoder_dict = {
'simple_local': decoder.LocalDecoder, "simple_local": decoder.LocalDecoder,
} }

View file

@ -1,15 +1,17 @@
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np
from ipdb import set_trace as st from src.network.utils import (
from src.network.utils import normalize_3d_coordinate, ResnetBlockFC, \ ResnetBlockFC,
normalize_coordinate normalize_3d_coordinate,
normalize_coordinate,
)
class LocalDecoder(nn.Module): class LocalDecoder(nn.Module):
''' Decoder. """Decoder.
Instead of conditioning on global features, on plane/volume local features. Instead of conditioning on global features, on plane/volume local features.
Args: Args:
dim (int): input dimension dim (int): input dimension
c_dim (int): dimension of latent conditioned code c c_dim (int): dimension of latent conditioned code c
@ -17,26 +19,31 @@ class LocalDecoder(nn.Module):
n_blocks (int): number of blocks ResNetBlockFC layers n_blocks (int): number of blocks ResNetBlockFC layers
leaky (bool): whether to use leaky ReLUs leaky (bool): whether to use leaky ReLUs
sample_mode (str): sampling feature strategy, bilinear|nearest sample_mode (str): sampling feature strategy, bilinear|nearest
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55].
''' """
def __init__(self, dim=3, c_dim=128, out_dim=3, def __init__(
hidden_size=256, n_blocks=5, leaky=False, sample_mode='bilinear', padding=0.1, map2local=None): self,
dim=3,
c_dim=128,
out_dim=3,
hidden_size=256,
n_blocks=5,
leaky=False,
sample_mode="bilinear",
padding=0.1,
map2local=None,
):
super().__init__() super().__init__()
self.c_dim = c_dim self.c_dim = c_dim
self.n_blocks = n_blocks self.n_blocks = n_blocks
if c_dim != 0: if c_dim != 0:
self.fc_c = nn.ModuleList([ self.fc_c = nn.ModuleList([nn.Linear(c_dim, hidden_size) for i in range(n_blocks)])
nn.Linear(c_dim, hidden_size) for i in range(n_blocks)
])
self.fc_p = nn.Linear(dim, hidden_size) self.fc_p = nn.Linear(dim, hidden_size)
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([ResnetBlockFC(hidden_size) for i in range(n_blocks)])
ResnetBlockFC(hidden_size) for i in range(n_blocks)
])
self.fc_out = nn.Linear(hidden_size, out_dim) self.fc_out = nn.Linear(hidden_size, out_dim)
@ -49,46 +56,45 @@ class LocalDecoder(nn.Module):
self.padding = padding self.padding = padding
self.map2local = map2local self.map2local = map2local
self.out_dim = out_dim self.out_dim = out_dim
def sample_plane_feature(self, p, c, plane='xz'): def sample_plane_feature(self, p, c, plane="xz"):
xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1) xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
xy = xy[:, :, None].float() xy = xy[:, :, None].float()
vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1) vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
c = F.grid_sample(c, vgrid, padding_mode='border', c = F.grid_sample(c, vgrid, padding_mode="border", align_corners=True, mode=self.sample_mode).squeeze(-1)
align_corners=True,
mode=self.sample_mode).squeeze(-1)
return c return c
def sample_grid_feature(self, p, c): def sample_grid_feature(self, p, c):
p_nor = normalize_3d_coordinate(p.clone()) p_nor = normalize_3d_coordinate(p.clone())
p_nor = p_nor[:, :, None, None].float() p_nor = p_nor[:, :, None, None].float()
vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1) vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1)
# acutally trilinear interpolation if mode = 'bilinear' # acutally trilinear interpolation if mode = 'bilinear'
c = F.grid_sample(c, vgrid, padding_mode='border', c = (
align_corners=True, F.grid_sample(c, vgrid, padding_mode="border", align_corners=True, mode=self.sample_mode)
mode=self.sample_mode).squeeze(-1).squeeze(-1) .squeeze(-1)
.squeeze(-1)
)
return c return c
def forward(self, p, c_plane, **kwargs): def forward(self, p, c_plane, **kwargs):
batch_size = p.shape[0] batch_size = p.shape[0]
plane_type = list(c_plane.keys()) plane_type = list(c_plane.keys())
c = 0 c = 0
if 'grid' in plane_type: if "grid" in plane_type:
c += self.sample_grid_feature(p, c_plane['grid']) c += self.sample_grid_feature(p, c_plane["grid"])
if 'xz' in plane_type: if "xz" in plane_type:
c += self.sample_plane_feature(p, c_plane['xz'], plane='xz') c += self.sample_plane_feature(p, c_plane["xz"], plane="xz")
if 'xy' in plane_type: if "xy" in plane_type:
c += self.sample_plane_feature(p, c_plane['xy'], plane='xy') c += self.sample_plane_feature(p, c_plane["xy"], plane="xy")
if 'yz' in plane_type: if "yz" in plane_type:
c += self.sample_plane_feature(p, c_plane['yz'], plane='yz') c += self.sample_plane_feature(p, c_plane["yz"], plane="yz")
c = c.transpose(1, 2) c = c.transpose(1, 2)
p = p.float() p = p.float()
if self.map2local: if self.map2local:
p = self.map2local(p) p = self.map2local(p)
net = self.fc_p(p) net = self.fc_p(p)
for i in range(self.n_blocks): for i in range(self.n_blocks):
@ -98,9 +104,8 @@ class LocalDecoder(nn.Module):
net = self.blocks[i](net) net = self.blocks[i](net)
out = self.fc_out(self.actvn(net)) out = self.fc_out(self.actvn(net))
if self.out_dim > 3: if self.out_dim > 3:
out = out.reshape(batch_size, -1, 3) out = out.reshape(batch_size, -1, 3)
return out return out

View file

@ -1,17 +1,23 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np from torch_scatter import scatter_max, scatter_mean
from src.network.unet3d import UNet3D
from src.network.unet import UNet from src.network.unet import UNet
from ipdb import set_trace as st from src.network.unet3d import UNet3D
from torch_scatter import scatter_mean, scatter_max from src.network.utils import (
from src.network.utils import get_embedder, normalize_3d_coordinate,\ ResnetBlockFC,
coordinate2index, ResnetBlockFC, normalize_coordinate coordinate2index,
get_embedder,
normalize_3d_coordinate,
normalize_coordinate,
)
class LocalPoolPointnet(nn.Module): class LocalPoolPointnet(nn.Module):
''' PointNet-based encoder network with ResNet blocks for each point. """PointNet-based encoder network with ResNet blocks for each point.
Number of input points are fixed. Number of input points are fixed.
Args: Args:
c_dim (int): dimension of latent code c c_dim (int): dimension of latent code c
dim (int): input points dimension dim (int): input points dimension
@ -22,26 +28,38 @@ class LocalPoolPointnet(nn.Module):
unet3d (bool): weather to use 3D U-Net unet3d (bool): weather to use 3D U-Net
unet3d_kwargs (str): 3D U-Net parameters unet3d_kwargs (str): 3D U-Net parameters
plane_resolution (int): defined resolution for plane feature plane_resolution (int): defined resolution for plane feature
grid_resolution (int): defined resolution for grid feature grid_resolution (int): defined resolution for grid feature
plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
n_blocks (int): number of blocks ResNetBlockFC layers n_blocks (int): number of blocks ResNetBlockFC layers
map2local (function): map global coordintes to local ones map2local (function): map global coordintes to local ones
pos_encoding (int): frequency for the positional encoding pos_encoding (int): frequency for the positional encoding
''' """
def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max', def __init__(
unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None, self,
plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5, c_dim=128,
map2local=None, pos_encoding=0): dim=3,
hidden_dim=128,
scatter_type="max",
unet=False,
unet_kwargs=None,
unet3d=False,
unet3d_kwargs=None,
plane_resolution=None,
grid_resolution=None,
plane_type="xz",
padding=0.1,
n_blocks=5,
map2local=None,
pos_encoding=0,
):
super().__init__() super().__init__()
self.c_dim = c_dim self.c_dim = c_dim
self.fc_pos = nn.Linear(dim, 2*hidden_dim) self.fc_pos = nn.Linear(dim, 2 * hidden_dim)
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks)])
ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
])
self.fc_c = nn.Linear(hidden_dim, c_dim) self.fc_c = nn.Linear(hidden_dim, c_dim)
self.actvn = nn.ReLU() self.actvn = nn.ReLU()
@ -59,34 +77,36 @@ class LocalPoolPointnet(nn.Module):
self.reso_grid = grid_resolution self.reso_grid = grid_resolution
self.plane_type = plane_type self.plane_type = plane_type
self.padding = padding self.padding = padding
self.fc_pos = nn.Linear(dim, 2*hidden_dim) self.fc_pos = nn.Linear(dim, 2 * hidden_dim)
self.pe = None self.pe = None
if pos_encoding > 0: if pos_encoding > 0:
embed_fn, input_ch = get_embedder(pos_encoding, d_in=dim) embed_fn, input_ch = get_embedder(pos_encoding, d_in=dim)
self.pe = embed_fn self.pe = embed_fn
self.fc_pos = nn.Linear(input_ch, 2*hidden_dim) self.fc_pos = nn.Linear(input_ch, 2 * hidden_dim)
self.map2local = map2local self.map2local = map2local
if scatter_type == 'max': if scatter_type == "max":
self.scatter = scatter_max self.scatter = scatter_max
elif scatter_type == 'mean': elif scatter_type == "mean":
self.scatter = scatter_mean self.scatter = scatter_mean
else: else:
raise ValueError('incorrect scatter type') msg = "incorrect scatter type"
raise ValueError(msg)
def generate_plane_features(self, p, c, plane='xz'): def generate_plane_features(self, p, c, plane="xz"):
# acquire indices of features in plane # acquire indices of features in plane
xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1) xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
index = coordinate2index(xy, self.reso_plane) index = coordinate2index(xy, self.reso_plane)
# scatter plane features from points # scatter plane features from points
fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2) fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
c = c.permute(0, 2, 1) # B x 512 x T c = c.permute(0, 2, 1) # B x 512 x T
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2 fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso) fea_plane = fea_plane.reshape(
p.size(0), self.c_dim, self.reso_plane, self.reso_plane,
) # sparce matrix (B x 512 x reso x reso)
# process the plane features with UNet # process the plane features with UNet
if self.unet is not None: if self.unet is not None:
@ -96,12 +116,14 @@ class LocalPoolPointnet(nn.Module):
def generate_grid_features(self, p, c): def generate_grid_features(self, p, c):
p_nor = normalize_3d_coordinate(p.clone()) p_nor = normalize_3d_coordinate(p.clone())
index = coordinate2index(p_nor, self.reso_grid, coord_type='3d') index = coordinate2index(p_nor, self.reso_grid, coord_type="3d")
# scatter grid features from points # scatter grid features from points
fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3) fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3)
c = c.permute(0, 2, 1) c = c.permute(0, 2, 1)
fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3 fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3
fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) # sparce matrix (B x 512 x reso x reso) fea_grid = fea_grid.reshape(
p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid,
) # sparce matrix (B x 512 x reso x reso)
if self.unet3d is not None: if self.unet3d is not None:
fea_grid = self.unet3d(fea_grid) fea_grid = self.unet3d(fea_grid)
@ -115,7 +137,7 @@ class LocalPoolPointnet(nn.Module):
c_out = 0 c_out = 0
for key in keys: for key in keys:
# scatter plane features from points # scatter plane features from points
if key == 'grid': if key == "grid":
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid**3) fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid**3)
else: else:
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2) fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2)
@ -126,33 +148,31 @@ class LocalPoolPointnet(nn.Module):
c_out += fea c_out += fea
return c_out.permute(0, 2, 1) return c_out.permute(0, 2, 1)
def forward(self, p, normalize=True): def forward(self, p, normalize=True):
batch_size, T, D = p.size() batch_size, T, D = p.size()
# acquire the index for each point # acquire the index for each point
coord = {} coord = {}
index = {} index = {}
if 'xz' in self.plane_type: if "xz" in self.plane_type:
coord['xz'] = normalize_coordinate(p.clone(), plane='xz') coord["xz"] = normalize_coordinate(p.clone(), plane="xz")
index['xz'] = coordinate2index(coord['xz'], self.reso_plane) index["xz"] = coordinate2index(coord["xz"], self.reso_plane)
if 'xy' in self.plane_type: if "xy" in self.plane_type:
coord['xy'] = normalize_coordinate(p.clone(), plane='xy') coord["xy"] = normalize_coordinate(p.clone(), plane="xy")
index['xy'] = coordinate2index(coord['xy'], self.reso_plane) index["xy"] = coordinate2index(coord["xy"], self.reso_plane)
if 'yz' in self.plane_type: if "yz" in self.plane_type:
coord['yz'] = normalize_coordinate(p.clone(), plane='yz') coord["yz"] = normalize_coordinate(p.clone(), plane="yz")
index['yz'] = coordinate2index(coord['yz'], self.reso_plane) index["yz"] = coordinate2index(coord["yz"], self.reso_plane)
if 'grid' in self.plane_type: if "grid" in self.plane_type:
if normalize: if normalize:
coord['grid'] = normalize_3d_coordinate(p.clone()) coord["grid"] = normalize_3d_coordinate(p.clone())
else: else:
coord['grid'] = p.clone()[...,:3] coord["grid"] = p.clone()[..., :3]
index['grid'] = coordinate2index(coord['grid'], self.reso_grid, coord_type='3d') index["grid"] = coordinate2index(coord["grid"], self.reso_grid, coord_type="3d")
if self.pe: if self.pe:
p = self.pe(p) p = self.pe(p)
if self.map2local: if self.map2local:
pp = self.map2local(p) pp = self.map2local(p)
net = self.fc_pos(pp) net = self.fc_pos(pp)
@ -169,13 +189,13 @@ class LocalPoolPointnet(nn.Module):
c = self.fc_c(net) c = self.fc_c(net)
fea = {} fea = {}
if 'grid' in self.plane_type: if "grid" in self.plane_type:
fea['grid'] = self.generate_grid_features(p, c) fea["grid"] = self.generate_grid_features(p, c)
if 'xz' in self.plane_type: if "xz" in self.plane_type:
fea['xz'] = self.generate_plane_features(p, c, plane='xz') fea["xz"] = self.generate_plane_features(p, c, plane="xz")
if 'xy' in self.plane_type: if "xy" in self.plane_type:
fea['xy'] = self.generate_plane_features(p, c, plane='xy') fea["xy"] = self.generate_plane_features(p, c, plane="xy")
if 'yz' in self.plane_type: if "yz" in self.plane_type:
fea['yz'] = self.generate_plane_features(p, c, plane='yz') fea["yz"] = self.generate_plane_features(p, c, plane="yz")
return fea return fea

View file

@ -1,39 +1,41 @@
# code from IDR (https://github.com/lioryariv/idr/blob/main/code/model/implicit_differentiable_renderer.py) # code from IDR (https://github.com/lioryariv/idr/blob/main/code/model/implicit_differentiable_renderer.py)
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
from src.network.utils import get_embedder from src.network.utils import get_embedder
from pdb import set_trace as st
class RenderingNetwork(nn.Module): class RenderingNetwork(nn.Module):
def __init__( def __init__(
self, self,
fea_size=0, fea_size=0,
mode='naive', mode="naive",
d_out=3, d_out=3,
dims=[512, 512, 512, 512], dims=[512, 512, 512, 512],
weight_norm=True, weight_norm=True,
pe_freq_view=0 # for positional encoding pe_freq_view=0, # for positional encoding
): ):
super().__init__() super().__init__()
self.mode = mode self.mode = mode
if mode == 'naive': if mode == "naive":
d_in = 3 d_in = 3
elif mode == 'no_feature': elif mode == "no_feature":
d_in = 3 + 3 + 3 d_in = 3 + 3 + 3
fea_size = 0 fea_size = 0
elif mode == 'full': elif mode == "full":
d_in = 3 + 3 + 3 d_in = 3 + 3 + 3
else: else:
d_in = 3 + 3 d_in = 3 + 3
dims = [d_in + fea_size] + dims + [d_out] dims = [d_in + fea_size, *dims, d_out]
self.embedview_fn = None self.embedview_fn = None
if pe_freq_view > 0: if pe_freq_view > 0:
embedview_fn, input_ch = get_embedder(pe_freq_view, d_in=3) embedview_fn, input_ch = get_embedder(pe_freq_view, d_in=3)
self.embedview_fn = embedview_fn self.embedview_fn = embedview_fn
dims[0] += (input_ch - 3) dims[0] += input_ch - 3
self.num_layers = len(dims) self.num_layers = len(dims)
@ -54,13 +56,13 @@ class RenderingNetwork(nn.Module):
view_dirs = self.embedview_fn(view_dirs) view_dirs = self.embedview_fn(view_dirs)
# points = self.embedview_fn(points) # points = self.embedview_fn(points)
if (self.mode == 'full') & (feature_vectors is not None): if (self.mode == "full") & (feature_vectors is not None):
rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1) rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
elif (self.mode == 'no_feature') | ((self.mode == 'full') & (feature_vectors is None)): elif (self.mode == "no_feature") | ((self.mode == "full") & (feature_vectors is None)):
rendering_input = torch.cat([points, view_dirs, normals], dim=-1) rendering_input = torch.cat([points, view_dirs, normals], dim=-1)
elif self.mode == 'no_view_dir': elif self.mode == "no_view_dir":
rendering_input = torch.cat([points, normals], dim=-1) rendering_input = torch.cat([points, normals], dim=-1)
elif self.mode == 'no_normal': elif self.mode == "no_normal":
rendering_input = torch.cat([points, view_dirs], dim=-1) rendering_input = torch.cat([points, view_dirs], dim=-1)
else: else:
rendering_input = points rendering_input = points
@ -81,28 +83,27 @@ class RenderingNetwork(nn.Module):
class NeRFRenderingNetwork(nn.Module): class NeRFRenderingNetwork(nn.Module):
def __init__( def __init__(
self, self,
feature_vector_size=0, feature_vector_size=0,
mode='naive', mode="naive",
d_in=3, d_in=3,
d_out=3, d_out=3,
dims=[512, 512, 512, 256], dims=[512, 512, 512, 256],
weight_norm=True, weight_norm=True,
multires=0, # positional encoding of points multires=0, # positional encoding of points
multires_view=0 # positional encoding of view multires_view=0, # positional encoding of view
): ):
super().__init__() super().__init__()
self.mode = mode self.mode = mode
dims = [d_in + feature_vector_size] + dims dims = [d_in + feature_vector_size, *dims]
self.embed_fn = None self.embed_fn = None
if multires > 0: if multires > 0:
embed_fn, input_ch = get_embedder(multires, d_in=d_in) embed_fn, input_ch = get_embedder(multires, d_in=d_in)
self.embed_fn = embed_fn self.embed_fn = embed_fn
dims[0] += (input_ch - 3) dims[0] += input_ch - 3
self.num_layers = len(dims) self.num_layers = len(dims)
self.pts_net = nn.ModuleList([nn.Linear(dims[i], dims[i + 1]) for i in range(self.num_layers - 1)]) self.pts_net = nn.ModuleList([nn.Linear(dims[i], dims[i + 1]) for i in range(self.num_layers - 1)])
@ -113,13 +114,12 @@ class NeRFRenderingNetwork(nn.Module):
self.embedview_fn = embedview_fn self.embedview_fn = embedview_fn
# dims[0] += (input_ch - 3) # dims[0] += (input_ch - 3)
if mode == 'full': if mode == "full":
self.view_net = nn.ModuleList([nn.Linear(dims[-1]+view_ch, 128)]) self.view_net = nn.ModuleList([nn.Linear(dims[-1] + view_ch, 128)])
self.rgb_net = nn.Linear(128, 3) self.rgb_net = nn.Linear(128, 3)
else: else:
self.rgb_net = nn.Linear(dims[-1], 3) self.rgb_net = nn.Linear(dims[-1], 3)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.tanh = nn.Tanh() self.tanh = nn.Tanh()
@ -134,7 +134,7 @@ class NeRFRenderingNetwork(nn.Module):
x = net(x) x = net(x)
x = self.relu(x) x = self.relu(x)
if self.mode=='full': if self.mode == "full":
x = torch.cat([x, view_dirs], -1) x = torch.cat([x, view_dirs], -1)
for net in self.view_net: for net in self.view_net:
x = net(x) x = net(x)
@ -144,22 +144,23 @@ class NeRFRenderingNetwork(nn.Module):
x = self.tanh(x) x = self.tanh(x)
return x return x
class ImplicitNetwork(nn.Module): class ImplicitNetwork(nn.Module):
def __init__( def __init__(
self, self,
d_in, d_in,
d_out, d_out,
dims, dims,
geometric_init=True, geometric_init=True,
feature_vector_size=0, feature_vector_size=0,
bias=1.0, bias=1.0,
skip_in=(), skip_in=(),
weight_norm=True, weight_norm=True,
multires=0 multires=0,
): ):
super().__init__() super().__init__()
dims = [d_in] + dims + [d_out + feature_vector_size] dims = [d_in, *dims, d_out + feature_vector_size]
self.embed_fn = None self.embed_fn = None
if multires > 0: if multires > 0:
@ -189,7 +190,7 @@ class ImplicitNetwork(nn.Module):
elif multires > 0 and l in self.skip_in: elif multires > 0 and l in self.skip_in:
torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0)
else: else:
torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
@ -222,13 +223,9 @@ class ImplicitNetwork(nn.Module):
def gradient(self, x): def gradient(self, x):
x.requires_grad_(True) x.requires_grad_(True)
y = self.forward(x)[:,:1] y = self.forward(x)[:, :1]
d_output = torch.ones_like(y, requires_grad=False, device=y.device) d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad( gradients = torch.autograd.grad(
outputs=y, outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True,
inputs=x, )[0]
grad_outputs=d_output, return gradients.unsqueeze(1)
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return gradients.unsqueeze(1)

View file

@ -1,55 +1,38 @@
''' """Codes are from:
Codes are from: https://github.com/jaxony/unet-pytorch/blob/master/model.py.
https://github.com/jaxony/unet-pytorch/blob/master/model.py """
'''
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import init from torch.nn import init
import numpy as np
def conv3x3(in_channels, out_channels, stride=1,
padding=1, bias=True, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=padding,
bias=bias,
groups=groups)
def upconv2x2(in_channels, out_channels, mode='transpose'): def conv3x3(in_channels, out_channels, stride=1, padding=1, bias=True, groups=1):
if mode == 'transpose': return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, bias=bias, groups=groups)
return nn.ConvTranspose2d(
in_channels,
out_channels, def upconv2x2(in_channels, out_channels, mode="transpose"):
kernel_size=2, if mode == "transpose":
stride=2) return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
else: else:
# out_channels is always going to be the same # out_channels is always going to be the same
# as in_channels # as in_channels
return nn.Sequential( return nn.Sequential(nn.Upsample(mode="bilinear", scale_factor=2), conv1x1(in_channels, out_channels))
nn.Upsample(mode='bilinear', scale_factor=2),
conv1x1(in_channels, out_channels))
def conv1x1(in_channels, out_channels, groups=1): def conv1x1(in_channels, out_channels, groups=1):
return nn.Conv2d( return nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=groups, stride=1)
in_channels,
out_channels,
kernel_size=1,
groups=groups,
stride=1)
class DownConv(nn.Module): class DownConv(nn.Module):
""" """A helper Module that performs 2 convolutions and 1 MaxPool.
A helper Module that performs 2 convolutions and 1 MaxPool.
A ReLU activation follows each convolution. A ReLU activation follows each convolution.
""" """
def __init__(self, in_channels, out_channels, pooling=True): def __init__(self, in_channels, out_channels, pooling=True):
super(DownConv, self).__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -71,39 +54,35 @@ class DownConv(nn.Module):
class UpConv(nn.Module): class UpConv(nn.Module):
""" """A helper Module that performs 2 convolutions and 1 UpConvolution.
A helper Module that performs 2 convolutions and 1 UpConvolution.
A ReLU activation follows each convolution. A ReLU activation follows each convolution.
""" """
def __init__(self, in_channels, out_channels,
merge_mode='concat', up_mode='transpose'): def __init__(self, in_channels, out_channels, merge_mode="concat", up_mode="transpose"):
super(UpConv, self).__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.merge_mode = merge_mode self.merge_mode = merge_mode
self.up_mode = up_mode self.up_mode = up_mode
self.upconv = upconv2x2(self.in_channels, self.out_channels, self.upconv = upconv2x2(self.in_channels, self.out_channels, mode=self.up_mode)
mode=self.up_mode)
if self.merge_mode == 'concat': if self.merge_mode == "concat":
self.conv1 = conv3x3( self.conv1 = conv3x3(2 * self.out_channels, self.out_channels)
2*self.out_channels, self.out_channels)
else: else:
# num of input channels to conv2 is same # num of input channels to conv2 is same
self.conv1 = conv3x3(self.out_channels, self.out_channels) self.conv1 = conv3x3(self.out_channels, self.out_channels)
self.conv2 = conv3x3(self.out_channels, self.out_channels) self.conv2 = conv3x3(self.out_channels, self.out_channels)
def forward(self, from_down, from_up): def forward(self, from_down, from_up):
""" Forward pass """Forward pass
Arguments: Arguments:
from_down: tensor from the encoder pathway from_down: tensor from the encoder pathway
from_up: upconv'd tensor from the decoder pathway from_up: upconv'd tensor from the decoder pathway.
""" """
from_up = self.upconv(from_up) from_up = self.upconv(from_up)
if self.merge_mode == 'concat': if self.merge_mode == "concat":
x = torch.cat((from_up, from_down), 1) x = torch.cat((from_up, from_down), 1)
else: else:
x = from_up + from_down x = from_up + from_down
@ -113,7 +92,7 @@ class UpConv(nn.Module):
class UNet(nn.Module): class UNet(nn.Module):
""" `UNet` class is based on https://arxiv.org/abs/1505.04597 """`UNet` class is based on https://arxiv.org/abs/1505.04597.
The U-Net is a convolutional encoder-decoder neural network. The U-Net is a convolutional encoder-decoder neural network.
Contextual spatial information (from the decoding, Contextual spatial information (from the decoding,
@ -135,44 +114,37 @@ class UNet(nn.Module):
the tranpose convolution (specified by upmode='transpose') the tranpose convolution (specified by upmode='transpose')
""" """
def __init__(self, num_classes, in_channels=3, depth=5, def __init__(
start_filts=64, up_mode='transpose', self, num_classes, in_channels=3, depth=5, start_filts=64, up_mode="transpose", merge_mode="concat", **kwargs,
merge_mode='concat', **kwargs): ):
"""Arguments:
in_channels: int, number of channels in the input tensor.
Default is 3 for RGB images.
depth: int, number of MaxPools in the U-Net.
start_filts: int, number of convolutional filters for the
first conv.
up_mode: string, type of upconvolution. Choices: 'transpose'
for transpose convolution or 'upsample' for nearest neighbour
upsampling.
""" """
Arguments: super().__init__()
in_channels: int, number of channels in the input tensor.
Default is 3 for RGB images.
depth: int, number of MaxPools in the U-Net.
start_filts: int, number of convolutional filters for the
first conv.
up_mode: string, type of upconvolution. Choices: 'transpose'
for transpose convolution or 'upsample' for nearest neighbour
upsampling.
"""
super(UNet, self).__init__()
if up_mode in ('transpose', 'upsample'): if up_mode in ("transpose", "upsample"):
self.up_mode = up_mode self.up_mode = up_mode
else: else:
raise ValueError("\"{}\" is not a valid mode for " msg = f'"{up_mode}" is not a valid mode for upsampling. Only "transpose" and "upsample" are allowed.'
"upsampling. Only \"transpose\" and " raise ValueError(msg)
"\"upsample\" are allowed.".format(up_mode))
if merge_mode in ("concat", "add"):
if merge_mode in ('concat', 'add'):
self.merge_mode = merge_mode self.merge_mode = merge_mode
else: else:
raise ValueError("\"{}\" is not a valid mode for" msg = f'"{up_mode}" is not a valid mode formerging up and down paths. Only "concat" and "add" are allowed.'
"merging up and down paths. " raise ValueError(msg)
"Only \"concat\" and "
"\"add\" are allowed.".format(up_mode))
# NOTE: up_mode 'upsample' is incompatible with merge_mode 'add' # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
if self.up_mode == 'upsample' and self.merge_mode == 'add': if self.up_mode == "upsample" and self.merge_mode == "add":
raise ValueError("up_mode \"upsample\" is incompatible " msg = 'up_mode "upsample" is incompatible with merge_mode "add" at the moment because it doesn\'t make sense to use nearest neighbour to reduce depth channels (by half).'
"with merge_mode \"add\" at the moment " raise ValueError(msg)
"because it doesn't make sense to use "
"nearest neighbour to reduce "
"depth channels (by half).")
self.num_classes = num_classes self.num_classes = num_classes
self.in_channels = in_channels self.in_channels = in_channels
@ -185,19 +157,18 @@ class UNet(nn.Module):
# create the encoder pathway and add to a list # create the encoder pathway and add to a list
for i in range(depth): for i in range(depth):
ins = self.in_channels if i == 0 else outs ins = self.in_channels if i == 0 else outs
outs = self.start_filts*(2**i) outs = self.start_filts * (2**i)
pooling = True if i < depth-1 else False pooling = True if i < depth - 1 else False
down_conv = DownConv(ins, outs, pooling=pooling) down_conv = DownConv(ins, outs, pooling=pooling)
self.down_convs.append(down_conv) self.down_convs.append(down_conv)
# create the decoder pathway and add to a list # create the decoder pathway and add to a list
# - careful! decoding only requires depth-1 blocks # - careful! decoding only requires depth-1 blocks
for i in range(depth-1): for i in range(depth - 1):
ins = outs ins = outs
outs = ins // 2 outs = ins // 2
up_conv = UpConv(ins, outs, up_mode=up_mode, up_conv = UpConv(ins, outs, up_mode=up_mode, merge_mode=merge_mode)
merge_mode=merge_mode)
self.up_convs.append(up_conv) self.up_convs.append(up_conv)
# add the list of modules to current module # add the list of modules to current module
@ -214,12 +185,10 @@ class UNet(nn.Module):
init.xavier_normal_(m.weight) init.xavier_normal_(m.weight)
init.constant_(m.bias, 0) init.constant_(m.bias, 0)
def reset_params(self): def reset_params(self):
for i, m in enumerate(self.modules()): for _i, m in enumerate(self.modules()):
self.weight_init(m) self.weight_init(m)
def forward(self, x): def forward(self, x):
encoder_outs = [] encoder_outs = []
# encoder pathway, save outputs for merging # encoder pathway, save outputs for merging
@ -227,30 +196,31 @@ class UNet(nn.Module):
x, before_pool = module(x) x, before_pool = module(x)
encoder_outs.append(before_pool) encoder_outs.append(before_pool)
for i, module in enumerate(self.up_convs): for i, module in enumerate(self.up_convs):
before_pool = encoder_outs[-(i+2)] before_pool = encoder_outs[-(i + 2)]
x = module(before_pool, x) x = module(before_pool, x)
# No softmax is used. This means you need to use # No softmax is used. This means you need to use
# nn.CrossEntropyLoss is your training script, # nn.CrossEntropyLoss is your training script,
# as this module includes a softmax already. # as this module includes a softmax already.
x = self.conv_final(x) x = self.conv_final(x)
return x return x
if __name__ == "__main__": if __name__ == "__main__":
""" """
testing testing
""" """
model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32) model = UNet(1, depth=5, merge_mode="concat", in_channels=1, start_filts=32)
print(model) print(model)
print(sum(p.numel() for p in model.parameters())) print(sum(p.numel() for p in model.parameters()))
reso = 176 reso = 176
x = np.zeros((1, 1, reso, reso)) x = np.zeros((1, 1, reso, reso))
x[:,:,int(reso/2-1), int(reso/2-1)] = np.nan x[:, :, int(reso / 2 - 1), int(reso / 2 - 1)] = np.nan
x = torch.FloatTensor(x) x = torch.FloatTensor(x)
out = model(x) out = model(x)
print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso))) print("%f" % (torch.sum(torch.isnan(out)).detach().cpu().numpy() / (reso * reso)))
# loss = torch.sum(out) # loss = torch.sum(out)
# loss.backward() # loss.backward()

View file

@ -1,16 +1,18 @@
''' """Code from the 3D UNet implementation:
Code from the 3D UNet implementation: https://github.com/wolny/pytorch-3dunet/.
https://github.com/wolny/pytorch-3dunet/ """
'''
import importlib import importlib
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from functools import partial
from src.network.utils import get_embedder from src.network.utils import get_embedder
def number_of_features_per_level(init_channel_number, num_levels): def number_of_features_per_level(init_channel_number, num_levels):
return [init_channel_number * 2 ** k for k in range(num_levels)] return [init_channel_number * 2**k for k in range(num_levels)]
def conv3d(in_channels, out_channels, kernel_size, bias, padding=1): def conv3d(in_channels, out_channels, kernel_size, bias, padding=1):
@ -18,8 +20,7 @@ def conv3d(in_channels, out_channels, kernel_size, bias, padding=1):
def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1): def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1):
""" """Create a list of modules with together constitute a single conv layer with non-linearity
Create a list of modules with together constitute a single conv layer with non-linearity
and optional batchnorm/groupnorm. and optional batchnorm/groupnorm.
Args: Args:
@ -37,23 +38,23 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi
Return: Return:
list of tuple (name, module) list of tuple (name, module)
""" """
assert 'c' in order, "Conv layer MUST be present" assert "c" in order, "Conv layer MUST be present"
assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' assert order[0] not in "rle", "Non-linearity cannot be the first operation in the layer"
modules = [] modules = []
for i, char in enumerate(order): for i, char in enumerate(order):
if char == 'r': if char == "r":
modules.append(('ReLU', nn.ReLU(inplace=True))) modules.append(("ReLU", nn.ReLU(inplace=True)))
elif char == 'l': elif char == "l":
modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True))) modules.append(("LeakyReLU", nn.LeakyReLU(negative_slope=0.1, inplace=True)))
elif char == 'e': elif char == "e":
modules.append(('ELU', nn.ELU(inplace=True))) modules.append(("ELU", nn.ELU(inplace=True)))
elif char == 'c': elif char == "c":
# add learnable bias only in the absence of batchnorm/groupnorm # add learnable bias only in the absence of batchnorm/groupnorm
bias = not ('g' in order or 'b' in order) bias = not ("g" in order or "b" in order)
modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding))) modules.append(("conv", conv3d(in_channels, out_channels, kernel_size, bias, padding=padding)))
elif char == 'g': elif char == "g":
is_before_conv = i < order.index('c') is_before_conv = i < order.index("c")
if is_before_conv: if is_before_conv:
num_channels = in_channels num_channels = in_channels
else: else:
@ -63,14 +64,16 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi
if num_channels < num_groups: if num_channels < num_groups:
num_groups = 1 num_groups = 1
assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}' assert (
modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels))) num_channels % num_groups == 0
elif char == 'b': ), f"Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}"
is_before_conv = i < order.index('c') modules.append(("groupnorm", nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)))
elif char == "b":
is_before_conv = i < order.index("c")
if is_before_conv: if is_before_conv:
modules.append(('batchnorm', nn.BatchNorm3d(in_channels))) modules.append(("batchnorm", nn.BatchNorm3d(in_channels)))
else: else:
modules.append(('batchnorm', nn.BatchNorm3d(out_channels))) modules.append(("batchnorm", nn.BatchNorm3d(out_channels)))
else: else:
raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']") raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']")
@ -78,9 +81,8 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi
class SingleConv(nn.Sequential): class SingleConv(nn.Sequential):
""" """Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order of operations can be specified via the `order` parameter.
of operations can be specified via the `order` parameter
Args: Args:
in_channels (int): number of input channels in_channels (int): number of input channels
@ -94,16 +96,15 @@ class SingleConv(nn.Sequential):
num_groups (int): number of groups for the GroupNorm num_groups (int): number of groups for the GroupNorm
""" """
def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8, padding=1): def __init__(self, in_channels, out_channels, kernel_size=3, order="crg", num_groups=8, padding=1):
super(SingleConv, self).__init__() super().__init__()
for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding): for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding):
self.add_module(name, module) self.add_module(name, module)
class DoubleConv(nn.Sequential): class DoubleConv(nn.Sequential):
""" """A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
We use (Conv3d+ReLU+GroupNorm3d) by default. We use (Conv3d+ReLU+GroupNorm3d) by default.
This can be changed however by providing the 'order' argument, e.g. in order This can be changed however by providing the 'order' argument, e.g. in order
to change to Conv3d+BatchNorm3d+ELU use order='cbe'. to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
@ -123,8 +124,8 @@ class DoubleConv(nn.Sequential):
num_groups (int): number of groups for the GroupNorm num_groups (int): number of groups for the GroupNorm
""" """
def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='crg', num_groups=8): def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order="crg", num_groups=8):
super(DoubleConv, self).__init__() super().__init__()
if encoder: if encoder:
# we're in the encoder path # we're in the encoder path
conv1_in_channels = in_channels conv1_in_channels = in_channels
@ -138,26 +139,27 @@ class DoubleConv(nn.Sequential):
conv2_in_channels, conv2_out_channels = out_channels, out_channels conv2_in_channels, conv2_out_channels = out_channels, out_channels
# conv1 # conv1
self.add_module('SingleConv1', self.add_module(
SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups)) "SingleConv1", SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups),
)
# conv2 # conv2
self.add_module('SingleConv2', self.add_module(
SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups)) "SingleConv2", SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups),
)
class ExtResNetBlock(nn.Module): class ExtResNetBlock(nn.Module):
""" """Basic UNet block consisting of a SingleConv followed by the residual block.
Basic UNet block consisting of a SingleConv followed by the residual block.
The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number
of output channels is compatible with the residual block that follows. of output channels is compatible with the residual block that follows.
This block can be used instead of standard DoubleConv in the Encoder module. This block can be used instead of standard DoubleConv in the Encoder module.
Motivated by: https://arxiv.org/pdf/1706.00120.pdf Motivated by: https://arxiv.org/pdf/1706.00120.pdf.
Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm. Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm.
""" """
def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs): def __init__(self, in_channels, out_channels, kernel_size=3, order="cge", num_groups=8, **kwargs):
super(ExtResNetBlock, self).__init__() super().__init__()
# first convolution # first convolution
self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
@ -165,15 +167,16 @@ class ExtResNetBlock(nn.Module):
self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
# remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
n_order = order n_order = order
for c in 'rel': for c in "rel":
n_order = n_order.replace(c, '') n_order = n_order.replace(c, "")
self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order, self.conv3 = SingleConv(
num_groups=num_groups) out_channels, out_channels, kernel_size=kernel_size, order=n_order, num_groups=num_groups,
)
# create non-linearity separately # create non-linearity separately
if 'l' in order: if "l" in order:
self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True) self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True)
elif 'e' in order: elif "e" in order:
self.non_linearity = nn.ELU(inplace=True) self.non_linearity = nn.ELU(inplace=True)
else: else:
self.non_linearity = nn.ReLU(inplace=True) self.non_linearity = nn.ReLU(inplace=True)
@ -194,12 +197,12 @@ class ExtResNetBlock(nn.Module):
class Encoder(nn.Module): class Encoder(nn.Module):
""" """A single module from the encoder path consisting of the optional max
A single module from the encoder path consisting of the optional max
pooling layer (one may specify the MaxPool kernel_size to be different pooling layer (one may specify the MaxPool kernel_size to be different
than the standard (2,2,2), e.g. if the volumetric data is anisotropic than the standard (2,2,2), e.g. if the volumetric data is anisotropic
(make sure to use complementary scale_factor in the decoder path) followed by (make sure to use complementary scale_factor in the decoder path) followed by
a DoubleConv module. a DoubleConv module.
Args: Args:
in_channels (int): number of input channels in_channels (int): number of input channels
out_channels (int): number of output channels out_channels (int): number of output channels
@ -210,27 +213,39 @@ class Encoder(nn.Module):
basic_module(nn.Module): either ResNetBlock or DoubleConv basic_module(nn.Module): either ResNetBlock or DoubleConv
conv_layer_order (string): determines the order of layers conv_layer_order (string): determines the order of layers
in `DoubleConv` module. See `DoubleConv` for more info. in `DoubleConv` module. See `DoubleConv` for more info.
num_groups (int): number of groups for the GroupNorm num_groups (int): number of groups for the GroupNorm.
""" """
def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, def __init__(
pool_kernel_size=(2, 2, 2), pool_type='max', basic_module=DoubleConv, conv_layer_order='crg', self,
num_groups=8): in_channels,
super(Encoder, self).__init__() out_channels,
assert pool_type in ['max', 'avg'] conv_kernel_size=3,
apply_pooling=True,
pool_kernel_size=(2, 2, 2),
pool_type="max",
basic_module=DoubleConv,
conv_layer_order="crg",
num_groups=8,
):
super().__init__()
assert pool_type in ["max", "avg"]
if apply_pooling: if apply_pooling:
if pool_type == 'max': if pool_type == "max":
self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size) self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
else: else:
self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size) self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size)
else: else:
self.pooling = None self.pooling = None
self.basic_module = basic_module(in_channels, out_channels, self.basic_module = basic_module(
encoder=True, in_channels,
kernel_size=conv_kernel_size, out_channels,
order=conv_layer_order, encoder=True,
num_groups=num_groups) kernel_size=conv_kernel_size,
order=conv_layer_order,
num_groups=num_groups,
)
def forward(self, x): def forward(self, x):
if self.pooling is not None: if self.pooling is not None:
@ -240,9 +255,9 @@ class Encoder(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
""" """A single module for decoder path consisting of the upsampling layer
A single module for decoder path consisting of the upsampling layer
(either learned ConvTranspose3d or nearest neighbor interpolation) followed by a basic module (DoubleConv or ExtResNetBlock). (either learned ConvTranspose3d or nearest neighbor interpolation) followed by a basic module (DoubleConv or ExtResNetBlock).
Args: Args:
in_channels (int): number of input channels in_channels (int): number of input channels
out_channels (int): number of output channels out_channels (int): number of output channels
@ -253,32 +268,56 @@ class Decoder(nn.Module):
basic_module(nn.Module): either ResNetBlock or DoubleConv basic_module(nn.Module): either ResNetBlock or DoubleConv
conv_layer_order (string): determines the order of layers conv_layer_order (string): determines the order of layers
in `DoubleConv` module. See `DoubleConv` for more info. in `DoubleConv` module. See `DoubleConv` for more info.
num_groups (int): number of groups for the GroupNorm num_groups (int): number of groups for the GroupNorm.
""" """
def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv, def __init__(
conv_layer_order='crg', num_groups=8, mode='nearest'): self,
super(Decoder, self).__init__() in_channels,
out_channels,
kernel_size=3,
scale_factor=(2, 2, 2),
basic_module=DoubleConv,
conv_layer_order="crg",
num_groups=8,
mode="nearest",
):
super().__init__()
if basic_module == DoubleConv: if basic_module == DoubleConv:
# if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining # if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining
self.upsampling = Upsampling(transposed_conv=False, in_channels=in_channels, out_channels=out_channels, self.upsampling = Upsampling(
kernel_size=kernel_size, scale_factor=scale_factor, mode=mode) transposed_conv=False,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
scale_factor=scale_factor,
mode=mode,
)
# concat joining # concat joining
self.joining = partial(self._joining, concat=True) self.joining = partial(self._joining, concat=True)
else: else:
# if basic_module=ExtResNetBlock use transposed convolution upsampling and summation joining # if basic_module=ExtResNetBlock use transposed convolution upsampling and summation joining
self.upsampling = Upsampling(transposed_conv=True, in_channels=in_channels, out_channels=out_channels, self.upsampling = Upsampling(
kernel_size=kernel_size, scale_factor=scale_factor, mode=mode) transposed_conv=True,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
scale_factor=scale_factor,
mode=mode,
)
# sum joining # sum joining
self.joining = partial(self._joining, concat=False) self.joining = partial(self._joining, concat=False)
# adapt the number of in_channels for the ExtResNetBlock # adapt the number of in_channels for the ExtResNetBlock
in_channels = out_channels in_channels = out_channels
self.basic_module = basic_module(in_channels, out_channels, self.basic_module = basic_module(
encoder=False, in_channels,
kernel_size=kernel_size, out_channels,
order=conv_layer_order, encoder=False,
num_groups=num_groups) kernel_size=kernel_size,
order=conv_layer_order,
num_groups=num_groups,
)
def forward(self, encoder_features, x): def forward(self, encoder_features, x):
x = self.upsampling(encoder_features=encoder_features, x=x) x = self.upsampling(encoder_features=encoder_features, x=x)
@ -295,8 +334,7 @@ class Decoder(nn.Module):
class Upsampling(nn.Module): class Upsampling(nn.Module):
""" """Upsamples a given multi-channel 3D data using either interpolation or learned transposed convolution.
Upsamples a given multi-channel 3D data using either interpolation or learned transposed convolution.
Args: Args:
transposed_conv (bool): if True uses ConvTranspose3d for upsampling, otherwise uses interpolation transposed_conv (bool): if True uses ConvTranspose3d for upsampling, otherwise uses interpolation
@ -310,15 +348,23 @@ class Upsampling(nn.Module):
'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest' 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest'
""" """
def __init__(self, transposed_conv, in_channels=None, out_channels=None, kernel_size=3, def __init__(
scale_factor=(2, 2, 2), mode='nearest'): self,
super(Upsampling, self).__init__() transposed_conv,
in_channels=None,
out_channels=None,
kernel_size=3,
scale_factor=(2, 2, 2),
mode="nearest",
):
super().__init__()
if transposed_conv: if transposed_conv:
# make sure that the output size reverses the MaxPool3d from the corresponding encoder # make sure that the output size reverses the MaxPool3d from the corresponding encoder
# (D_out=(D_in1)×stride[0]2×padding[0]+kernel_size[0]+output_padding[0]) # (D_out = (D_in - 1) x stride[0] - 2 x padding[0] + kernel_size[0] + output_padding[0])
self.upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor, self.upsample = nn.ConvTranspose3d(
padding=1) in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor, padding=1,
)
else: else:
self.upsample = partial(self._interpolate, mode=mode) self.upsample = partial(self._interpolate, mode=mode)
@ -332,13 +378,13 @@ class Upsampling(nn.Module):
class FinalConv(nn.Sequential): class FinalConv(nn.Sequential):
""" """A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution
A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution
which reduces the number of channels to 'out_channels'. which reduces the number of channels to 'out_channels'.
with the number of output channels 'out_channels // 2' and 'out_channels' respectively. with the number of output channels 'out_channels // 2' and 'out_channels' respectively.
We use (Conv3d+ReLU+GroupNorm3d) by default. We use (Conv3d+ReLU+GroupNorm3d) by default.
This can be change however by providing the 'order' argument, e.g. in order This can be change however by providing the 'order' argument, e.g. in order
to change to Conv3d+BatchNorm3d+ReLU use order='cbr'. to change to Conv3d+BatchNorm3d+ReLU use order='cbr'.
Args: Args:
in_channels (int): number of input channels in_channels (int): number of input channels
out_channels (int): number of output channels out_channels (int): number of output channels
@ -346,22 +392,22 @@ class FinalConv(nn.Sequential):
order (string): determines the order of layers, e.g. order (string): determines the order of layers, e.g.
'cr' -> conv + ReLU 'cr' -> conv + ReLU
'crg' -> conv + ReLU + groupnorm 'crg' -> conv + ReLU + groupnorm
num_groups (int): number of groups for the GroupNorm num_groups (int): number of groups for the GroupNorm.
""" """
def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8): def __init__(self, in_channels, out_channels, kernel_size=3, order="crg", num_groups=8):
super(FinalConv, self).__init__() super().__init__()
# conv1 # conv1
self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups)) self.add_module("SingleConv", SingleConv(in_channels, in_channels, kernel_size, order, num_groups))
# in the last layer a 1×1 convolution reduces the number of output channels to out_channels # in the last layer a 1x1 convolution reduces the number of output channels to out_channels
final_conv = nn.Conv3d(in_channels, out_channels, 1) final_conv = nn.Conv3d(in_channels, out_channels, 1)
self.add_module('final_conv', final_conv) self.add_module("final_conv", final_conv)
class Abstract3DUNet(nn.Module): class Abstract3DUNet(nn.Module):
""" """Base class for standard and residual UNet.
Base class for standard and residual UNet.
Args: Args:
in_channels (int): number of input channels in_channels (int): number of input channels
@ -391,9 +437,22 @@ class Abstract3DUNet(nn.Module):
and the `final_activation` (even if present) won't be applied; default: False and the `final_activation` (even if present) won't be applied; default: False
""" """
def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr', def __init__(
num_groups=8, num_levels=4, is_segmentation=False, testing=False, pe_freq=0, **kwargs): self,
super(Abstract3DUNet, self).__init__() in_channels,
out_channels,
final_sigmoid,
basic_module,
f_maps=64,
layer_order="gcr",
num_groups=8,
num_levels=4,
is_segmentation=False,
testing=False,
pe_freq=0,
**kwargs,
):
super().__init__()
self.testing = testing self.testing = testing
@ -411,13 +470,24 @@ class Abstract3DUNet(nn.Module):
encoders = [] encoders = []
for i, out_feature_num in enumerate(f_maps): for i, out_feature_num in enumerate(f_maps):
if i == 0: if i == 0:
encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=basic_module, encoder = Encoder(
conv_layer_order=layer_order, num_groups=num_groups) in_channels,
out_feature_num,
apply_pooling=False,
basic_module=basic_module,
conv_layer_order=layer_order,
num_groups=num_groups,
)
else: else:
# TODO: adapt for anisotropy in the data, i.e. use proper pooling kernel to make the data isotropic after 1-2 pooling operations # TODO: adapt for anisotropy in the data, i.e. use proper pooling kernel to make the data isotropic after 1-2 pooling operations
# currently pools with a constant kernel: (2, 2, 2) # currently pools with a constant kernel: (2, 2, 2)
encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=basic_module, encoder = Encoder(
conv_layer_order=layer_order, num_groups=num_groups) f_maps[i - 1],
out_feature_num,
basic_module=basic_module,
conv_layer_order=layer_order,
num_groups=num_groups,
)
encoders.append(encoder) encoders.append(encoder)
self.encoders = nn.ModuleList(encoders) self.encoders = nn.ModuleList(encoders)
@ -434,13 +504,18 @@ class Abstract3DUNet(nn.Module):
out_feature_num = reversed_f_maps[i + 1] out_feature_num = reversed_f_maps[i + 1]
# TODO: if non-standard pooling was used, make sure to use correct striding for transpose conv # TODO: if non-standard pooling was used, make sure to use correct striding for transpose conv
# currently strides with a constant stride: (2, 2, 2) # currently strides with a constant stride: (2, 2, 2)
decoder = Decoder(in_feature_num, out_feature_num, basic_module=basic_module, decoder = Decoder(
conv_layer_order=layer_order, num_groups=num_groups) in_feature_num,
out_feature_num,
basic_module=basic_module,
conv_layer_order=layer_order,
num_groups=num_groups,
)
decoders.append(decoder) decoders.append(decoder)
self.decoders = nn.ModuleList(decoders) self.decoders = nn.ModuleList(decoders)
# in the last layer a 1×1 convolution reduces the number of output # in the last layer a 1x1 convolution reduces the number of output
# channels to the number of labels # channels to the number of labels
self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)
@ -455,7 +530,6 @@ class Abstract3DUNet(nn.Module):
self.final_activation = None self.final_activation = None
def forward(self, x): def forward(self, x):
if self.embed_fn is not None: if self.embed_fn is not None:
x = self.embed_fn(x.permute(0, 2, 3, 4, 1)) x = self.embed_fn(x.permute(0, 2, 3, 4, 1))
x = x.permute(0, 4, 1, 2, 3) x = x.permute(0, 4, 1, 2, 3)
@ -488,49 +562,81 @@ class Abstract3DUNet(nn.Module):
class UNet3D(Abstract3DUNet): class UNet3D(Abstract3DUNet):
""" """3DUnet model from
3DUnet model from
`"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation" `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation"
<https://arxiv.org/pdf/1606.06650.pdf>`. <https://arxiv.org/pdf/1606.06650.pdf>`.
Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder
""" """
def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', def __init__(
num_groups=8, num_levels=4, is_segmentation=True, **kwargs): self,
super(UNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid, in_channels,
basic_module=DoubleConv, f_maps=f_maps, layer_order=layer_order, out_channels,
num_groups=num_groups, num_levels=num_levels, is_segmentation=is_segmentation, final_sigmoid=True,
**kwargs) f_maps=64,
layer_order="gcr",
num_groups=8,
num_levels=4,
is_segmentation=True,
**kwargs,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
final_sigmoid=final_sigmoid,
basic_module=DoubleConv,
f_maps=f_maps,
layer_order=layer_order,
num_groups=num_groups,
num_levels=num_levels,
is_segmentation=is_segmentation,
**kwargs,
)
class ResidualUNet3D(Abstract3DUNet): class ResidualUNet3D(Abstract3DUNet):
""" """Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
Uses ExtResNetBlock as a basic building block, summation joining instead Uses ExtResNetBlock as a basic building block, summation joining instead
of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts). of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts).
Since the model effectively becomes a residual net, in theory it allows for deeper UNet. Since the model effectively becomes a residual net, in theory it allows for deeper UNet.
""" """
def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', def __init__(
num_groups=8, num_levels=5, is_segmentation=True, **kwargs): self,
super(ResidualUNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, in_channels,
final_sigmoid=final_sigmoid, out_channels,
basic_module=ExtResNetBlock, f_maps=f_maps, layer_order=layer_order, final_sigmoid=True,
num_groups=num_groups, num_levels=num_levels, f_maps=64,
is_segmentation=is_segmentation, layer_order="gcr",
**kwargs) num_groups=8,
num_levels=5,
is_segmentation=True,
**kwargs,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
final_sigmoid=final_sigmoid,
basic_module=ExtResNetBlock,
f_maps=f_maps,
layer_order=layer_order,
num_groups=num_groups,
num_levels=num_levels,
is_segmentation=is_segmentation,
**kwargs,
)
def get_model(config): def get_model(config):
def _model_class(class_name): def _model_class(class_name):
m = importlib.import_module('pytorch3dunet.unet3d.model') m = importlib.import_module("pytorch3dunet.unet3d.model")
clazz = getattr(m, class_name) clazz = getattr(m, class_name)
return clazz return clazz
assert 'model' in config, 'Could not find model configuration' assert "model" in config, "Could not find model configuration"
model_config = config['model'] model_config = config["model"]
model_class = _model_class(model_config['name']) model_class = _model_class(model_config["name"])
return model_class(**model_config) return model_class(**model_config)
@ -542,18 +648,18 @@ if __name__ == "__main__":
out_channels = 1 out_channels = 1
f_maps = 32 f_maps = 32
num_levels = 2 num_levels = 2
model = UNet3D(in_channels, out_channels, f_maps=f_maps, num_levels=num_levels, layer_order='cr') model = UNet3D(in_channels, out_channels, f_maps=f_maps, num_levels=num_levels, layer_order="cr")
print(model) print(model)
print('number of parameters: ', sum(p.numel() for p in model.parameters())) print("number of parameters: ", sum(p.numel() for p in model.parameters()))
reso = 18 reso = 18
import numpy as np import numpy as np
import torch import torch
x = np.zeros((1, 1, reso, reso, reso)) x = np.zeros((1, 1, reso, reso, reso))
x[:,:, int(reso/2-1), int(reso/2-1), int(reso/2-1)] = np.nan x[:, :, int(reso / 2 - 1), int(reso / 2 - 1), int(reso / 2 - 1)] = np.nan
x = torch.FloatTensor(x) x = torch.FloatTensor(x)
out = model(x) out = model(x)
print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso*reso))) print("%f" % (torch.sum(torch.isnan(out)).detach().cpu().numpy() / (reso * reso * reso)))

View file

@ -1,8 +1,9 @@
""" Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """ """Positional encoding embedding. Code was taken from https://github.com/bmild/nerf."""
import torch import torch
import torch.nn as nn import torch.nn as nn
class Embedder: class Embedder:
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.kwargs = kwargs self.kwargs = kwargs
@ -10,24 +11,23 @@ class Embedder:
def create_embedding_fn(self): def create_embedding_fn(self):
embed_fns = [] embed_fns = []
d = self.kwargs['input_dims'] d = self.kwargs["input_dims"]
out_dim = 0 out_dim = 0
if self.kwargs['include_input']: if self.kwargs["include_input"]:
embed_fns.append(lambda x: x) embed_fns.append(lambda x: x)
out_dim += d out_dim += d
max_freq = self.kwargs['max_freq_log2'] max_freq = self.kwargs["max_freq_log2"]
N_freqs = self.kwargs['num_freqs'] N_freqs = self.kwargs["num_freqs"]
if self.kwargs['log_sampling']: if self.kwargs["log_sampling"]:
freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) freq_bands = 2.0 ** torch.linspace(0.0, max_freq, N_freqs)
else: else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, N_freqs)
for freq in freq_bands: for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']: for p_fn in self.kwargs["periodic_fns"]:
embed_fns.append(lambda x, p_fn=p_fn, embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
freq=freq: p_fn(x * freq))
out_dim += d out_dim += d
self.embed_fns = embed_fns self.embed_fns = embed_fns
@ -36,35 +36,40 @@ class Embedder:
def embed(self, inputs): def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1) return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, d_in=3): def get_embedder(multires, d_in=3):
embed_kwargs = { embed_kwargs = {
'include_input': True, "include_input": True,
'input_dims': d_in, "input_dims": d_in,
'max_freq_log2': multires-1, "max_freq_log2": multires - 1,
'num_freqs': multires, "num_freqs": multires,
'log_sampling': True, "log_sampling": True,
'periodic_fns': [torch.sin, torch.cos], "periodic_fns": [torch.sin, torch.cos],
} }
embedder_obj = Embedder(**embed_kwargs) embedder_obj = Embedder(**embed_kwargs)
def embed(x, eo=embedder_obj): return eo.embed(x)
def embed(x, eo=embedder_obj):
return eo.embed(x)
return embed, embedder_obj.out_dim return embed, embedder_obj.out_dim
def normalize_coordinate(p, plane='xz'):
''' Normalize coordinate to [0, 1] for unit cube experiments def normalize_coordinate(p, plane="xz"):
"""Normalize coordinate to [0, 1] for unit cube experiments.
Args: Args:
p (tensor): point p (tensor): point
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
plane (str): plane feature type, ['xz', 'xy', 'yz'] plane (str): plane feature type, ['xz', 'xy', 'yz']
''' """
if plane == 'xz': if plane == "xz":
xy = p[:, :, [0, 2]] xy = p[:, :, [0, 2]]
elif plane =='xy': elif plane == "xy":
xy = p[:, :, [0, 1]] xy = p[:, :, [0, 1]]
else: else:
xy = p[:, :, [1, 2]] xy = p[:, :, [1, 2]]
xy_new = xy xy_new = xy
# f there are outliers out of the range # f there are outliers out of the range
if xy_new.max() >= 1: if xy_new.max() >= 1:
@ -75,40 +80,41 @@ def normalize_coordinate(p, plane='xz'):
def normalize_3d_coordinate(p): def normalize_3d_coordinate(p):
''' Normalize coordinate to [0, 1] for unit cube experiments. """Normalize coordinate to [0, 1] for unit cube experiments."""
'''
if p.max() >= 1: if p.max() >= 1:
p[p >= 1] = 1 - 10e-6 p[p >= 1] = 1 - 10e-6
if p.min() < 0: if p.min() < 0:
p[p < 0] = 0.0 p[p < 0] = 0.0
return p return p
def coordinate2index(x, reso, coord_type='2d'):
''' Normalize coordinate to [0, 1] for unit cube experiments. def coordinate2index(x, reso, coord_type="2d"):
Corresponds to our 3D model """Normalize coordinate to [0, 1] for unit cube experiments.
Corresponds to our 3D model.
Args: Args:
x (tensor): coordinate x (tensor): coordinate
reso (int): defined resolution reso (int): defined resolution
coord_type (str): coordinate type coord_type (str): coordinate type
''' """
x = (x * reso).long() x = (x * reso).long()
if coord_type == '2d': # plane if coord_type == "2d": # plane
index = x[:, :, 0] + reso * x[:, :, 1] index = x[:, :, 0] + reso * x[:, :, 1]
elif coord_type == '3d': # grid elif coord_type == "3d": # grid
index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2]) index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2])
index = index[:, None, :] index = index[:, None, :]
return index return index
class map2local(object): class map2local:
''' Add new keys to the given input """Add new keys to the given input.
Args: Args:
s (float): the defined voxel size s (float): the defined voxel size
pos_encoding (str): method for the positional encoding, linear|sin_cos pos_encoding (str): method for the positional encoding, linear|sin_cos
''' """
def __init__(self, s, pos_encoding='linear'):
def __init__(self, s, pos_encoding="linear"):
super().__init__() super().__init__()
self.s = s self.s = s
# self.pe = positional_encoding(basis_function=pos_encoding, local=True) # self.pe = positional_encoding(basis_function=pos_encoding, local=True)
@ -121,15 +127,16 @@ class map2local(object):
# p = self.pe(p) # p = self.pe(p)
return p return p
# Resnet Blocks # Resnet Blocks
class ResnetBlockFC(nn.Module): class ResnetBlockFC(nn.Module):
''' Fully connected ResNet Block class. """Fully connected ResNet Block class.
Args: Args:
size_in (int): input dimension size_in (int): input dimension
size_out (int): output dimension size_out (int): output dimension
size_h (int): hidden dimension size_h (int): hidden dimension
''' """
def __init__(self, size_in, size_out=None, size_h=None, siren=False): def __init__(self, size_in, size_out=None, size_h=None, siren=False):
super().__init__() super().__init__()
@ -164,4 +171,4 @@ class ResnetBlockFC(nn.Module):
else: else:
x_s = x x_s = x
return x_s + dx return x_s + dx

View file

@ -1,112 +1,110 @@
import time, os import os
import numpy as np import numpy as np
import open3d as o3d
import torch import torch
from torch.nn import functional as F
import trimesh import trimesh
from pytorch3d.loss import chamfer_distance
from torchvision.io import write_video
from torchvision.utils import save_image
from src.dpsr import DPSR from src.dpsr import DPSR
from src.model import PSR2Mesh from src.model import PSR2Mesh
from src.utils import grid_interp, verts_on_largest_mesh,\ from src.utils import export_pointcloud, mc_from_psr, verts_on_largest_mesh
export_pointcloud, mc_from_psr, GaussianSmoothing from src.visualize import (
from src.visualize import visualize_points_mesh, visualize_psr_grid, \ render_rgb,
visualize_mesh_phong, render_rgb visualize_mesh_phong,
from torchvision.utils import save_image visualize_points_mesh,
from torchvision.io import write_video visualize_psr_grid,
from pytorch3d.loss import chamfer_distance )
import open3d as o3d
class Trainer(object):
''' class Trainer:
Args: """Args:
cfg : config file cfg : config file
optimizer : pytorch optimizer object optimizer : pytorch optimizer object
device : pytorch device device : pytorch device.
''' """
def __init__(self, cfg, optimizer, device=None): def __init__(self, cfg, optimizer, device=None):
self.optimizer = optimizer self.optimizer = optimizer
self.device = device self.device = device
self.cfg = cfg self.cfg = cfg
self.psr2mesh = PSR2Mesh.apply self.psr2mesh = PSR2Mesh.apply
self.data_type = cfg['data']['data_type'] self.data_type = cfg["data"]["data_type"]
# initialize DPSR # initialize DPSR
self.dpsr = DPSR(res=(cfg['model']['grid_res'], self.dpsr = DPSR(
cfg['model']['grid_res'], res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
cfg['model']['grid_res']), sig=cfg["model"]["psr_sigma"],
sig=cfg['model']['psr_sigma']) )
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR
self.dpsr = self.dpsr.to(device) self.dpsr = self.dpsr.to(device)
def train_step(self, data, inputs, model, it): def train_step(self, data, inputs, model, it):
''' Performs a training step. """Performs a training step.
Args: Args:
data (dict) : data dictionary data (dict) : data dictionary
inputs (torch.tensor) : input point clouds inputs (torch.tensor) : input point clouds
model (nn.Module or None): a neural network or None model (nn.Module or None): a neural network or None
it (int) : the number of iterations it (int) : the number of iterations
''' """
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss, loss_each = self.compute_loss(inputs, data, model, it) loss, loss_each = self.compute_loss(inputs, data, model, it)
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
return loss.item(), loss_each return loss.item(), loss_each
def compute_loss(self, inputs, data, model, it=0): def compute_loss(self, inputs, data, model, it=0):
''' Compute the loss. """Compute the loss.
Args: Args:
data (dict) : data dictionary data (dict) : data dictionary
inputs (torch.tensor) : input point clouds inputs (torch.tensor) : input point clouds
model (nn.Module or None): a neural network or None model (nn.Module or None): a neural network or None
it (int) : the number of iterations it (int) : the number of iterations.
''' """
res = self.cfg["model"]["grid_res"]
device = self.device
res = self.cfg['model']['grid_res']
# source oriented point clouds to PSR grid # source oriented point clouds to PSR grid
psr_grid, points, normals = self.pcl2psr(inputs) psr_grid, points, normals = self.pcl2psr(inputs)
# build mesh # build mesh
v, f, n = self.psr2mesh(psr_grid) v, f, n = self.psr2mesh(psr_grid)
# the output is in the range of [0, 1), we make it to the real range [0, 1].
# This is a hack for our DPSR solver
v = v * res / (res-1)
points = points * 2. - 1. # the output is in the range of [0, 1), we make it to the real range [0, 1].
v = v * 2. - 1. # within the range of (-1, 1) # This is a hack for our DPSR solver
v = v * res / (res - 1)
points = points * 2.0 - 1.0
v = v * 2.0 - 1.0 # within the range of (-1, 1)
loss = 0 loss = 0
loss_each = {} loss_each = {}
# compute loss # compute loss
if self.data_type == 'point': if self.data_type == "point":
if self.cfg['train']['w_chamfer'] > 0: if self.cfg["train"]["w_chamfer"] > 0:
loss_ = self.cfg['train']['w_chamfer'] * \ loss_ = self.cfg["train"]["w_chamfer"] * self.compute_3d_loss(v, data)
self.compute_3d_loss(v, data) loss_each["chamfer"] = loss_
loss_each['chamfer'] = loss_
loss += loss_ loss += loss_
elif self.data_type == 'img': elif self.data_type == "img":
loss, loss_each = self.compute_2d_loss(inputs, data, model) loss, loss_each = self.compute_2d_loss(inputs, data, model)
return loss, loss_each return loss, loss_each
def pcl2psr(self, inputs): def pcl2psr(self, inputs):
''' Convert an oriented point cloud to PSR indicator grid """Convert an oriented point cloud to PSR indicator grid
Args: Args:
inputs (torch.tensor): input oriented point clouds inputs (torch.tensor): input oriented point clouds.
''' """
points, normals = inputs[..., :3], inputs[..., 3:]
points, normals = inputs[...,:3], inputs[...,3:] if self.cfg["model"]["apply_sigmoid"]:
if self.cfg['model']['apply_sigmoid']:
points = torch.sigmoid(points) points = torch.sigmoid(points)
if self.cfg['model']['normal_normalize']: if self.cfg["model"]["normal_normalize"]:
normals = normals / normals.norm(dim=-1, keepdim=True) normals = normals / normals.norm(dim=-1, keepdim=True)
# DPSR to get grid # DPSR to get grid
@ -116,53 +114,50 @@ class Trainer(object):
return psr_grid, points, normals return psr_grid, points, normals
def compute_3d_loss(self, v, data): def compute_3d_loss(self, v, data):
''' Compute the loss for point clouds. """Compute the loss for point clouds.
Args: Args:
v (torch.tensor) : mesh vertices v (torch.tensor) : mesh vertices
data (dict) : data dictionary data (dict) : data dictionary.
''' """
pts_gt = data.get("target_points")
pts_gt = data.get('target_points') idx = np.random.randint(pts_gt.shape[1], size=self.cfg["train"]["n_sup_point"])
idx = np.random.randint(pts_gt.shape[1], size=self.cfg['train']['n_sup_point']) if self.cfg["train"]["subsample_vertex"]:
if self.cfg['train']['subsample_vertex']: # chamfer distance only on random sampled vertices
#chamfer distance only on random sampled vertices idx = np.random.randint(v.shape[1], size=self.cfg["train"]["n_sup_point"])
idx = np.random.randint(v.shape[1], size=self.cfg['train']['n_sup_point'])
loss, _ = chamfer_distance(v[:, idx], pts_gt) loss, _ = chamfer_distance(v[:, idx], pts_gt)
else: else:
loss, _ = chamfer_distance(v, pts_gt) loss, _ = chamfer_distance(v, pts_gt)
return loss return loss
def compute_2d_loss(self, inputs, data, model): def compute_2d_loss(self, inputs, data, model):
''' Compute the 2D losses. """Compute the 2D losses.
Args: Args:
inputs (torch.tensor) : input source point clouds inputs (torch.tensor) : input source point clouds
data (dict) : data dictionary data (dict) : data dictionary
model (nn.Module or None): neural network or None model (nn.Module or None): neural network or None.
''' """
losses = {
losses = {"color": "color": {
{"weight": self.cfg['train']['l_weight']['rgb'], "weight": self.cfg["train"]["l_weight"]["rgb"],
"values": [] "values": [],
}, },
"silhouette": "silhouette": {"weight": self.cfg["train"]["l_weight"]["mask"], "values": []},
{"weight": self.cfg['train']['l_weight']['mask'], }
"values": []},
}
loss_all = {k: torch.tensor(0.0, device=self.device) for k in losses} loss_all = {k: torch.tensor(0.0, device=self.device) for k in losses}
# forward pass # forward pass
out = model(inputs, data) out = model(inputs, data)
if out['rgb'] is not None: if out["rgb"] is not None:
rgb_gt = out['rgb_gt'].reshape(self.cfg['data']['n_views_per_iter'], rgb_gt = out["rgb_gt"].reshape(self.cfg["data"]["n_views_per_iter"], -1, 3)[out["vis_mask"]]
-1, 3)[out['vis_mask']] loss_all["color"] += torch.nn.L1Loss(reduction="sum")(rgb_gt, out["rgb"]) / out["rgb"].shape[0]
loss_all["color"] += torch.nn.L1Loss(reduction='sum')(rgb_gt,
out['rgb']) / out['rgb'].shape[0] if out["mask"] is not None:
loss_all["silhouette"] += ((out["mask"] - out["mask_gt"]) ** 2).mean()
if out['mask'] is not None:
loss_all["silhouette"] += ((out['mask'] - out['mask_gt']) ** 2).mean()
# weighted sum of the losses # weighted sum of the losses
loss = torch.tensor(0.0, device=self.device) loss = torch.tensor(0.0, device=self.device)
for k, l in loss_all.items(): for k, l in loss_all.items():
@ -172,154 +167,146 @@ class Trainer(object):
return loss, loss_all return loss, loss_all
def point_resampling(self, inputs): def point_resampling(self, inputs):
''' Resample points """Resample points
Args: Args:
inputs (torch.tensor): oriented point clouds inputs (torch.tensor): oriented point clouds.
''' """
psr_grid, points, normals = self.pcl2psr(inputs) psr_grid, points, normals = self.pcl2psr(inputs)
# shortcuts
n_grow = self.cfg['train']['n_grow_points']
# [hack] for points resampled from the mesh from marching cubes, # shortcuts
n_grow = self.cfg["train"]["n_grow_points"]
# [hack] for points resampled from the mesh from marching cubes,
# we need to divide by s instead of (s-1), and the scale is correct. # we need to divide by s instead of (s-1), and the scale is correct.
verts, faces, _ = mc_from_psr(psr_grid, real_scale=False, zero_level=0) verts, faces, _ = mc_from_psr(psr_grid, real_scale=False, zero_level=0)
# find the largest component # find the largest component
pts_mesh, faces_mesh = verts_on_largest_mesh(verts, faces) pts_mesh, faces_mesh = verts_on_largest_mesh(verts, faces)
# sample vertices only from the largest component, not from fragments # sample vertices only from the largest component, not from fragments
mesh = trimesh.Trimesh(vertices=pts_mesh, faces=faces_mesh) mesh = trimesh.Trimesh(vertices=pts_mesh, faces=faces_mesh)
pi, face_idx = mesh.sample(n_grow+points.shape[1], return_index=True) pi, face_idx = mesh.sample(n_grow + points.shape[1], return_index=True)
normals_i = mesh.face_normals[face_idx].astype('float32') normals_i = mesh.face_normals[face_idx].astype("float32")
pts_mesh = torch.tensor(pi.astype('float32')).to(self.device)[None] pts_mesh = torch.tensor(pi.astype("float32")).to(self.device)[None]
n_mesh = torch.tensor(normals_i).to(self.device)[None] n_mesh = torch.tensor(normals_i).to(self.device)[None]
points, normals = pts_mesh, n_mesh points, normals = pts_mesh, n_mesh
print('{} total points are resampled'.format(points.shape[1])) print(f"{points.shape[1]} total points are resampled")
# update inputs # update inputs
points = torch.log(points / (1 - points)) # inverse sigmoid points = torch.log(points / (1 - points)) # inverse sigmoid
inputs = torch.cat([points, normals], dim=-1) inputs = torch.cat([points, normals], dim=-1)
inputs.requires_grad = True inputs.requires_grad = True
return inputs return inputs
def visualize(self, data, inputs, renderer, epoch, o3d_vis=None): def visualize(self, data, inputs, renderer, epoch, o3d_vis=None):
''' Visualization. """Visualization.
Args: Args:
data (dict) : data dictionary data (dict) : data dictionary
inputs (torch.tensor) : source point clouds inputs (torch.tensor) : source point clouds
renderer (nn.Module or None): a neural network or None renderer (nn.Module or None): a neural network or None
epoch (int) : the number of iterations epoch (int) : the number of iterations
o3d_vis (o3d.Visualizer) : open3d visualizer o3d_vis (o3d.Visualizer) : open3d visualizer.
''' """
data_type = self.cfg["data"]["data_type"]
data_type = self.cfg['data']['data_type'] it = "{:04d}".format(int(epoch / self.cfg["train"]["visualize_every"]))
it = '{:04d}'.format(int(epoch/self.cfg['train']['visualize_every']))
if (self.cfg["train"]["exp_mesh"]) | (self.cfg["train"]["exp_pcl"]) | (self.cfg["train"]["o3d_show"]):
if (self.cfg['train']['exp_mesh']) \
| (self.cfg['train']['exp_pcl']) \
| (self.cfg['train']['o3d_show']):
psr_grid, points, normals = self.pcl2psr(inputs) psr_grid, points, normals = self.pcl2psr(inputs)
with torch.no_grad(): with torch.no_grad():
v, f, n = mc_from_psr(psr_grid, pytorchify=True, v, f, n = mc_from_psr(
zero_level=self.cfg['data']['zero_level'], real_scale=True) psr_grid, pytorchify=True, zero_level=self.cfg["data"]["zero_level"], real_scale=True,
)
v, f, n = v[None], f[None], n[None] v, f, n = v[None], f[None], n[None]
v = v * 2. - 1. # change to the range of [-1, 1] v = v * 2.0 - 1.0 # change to the range of [-1, 1]
color_v = None color_v = None
if data_type == 'img': if data_type == "img":
if self.cfg['train']['vis_vert_color'] & \ if self.cfg["train"]["vis_vert_color"] & (self.cfg["train"]["l_weight"]["rgb"] != 0.0):
(self.cfg['train']['l_weight']['rgb'] != 0.): color_v = renderer["color"](v, n).squeeze().detach().cpu().numpy()
color_v = renderer['color'](v, n).squeeze().detach().cpu().numpy() color_v[color_v < 0], color_v[color_v > 1] = 0.0, 1.0
color_v[color_v<0], color_v[color_v>1] = 0., 1.
vv = v.detach().squeeze().cpu().numpy() vv = v.detach().squeeze().cpu().numpy()
ff = f.detach().squeeze().cpu().numpy() ff = f.detach().squeeze().cpu().numpy()
points = points * 2 - 1 points = points * 2 - 1
visualize_points_mesh(o3d_vis, points, normals, visualize_points_mesh(o3d_vis, points, normals, vv, ff, self.cfg, it, epoch, color_v=color_v)
vv, ff, self.cfg, it, epoch, color_v=color_v)
else: else:
v, f, n = inputs v, f, n = inputs
if (data_type == "img") & (self.cfg["train"]["vis_rendering"]):
if (data_type == 'img') & (self.cfg['train']['vis_rendering']):
pred_imgs = [] pred_imgs = []
pred_masks = [] len(data["poses"])
n_views = len(data['poses'])
# idx_list = trange(n_views) # idx_list = trange(n_views)
idx_list = [13, 24, 27, 48] idx_list = [13, 24, 27, 48]
#! #!
model = renderer.eval() model = renderer.eval()
for idx in idx_list: for idx in idx_list:
pose = data['poses'][idx] pose = data["poses"][idx]
rgb = data['rgbs'][idx] rgb = data["rgbs"][idx]
mask_gt = data['masks'][idx] data["masks"][idx]
img_size = rgb.shape[0] if rgb.shape[0]== rgb.shape[1] else (rgb.shape[0], rgb.shape[1]) img_size = rgb.shape[0] if rgb.shape[0] == rgb.shape[1] else (rgb.shape[0], rgb.shape[1])
ray = None ray = None
if 'rays' in data.keys(): if "rays" in data.keys():
ray = data['rays'][idx] ray = data["rays"][idx]
if self.cfg['train']['l_weight']['rgb'] != 0.: if self.cfg["train"]["l_weight"]["rgb"] != 0.0:
fea_grid = None fea_grid = None
if model.unet3d is not None: if model.unet3d is not None:
with torch.no_grad(): with torch.no_grad():
fea_grid = model.unet3d(psr_grid).permute(0, 2, 3, 4, 1) fea_grid = model.unet3d(psr_grid).permute(0, 2, 3, 4, 1)
if model.encoder is not None: if model.encoder is not None:
pp = torch.cat([(points+1)/2, normals], dim=-1) pp = torch.cat([(points + 1) / 2, normals], dim=-1)
fea_grid = model.encoder(pp, fea_grid = model.encoder(pp, normalize=False).permute(0, 2, 3, 4, 1)
normalize=False).permute(0, 2, 3, 4, 1)
pred, visible_mask = render_rgb(v, f, n, pose, pred, visible_mask = render_rgb(
model.rendering_network.eval(), v, f, n, pose, model.rendering_network.eval(), img_size, ray=ray, fea_grid=fea_grid,
img_size, ray=ray, fea_grid=fea_grid) )
img_pred = torch.zeros([rgb.shape[0]*rgb.shape[1], 3]) img_pred = torch.zeros([rgb.shape[0] * rgb.shape[1], 3])
img_pred[visible_mask] = pred.detach().cpu() img_pred[visible_mask] = pred.detach().cpu()
img_pred = img_pred.reshape(rgb.shape[0], rgb.shape[1], 3) img_pred = img_pred.reshape(rgb.shape[0], rgb.shape[1], 3)
img_pred[img_pred<0], img_pred[img_pred>1] = 0., 1. img_pred[img_pred < 0], img_pred[img_pred > 1] = 0.0, 1.0
filename=os.path.join(self.cfg['train']['dir_rendering'], filename = os.path.join(self.cfg["train"]["dir_rendering"], f"rendering_{it}_{idx:d}.png")
'rendering_{}_{:d}.png'.format(it, idx))
save_image(img_pred.permute(2, 0, 1), filename) save_image(img_pred.permute(2, 0, 1), filename)
pred_imgs.append(img_pred) pred_imgs.append(img_pred)
#! Mesh rendering using Phong shading model #! Mesh rendering using Phong shading model
filename=os.path.join(self.cfg['train']['dir_rendering'], filename = os.path.join(self.cfg["train"]["dir_rendering"], f"mesh_{it}_{idx:d}.png")
'mesh_{}_{:d}.png'.format(it, idx))
visualize_mesh_phong(v, f, n, pose, img_size, name=filename) visualize_mesh_phong(v, f, n, pose, img_size, name=filename)
if len(pred_imgs) >= 1: if len(pred_imgs) >= 1:
pred_imgs = torch.stack(pred_imgs, dim=0) pred_imgs = torch.stack(pred_imgs, dim=0)
save_image(pred_imgs.permute(0, 3, 1, 2), save_image(
os.path.join(self.cfg['train']['dir_rendering'], pred_imgs.permute(0, 3, 1, 2), os.path.join(self.cfg["train"]["dir_rendering"], f"{it}.png"), nrow=4,
'{}.png'.format(it)), nrow=4) )
if self.cfg['train']['save_video']: if self.cfg["train"]["save_video"]:
write_video(os.path.join(self.cfg['train']['dir_rendering'], write_video(
'{}.mp4'.format(it)), os.path.join(self.cfg["train"]["dir_rendering"], f"{it}.mp4"),
(pred_imgs*255.).type(torch.uint8), fps=24) (pred_imgs * 255.0).type(torch.uint8),
fps=24,
)
def save_mesh_pointclouds(self, inputs, epoch, center=None, scale=None): def save_mesh_pointclouds(self, inputs, epoch, center=None, scale=None):
''' Save meshes and point clouds. """Save meshes and point clouds.
Args: Args:
inputs (torch.tensor) : source point clouds inputs (torch.tensor) : source point clouds
epoch (int) : the number of iterations epoch (int) : the number of iterations
center (numpy.array) : center of the shape center (numpy.array) : center of the shape
scale (numpy.array) : scale of the shape scale (numpy.array) : scale of the shape.
''' """
exp_pcl = self.cfg["train"]["exp_pcl"]
exp_mesh = self.cfg["train"]["exp_mesh"]
exp_pcl = self.cfg['train']['exp_pcl']
exp_mesh = self.cfg['train']['exp_mesh']
psr_grid, points, normals = self.pcl2psr(inputs) psr_grid, points, normals = self.pcl2psr(inputs)
if exp_pcl: if exp_pcl:
dir_pcl = self.cfg['train']['dir_pcl'] dir_pcl = self.cfg["train"]["dir_pcl"]
p = points.squeeze(0).detach().cpu().numpy() p = points.squeeze(0).detach().cpu().numpy()
p = p * 2 - 1 p = p * 2 - 1
n = normals.squeeze(0).detach().cpu().numpy() n = normals.squeeze(0).detach().cpu().numpy()
@ -327,12 +314,11 @@ class Trainer(object):
p *= scale p *= scale
if center is not None: if center is not None:
p += center p += center
export_pointcloud(os.path.join(dir_pcl, '{:04d}.ply'.format(epoch)), p, n) export_pointcloud(os.path.join(dir_pcl, f"{epoch:04d}.ply"), p, n)
if exp_mesh: if exp_mesh:
dir_mesh = self.cfg['train']['dir_mesh'] dir_mesh = self.cfg["train"]["dir_mesh"]
with torch.no_grad(): with torch.no_grad():
v, f, _ = mc_from_psr(psr_grid, v, f, _ = mc_from_psr(psr_grid, zero_level=self.cfg["data"]["zero_level"], real_scale=True)
zero_level=self.cfg['data']['zero_level'], real_scale=True)
v = v * 2 - 1 v = v * 2 - 1
if scale is not None: if scale is not None:
v *= scale v *= scale
@ -341,9 +327,9 @@ class Trainer(object):
mesh = o3d.geometry.TriangleMesh() mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(v) mesh.vertices = o3d.utility.Vector3dVector(v)
mesh.triangles = o3d.utility.Vector3iVector(f) mesh.triangles = o3d.utility.Vector3iVector(f)
outdir_mesh = os.path.join(dir_mesh, '{:04d}.ply'.format(epoch)) outdir_mesh = os.path.join(dir_mesh, f"{epoch:04d}.ply")
o3d.io.write_triangle_mesh(outdir_mesh, mesh) o3d.io.write_triangle_mesh(outdir_mesh, mesh)
if self.cfg['train']['vis_psr']: if self.cfg["train"]["vis_psr"]:
dir_psr_vis = self.cfg['train']['out_dir']+'/psr_vis_all' dir_psr_vis = self.cfg["train"]["out_dir"] + "/psr_vis_all"
visualize_psr_grid(psr_grid, out_dir=dir_psr_vis) visualize_psr_grid(psr_grid, out_dir=dir_psr_vis)

View file

@ -1,75 +1,81 @@
import os import os
from collections import defaultdict
import numpy as np import numpy as np
import torch import torch
from pytorch3d.loss import chamfer_distance
from pytorch3d.ops.knn import knn_gather, knn_points
from torch.nn import functional as F from torch.nn import functional as F
from collections import defaultdict
import trimesh
from tqdm import tqdm from tqdm import tqdm
from src.dpsr import DPSR from src.dpsr import DPSR
from src.utils import grid_interp, export_pointcloud, export_mesh, \ from src.utils import (
mc_from_psr, scale2onet, GaussianSmoothing GaussianSmoothing,
from pytorch3d.ops.knn import knn_gather, knn_points export_mesh,
from pytorch3d.loss import chamfer_distance export_pointcloud,
from pdb import set_trace as st mc_from_psr,
scale2onet,
)
class Trainer:
"""Args:
model (nn.Module): our defined model
optimizer (optimizer): pytorch optimizer object
device (device): pytorch device
input_type (str): input type
vis_dir (str): visualization directory.
"""
class Trainer(object):
'''
Args:
model (nn.Module): our defined model
optimizer (optimizer): pytorch optimizer object
device (device): pytorch device
input_type (str): input type
vis_dir (str): visualization directory
'''
def __init__(self, cfg, optimizer, device=None): def __init__(self, cfg, optimizer, device=None):
self.optimizer = optimizer self.optimizer = optimizer
self.device = device self.device = device
self.cfg = cfg self.cfg = cfg
if self.cfg['train']['w_raw'] != 0: if self.cfg["train"]["w_raw"] != 0:
from src.model import PSR2Mesh from src.model import PSR2Mesh
self.psr2mesh = PSR2Mesh.apply self.psr2mesh = PSR2Mesh.apply
# initialize DPSR # initialize DPSR
self.dpsr = DPSR(res=(cfg['model']['grid_res'], self.dpsr = DPSR(
cfg['model']['grid_res'], res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
cfg['model']['grid_res']), sig=cfg["model"]["psr_sigma"],
sig=cfg['model']['psr_sigma']) )
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR
self.dpsr = self.dpsr.to(device) self.dpsr = self.dpsr.to(device)
if cfg['train']['gauss_weight']>0.: if cfg["train"]["gauss_weight"] > 0.0:
self.gauss_smooth = GaussianSmoothing(1, 7, 2).to(device) self.gauss_smooth = GaussianSmoothing(1, 7, 2).to(device)
def train_step(self, inputs, data, model): def train_step(self, inputs, data, model):
''' Performs a training step. """Performs a training step.
Args: Args:
data (dict): data dictionary data (dict): data dictionary
''' """
self.optimizer.zero_grad() self.optimizer.zero_grad()
p = data.get('inputs').to(self.device) p = data.get("inputs").to(self.device)
out = model(p) out = model(p)
points, normals = out points, normals = out
loss = 0 loss = 0
loss_each = {} loss_each = {}
if self.cfg['train']['w_psr'] != 0: if self.cfg["train"]["w_psr"] != 0:
psr_gt = data.get('gt_psr').to(self.device) psr_gt = data.get("gt_psr").to(self.device)
if self.cfg['model']['psr_tanh']: if self.cfg["model"]["psr_tanh"]:
psr_gt = torch.tanh(psr_gt) psr_gt = torch.tanh(psr_gt)
psr_grid = self.dpsr(points, normals) psr_grid = self.dpsr(points, normals)
if self.cfg['model']['psr_tanh']: if self.cfg["model"]["psr_tanh"]:
psr_grid = torch.tanh(psr_grid) psr_grid = torch.tanh(psr_grid)
# apply a rescaling weight based on GT SDF values # apply a rescaling weight based on GT SDF values
if self.cfg['train']['gauss_weight']>0: if self.cfg["train"]["gauss_weight"] > 0:
gauss_sigma = self.cfg['train']['gauss_weight'] self.cfg["train"]["gauss_weight"]
# set up the weighting for loss, higher weights # set up the weighting for loss, higher weights
# for points near to the surface # for points near to the surface
psr_gt_pad = torch.nn.ReplicationPad3d(1)(psr_gt.unsqueeze(1)).squeeze(1) psr_gt_pad = torch.nn.ReplicationPad3d(1)(psr_gt.unsqueeze(1)).squeeze(1)
delta_x = delta_y = delta_z = 1 delta_x = delta_y = delta_z = 1
@ -82,97 +88,97 @@ class Trainer(object):
psr_grad = torch.stack([grad_x, grad_y, grad_z], dim=-1) psr_grad = torch.stack([grad_x, grad_y, grad_z], dim=-1)
psr_grad_norm = psr_grad.norm(dim=-1)[:, None] psr_grad_norm = psr_grad.norm(dim=-1)[:, None]
w = torch.nn.ReplicationPad3d(3)(psr_grad_norm) w = torch.nn.ReplicationPad3d(3)(psr_grad_norm)
w = 2*self.gauss_smooth(w).squeeze(1) w = 2 * self.gauss_smooth(w).squeeze(1)
loss_each['psr'] = self.cfg['train']['w_psr'] * F.mse_loss(w*psr_grid, w*psr_gt) loss_each["psr"] = self.cfg["train"]["w_psr"] * F.mse_loss(w * psr_grid, w * psr_gt)
else: else:
loss_each['psr'] = self.cfg['train']['w_psr'] * F.mse_loss(psr_grid, psr_gt) loss_each["psr"] = self.cfg["train"]["w_psr"] * F.mse_loss(psr_grid, psr_gt)
loss += loss_each['psr'] loss += loss_each["psr"]
# regularization on the input point positions via chamfer distance # regularization on the input point positions via chamfer distance
if self.cfg['train']['w_reg_point'] != 0.: if self.cfg["train"]["w_reg_point"] != 0.0:
points_gt = data.get('gt_points').to(self.device) points_gt = data.get("gt_points").to(self.device)
loss_reg, loss_norm = chamfer_distance(points, points_gt) loss_reg, loss_norm = chamfer_distance(points, points_gt)
loss_each['reg'] = self.cfg['train']['w_reg_point'] * loss_reg loss_each["reg"] = self.cfg["train"]["w_reg_point"] * loss_reg
loss += loss_each['reg'] loss += loss_each["reg"]
if self.cfg['train']['w_normals'] != 0.: if self.cfg["train"]["w_normals"] != 0.0:
points_gt = data.get('gt_points').to(self.device) points_gt = data.get("gt_points").to(self.device)
normals_gt = data.get('gt_points.normals').to(self.device) normals_gt = data.get("gt_points.normals").to(self.device)
x_nn = knn_points(points, points_gt, K=1) x_nn = knn_points(points, points_gt, K=1)
x_normals_near = knn_gather(normals_gt, x_nn.idx)[..., 0, :] x_normals_near = knn_gather(normals_gt, x_nn.idx)[..., 0, :]
cham_norm_x = F.l1_loss(normals, x_normals_near) cham_norm_x = F.l1_loss(normals, x_normals_near)
loss_norm = cham_norm_x loss_norm = cham_norm_x
loss_each['normals'] = self.cfg['train']['w_normals'] * loss_norm loss_each["normals"] = self.cfg["train"]["w_normals"] * loss_norm
loss += loss_each['normals'] loss += loss_each["normals"]
if self.cfg['train']['w_raw'] != 0: if self.cfg["train"]["w_raw"] != 0:
res = self.cfg['model']['grid_res'] self.cfg["model"]["grid_res"]
# DPSR to get grid # DPSR to get grid
psr_grid = self.dpsr(points, normals) psr_grid = self.dpsr(points, normals)
if self.cfg['model']['psr_tanh']: if self.cfg["model"]["psr_tanh"]:
psr_grid = torch.tanh(psr_grid) psr_grid = torch.tanh(psr_grid)
v, f, n = self.psr2mesh(psr_grid) v, f, n = self.psr2mesh(psr_grid)
pts_gt = data.get('gt_points').to(self.device) pts_gt = data.get("gt_points").to(self.device)
loss, _ = chamfer_distance(v, pts_gt) loss, _ = chamfer_distance(v, pts_gt)
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
return loss.item(), loss_each return loss.item(), loss_each
def save(self, model, data, epoch, id):
p = data.get('inputs').to(self.device)
exp_pcl = self.cfg['train']['exp_pcl'] def save(self, model, data, epoch, id):
exp_mesh = self.cfg['train']['exp_mesh'] p = data.get("inputs").to(self.device)
exp_gt = self.cfg['generation']['exp_gt']
exp_input = self.cfg['generation']['exp_input'] exp_pcl = self.cfg["train"]["exp_pcl"]
exp_mesh = self.cfg["train"]["exp_mesh"]
exp_gt = self.cfg["generation"]["exp_gt"]
exp_input = self.cfg["generation"]["exp_input"]
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
points, normals = model(p) points, normals = model(p)
if exp_gt: if exp_gt:
points_gt = data.get('gt_points').to(self.device) points_gt = data.get("gt_points").to(self.device)
normals_gt = data.get('gt_points.normals').to(self.device) normals_gt = data.get("gt_points.normals").to(self.device)
if exp_pcl: if exp_pcl:
dir_pcl = self.cfg['train']['dir_pcl'] dir_pcl = self.cfg["train"]["dir_pcl"]
export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}.ply'.format(epoch, id)), scale2onet(points), normals) export_pointcloud(os.path.join(dir_pcl, f"{epoch:04d}_{id:01d}.ply"), scale2onet(points), normals)
if exp_gt: if exp_gt:
export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}_oracle.ply'.format(epoch, id)), scale2onet(points_gt), normals_gt) export_pointcloud(
os.path.join(dir_pcl, f"{epoch:04d}_{id:01d}_oracle.ply"), scale2onet(points_gt), normals_gt,
)
if exp_input: if exp_input:
export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}_input.ply'.format(epoch, id)), scale2onet(p)) export_pointcloud(os.path.join(dir_pcl, f"{epoch:04d}_{id:01d}_input.ply"), scale2onet(p))
if exp_mesh: if exp_mesh:
dir_mesh = self.cfg['train']['dir_mesh'] dir_mesh = self.cfg["train"]["dir_mesh"]
psr_grid = self.dpsr(points, normals) psr_grid = self.dpsr(points, normals)
# psr_grid = torch.tanh(psr_grid) # psr_grid = torch.tanh(psr_grid)
with torch.no_grad(): with torch.no_grad():
v, f, _ = mc_from_psr(psr_grid, v, f, _ = mc_from_psr(psr_grid, zero_level=self.cfg["data"]["zero_level"])
zero_level=self.cfg['data']['zero_level']) outdir_mesh = os.path.join(dir_mesh, f"{epoch:04d}_{id:01d}.ply")
outdir_mesh = os.path.join(dir_mesh, '{:04d}_{:01d}.ply'.format(epoch, id))
export_mesh(outdir_mesh, scale2onet(v), f) export_mesh(outdir_mesh, scale2onet(v), f)
if exp_gt: if exp_gt:
psr_gt = self.dpsr(points_gt, normals_gt) psr_gt = self.dpsr(points_gt, normals_gt)
with torch.no_grad(): with torch.no_grad():
v, f, _ = mc_from_psr(psr_gt, v, f, _ = mc_from_psr(psr_gt, zero_level=self.cfg["data"]["zero_level"])
zero_level=self.cfg['data']['zero_level']) export_mesh(os.path.join(dir_mesh, f"{epoch:04d}_{id:01d}_oracle.ply"), scale2onet(v), f)
export_mesh(os.path.join(dir_mesh, '{:04d}_{:01d}_oracle.ply'.format(epoch, id)), scale2onet(v), f)
def evaluate(self, val_loader, model): def evaluate(self, val_loader, model):
''' Performs an evaluation. """Performs an evaluation.
Args: Args:
val_loader (dataloader): pytorch dataloader val_loader (dataloader): pytorch dataloader.
''' """
eval_list = defaultdict(list) eval_list = defaultdict(list)
for data in tqdm(val_loader): for data in tqdm(val_loader):
@ -183,25 +189,26 @@ class Trainer(object):
eval_dict = {k: np.mean(v) for k, v in eval_list.items()} eval_dict = {k: np.mean(v) for k, v in eval_list.items()}
return eval_dict return eval_dict
def eval_step(self, data, model): def eval_step(self, data, model):
''' Performs an evaluation step. """Performs an evaluation step.
Args: Args:
data (dict): data dictionary data (dict): data dictionary.
''' """
model.eval() model.eval()
eval_dict = {} eval_dict = {}
p = data.get('inputs').to(self.device) p = data.get("inputs").to(self.device)
psr_gt = data.get('gt_psr').to(self.device) psr_gt = data.get("gt_psr").to(self.device)
with torch.no_grad(): with torch.no_grad():
# forward pass # forward pass
points, normals = model(p) points, normals = model(p)
# DPSR to get predicted psr grid # DPSR to get predicted psr grid
psr_grid = self.dpsr(points, normals) psr_grid = self.dpsr(points, normals)
eval_dict['psr_l1'] = F.l1_loss(psr_grid, psr_gt).item() eval_dict["psr_l1"] = F.l1_loss(psr_grid, psr_gt).item()
eval_dict['psr_l2'] = F.mse_loss(psr_grid, psr_gt).item() eval_dict["psr_l2"] = F.mse_loss(psr_grid, psr_gt).item()
return eval_dict return eval_dict

View file

@ -1,53 +1,53 @@
import torch import logging
import io, os, logging, urllib
import yaml
import trimesh
import imageio
import numbers
import math import math
import numpy as np import numbers
import os
import urllib
from collections import OrderedDict from collections import OrderedDict
import numpy as np
import open3d as o3d
import torch
import trimesh
import yaml
from igl import adjacency_matrix, connected_components
from plyfile import PlyData from plyfile import PlyData
from pytorch3d.renderer import PerspectiveCameras, rasterize_meshes
from pytorch3d.structures import Meshes
from skimage import measure
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torch.utils import model_zoo from torch.utils import model_zoo
from skimage import measure, img_as_float32
from pytorch3d.structures import Meshes
from pytorch3d.renderer import PerspectiveCameras, rasterize_meshes
from igl import adjacency_matrix, connected_components
import open3d as o3d
################################################## ##################################################
# Below are functions for DPSR # Below are functions for DPSR
def fftfreqs(res, dtype=torch.float32, exact=True): def fftfreqs(res, dtype=torch.float32, exact=True):
""" """Helper function to return frequency tensors
Helper function to return frequency tensors
:param res: n_dims int tuple of number of frequency modes :param res: n_dims int tuple of number of frequency modes
:return: :return:
""" """
n_dims = len(res) n_dims = len(res)
freqs = [] freqs = []
for dim in range(n_dims - 1): for dim in range(n_dims - 1):
r_ = res[dim] r_ = res[dim]
freq = np.fft.fftfreq(r_, d=1/r_) freq = np.fft.fftfreq(r_, d=1 / r_)
freqs.append(torch.tensor(freq, dtype=dtype)) freqs.append(torch.tensor(freq, dtype=dtype))
r_ = res[-1] r_ = res[-1]
if exact: if exact:
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype)) freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1 / r_), dtype=dtype))
else: else:
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype)) freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1 / r_)[:-1], dtype=dtype))
omega = torch.meshgrid(freqs) omega = torch.meshgrid(freqs)
omega = list(omega) omega = list(omega)
omega = torch.stack(omega, dim=-1) omega = torch.stack(omega, dim=-1)
return omega return omega
def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag)
""" def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag)
multiply tensor x by i ** deg """Multiply tensor x by i ** deg."""
"""
deg %= 4 deg %= 4
if deg == 0: if deg == 0:
res = x res = x
@ -61,17 +61,18 @@ def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag)
res[..., 1] = -res[..., 1] res[..., 1] = -res[..., 1]
return res return res
def spec_gaussian_filter(res, sig): def spec_gaussian_filter(res, sig):
omega = fftfreqs(res, dtype=torch.float64) # [dim0, dim1, dim2, d] omega = fftfreqs(res, dtype=torch.float64) # [dim0, dim1, dim2, d]
dis = torch.sqrt(torch.sum(omega ** 2, dim=-1)) dis = torch.sqrt(torch.sum(omega**2, dim=-1))
filter_ = torch.exp(-0.5*((sig*2*dis/res[0])**2)).unsqueeze(-1).unsqueeze(-1) filter_ = torch.exp(-0.5 * ((sig * 2 * dis / res[0]) ** 2)).unsqueeze(-1).unsqueeze(-1)
filter_.requires_grad = False filter_.requires_grad = False
return filter_ return filter_
def grid_interp(grid, pts, batched=True): def grid_interp(grid, pts, batched=True):
""" """:param grid: tensor of shape (batch, *size, in_features)
:param grid: tensor of shape (batch, *size, in_features)
:param pts: tensor of shape (batch, num_points, dim) within range (0, 1) :param pts: tensor of shape (batch, num_points, dim) within range (0, 1)
:return values at query points :return values at query points
""" """
@ -82,69 +83,71 @@ def grid_interp(grid, pts, batched=True):
bs = grid.shape[0] bs = grid.shape[0]
size = torch.tensor(grid.shape[1:-1]).to(grid.device).type(pts.dtype) size = torch.tensor(grid.shape[1:-1]).to(grid.device).type(pts.dtype)
cubesize = 1.0 / size cubesize = 1.0 / size
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim) ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim) ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
tmp = torch.tensor([0,1],dtype=torch.long) tmp = torch.tensor([0, 1], dtype=torch.long)
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim) com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim) dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points) ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim) ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim) ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
# latent code on neighbor nodes # latent code on neighbor nodes
if dim == 2: if dim == 2:
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1]] # (batch, num_points, 2**dim, in_features) lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1]] # (batch, num_points, 2**dim, in_features)
else: else:
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2]] # (batch, num_points, 2**dim, in_features) lat = grid.clone()[
ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2],
] # (batch, num_points, 2**dim, in_features)
# weights of neighboring nodes # weights of neighboring nodes
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim) xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim) xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim) xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim) xyz01[com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim) pos_ = xyz01[1 - com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
pos_ = pos_.type(pts.dtype) pos_ = pos_.type(pts.dtype)
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim) dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim) weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
query_values = torch.sum(lat * weights.unsqueeze(-1), dim=-2) # (batch, num_points, in_features) query_values = torch.sum(lat * weights.unsqueeze(-1), dim=-2) # (batch, num_points, in_features)
if not batched: if not batched:
query_values = query_values.squeeze(0) query_values = query_values.squeeze(0)
return query_values return query_values
def scatter_to_grid(inds, vals, size): def scatter_to_grid(inds, vals, size):
""" """Scatter update values into empty tensor of size size.
Scatter update values into empty tensor of size size.
:param inds: (#values, dims) :param inds: (#values, dims)
:param vals: (#values) :param vals: (#values)
:param size: tuple for size. len(size)=dims :param size: tuple for size. len(size)=dims.
""" """
dims = inds.shape[1] dims = inds.shape[1]
assert(inds.shape[0] == vals.shape[0]) assert inds.shape[0] == vals.shape[0]
assert(len(size) == dims) assert len(size) == dims
dev = vals.device dev = vals.device
# result = torch.zeros(*size).view(-1).to(dev).type(vals.dtype) # flatten # result = torch.zeros(*size).view(-1).to(dev).type(vals.dtype) # flatten
# # flatten inds # # flatten inds
result = torch.zeros(*size, device=dev).view(-1).type(vals.dtype) # flatten result = torch.zeros(*size, device=dev).view(-1).type(vals.dtype) # flatten
# flatten inds # flatten inds
fac = [np.prod(size[i+1:]) for i in range(len(size)-1)] + [1] fac = [np.prod(size[i + 1 :]) for i in range(len(size) - 1)] + [1]
fac = torch.tensor(fac, device=dev).type(inds.dtype) fac = torch.tensor(fac, device=dev).type(inds.dtype)
inds_fold = torch.sum(inds*fac, dim=-1) # [#values,] inds_fold = torch.sum(inds * fac, dim=-1) # [#values,]
result.scatter_add_(0, inds_fold, vals) result.scatter_add_(0, inds_fold, vals)
result = result.view(*size) result = result.view(*size)
return result return result
def point_rasterize(pts, vals, size): def point_rasterize(pts, vals, size):
""" """:param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1)
:param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1)
:param vals: point values, tensor of shape (batch, num_points, features) :param vals: point values, tensor of shape (batch, num_points, features)
:param size: len(size)=dim tuple for grid size :param size: len(size)=dim tuple for grid size
:return rasterized values (batch, features, res0, res1, res2) :return rasterized values (batch, features, res0, res1, res2)
""" """
dim = pts.shape[-1] dim = pts.shape[-1]
assert(pts.shape[:2] == vals.shape[:2]) assert pts.shape[:2] == vals.shape[:2]
assert(pts.shape[2] == dim) assert pts.shape[2] == dim
size_list = list(size) size_list = list(size)
size = torch.tensor(size).to(pts.device).float() size = torch.tensor(size).to(pts.device).float()
cubesize = 1.0 / size cubesize = 1.0 / size
@ -152,55 +155,58 @@ def point_rasterize(pts, vals, size):
nf = vals.shape[-1] nf = vals.shape[-1]
npts = pts.shape[1] npts = pts.shape[1]
dev = pts.device dev = pts.device
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim) ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim) ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
tmp = torch.tensor([0,1],dtype=torch.long) tmp = torch.tensor([0, 1], dtype=torch.long)
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim) com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim) dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points) ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim) ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
# ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim) # ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
ind_b = torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim) ind_b = (
torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1)
) # (batch, num_points, 2**dim)
# weights of neighboring nodes # weights of neighboring nodes
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim) xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim) xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim) xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim) xyz01[com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim) pos_ = xyz01[1 - com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
pos_ = pos_.type(pts.dtype) pos_ = pos_.type(pts.dtype)
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim) dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim) weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
ind_b = ind_b.unsqueeze(-1).unsqueeze(-1) # (batch, num_points, 2**dim, 1, 1) ind_b = ind_b.unsqueeze(-1).unsqueeze(-1) # (batch, num_points, 2**dim, 1, 1)
ind_n = ind_n.unsqueeze(-2) # (batch, num_points, 2**dim, 1, dim) ind_n = ind_n.unsqueeze(-2) # (batch, num_points, 2**dim, 1, dim)
ind_f = torch.arange(nf, device=dev).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1) ind_f = torch.arange(nf, device=dev).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1)
# ind_f = torch.arange(nf).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1) # ind_f = torch.arange(nf).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1)
ind_b = ind_b.expand(bs, npts, 2**dim, nf, 1) ind_b = ind_b.expand(bs, npts, 2**dim, nf, 1)
ind_n = ind_n.expand(bs, npts, 2**dim, nf, dim).to(dev) ind_n = ind_n.expand(bs, npts, 2**dim, nf, dim).to(dev)
ind_f = ind_f.expand(bs, npts, 2**dim, nf, 1) ind_f = ind_f.expand(bs, npts, 2**dim, nf, 1)
inds = torch.cat([ind_b, ind_f, ind_n], dim=-1) # (batch, num_points, 2**dim, nf, 1+1+dim) inds = torch.cat([ind_b, ind_f, ind_n], dim=-1) # (batch, num_points, 2**dim, nf, 1+1+dim)
# weighted values
vals = weights.unsqueeze(-1) * vals.unsqueeze(-2) # (batch, num_points, 2**dim, nf)
inds = inds.view(-1, dim+2).permute(1, 0).long() # (1+dim+1, bs*npts*2**dim*nf)
vals = vals.reshape(-1) # (bs*npts*2**dim*nf)
tensor_size = [bs, nf] + size_list
raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf] + size_list)
return raster # [batch, nf, res, res, res]
# weighted values
vals = weights.unsqueeze(-1) * vals.unsqueeze(-2) # (batch, num_points, 2**dim, nf)
inds = inds.view(-1, dim + 2).permute(1, 0).long() # (1+dim+1, bs*npts*2**dim*nf)
vals = vals.reshape(-1) # (bs*npts*2**dim*nf)
[bs, nf, *size_list]
raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf, *size_list])
return raster # [batch, nf, res, res, res]
################################################## ##################################################
# Below are the utilization functions in general # Below are the utilization functions in general
class AverageMeter(object):
"""Computes and stores the average and current value""" class AverageMeter:
"""Computes and stores the average and current value."""
def __init__(self): def __init__(self):
self.reset() self.reset()
@ -226,49 +232,49 @@ class AverageMeter(object):
def avgcavg(self): def avgcavg(self):
return self.avg.sum().item() / (self.count != 0).sum().item() return self.avg.sum().item() / (self.count != 0).sum().item()
def load_model_manual(state_dict, model): def load_model_manual(state_dict, model):
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
is_model_parallel = isinstance(model, torch.nn.DataParallel) is_model_parallel = isinstance(model, torch.nn.DataParallel)
for k, v in state_dict.items(): for k, v in state_dict.items():
if k.startswith('module.') != is_model_parallel: if k.startswith("module.") != is_model_parallel:
if k.startswith('module.'): if k.startswith("module."):
# remove module # remove module
k = k[7:] k = k[7:]
else: else:
# add module # add module
k = 'module.' + k k = "module." + k
new_state_dict[k]=v new_state_dict[k] = v
model.load_state_dict(new_state_dict) model.load_state_dict(new_state_dict)
def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0): def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0):
''' """Run marching cubes from PSR grid."""
Run marching cubes from PSR grid
'''
batch_size = psr_grid.shape[0] batch_size = psr_grid.shape[0]
s = psr_grid.shape[-1] # size of psr_grid s = psr_grid.shape[-1] # size of psr_grid
psr_grid_numpy = psr_grid.squeeze().detach().cpu().numpy() psr_grid_numpy = psr_grid.squeeze().detach().cpu().numpy()
if batch_size>1: if batch_size > 1:
verts, faces, normals = [], [], [] verts, faces, normals = [], [], []
for i in range(batch_size): for i in range(batch_size):
verts_cur, faces_cur, normals_cur, values = measure.marching_cubes(psr_grid_numpy[i], level=0) verts_cur, faces_cur, normals_cur, values = measure.marching_cubes(psr_grid_numpy[i], level=0)
verts.append(verts_cur) verts.append(verts_cur)
faces.append(faces_cur) faces.append(faces_cur)
normals.append(normals_cur) normals.append(normals_cur)
verts = np.stack(verts, axis = 0) verts = np.stack(verts, axis=0)
faces = np.stack(faces, axis = 0) faces = np.stack(faces, axis=0)
normals = np.stack(normals, axis = 0) normals = np.stack(normals, axis=0)
else: else:
try: try:
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy, level=zero_level) verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy, level=zero_level)
except: except:
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy) verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy)
if real_scale: if real_scale:
verts = verts / (s-1) # scale to range [0, 1] verts = verts / (s - 1) # scale to range [0, 1]
else: else:
verts = verts / s # scale to range [0, 1) verts = verts / s # scale to range [0, 1)
if pytorchify: if pytorchify:
device = psr_grid.device device = psr_grid.device
@ -277,7 +283,8 @@ def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0):
normals = torch.Tensor(np.ascontiguousarray(-normals)).to(device) normals = torch.Tensor(np.ascontiguousarray(-normals)).to(device)
return verts, faces, normals return verts, faces, normals
def calc_inters_points(verts, faces, pose, img_size, mask_gt=None): def calc_inters_points(verts, faces, pose, img_size, mask_gt=None):
verts = verts.squeeze() verts = verts.squeeze()
faces = faces.squeeze() faces = faces.squeeze()
@ -288,42 +295,39 @@ def calc_inters_points(verts, faces, pose, img_size, mask_gt=None):
# find 3D points intesected on the mesh # find 3D points intesected on the mesh
if True: if True:
w_masked = w[mask] w_masked = w[mask]
f_p = faces[pix_to_face[mask]].long() # cooresponding faces for each pixel f_p = faces[pix_to_face[mask]].long() # cooresponding faces for each pixel
# corresponding vertices for p_closest # corresponding vertices for p_closest
v_a, v_b, v_c = verts[f_p[..., 0]], verts[f_p[..., 1]], verts[f_p[..., 2]] v_a, v_b, v_c = verts[f_p[..., 0]], verts[f_p[..., 1]], verts[f_p[..., 2]]
# calculate the intersection point of each pixel and the mesh # calculate the intersection point of each pixel and the mesh
p_inters = w_masked[..., 0, None] * v_a + \ p_inters = w_masked[..., 0, None] * v_a + w_masked[..., 1, None] * v_b + w_masked[..., 2, None] * v_c
w_masked[..., 1, None] * v_b + \
w_masked[..., 2, None] * v_c
else: else:
# backproject ndc to world coordinates using z-buffer # backproject ndc to world coordinates using z-buffer
W, H = img_size[1], img_size[0] W, H = img_size[1], img_size[0]
xy = uv.to(mask.device)[mask] xy = uv.to(mask.device)[mask]
x_ndc = 1 - (2*xy[:, 0]) / (W - 1) x_ndc = 1 - (2 * xy[:, 0]) / (W - 1)
y_ndc = 1 - (2*xy[:, 1]) / (H - 1) y_ndc = 1 - (2 * xy[:, 1]) / (H - 1)
z = zbuf.squeeze().reshape(H * W)[mask] z = zbuf.squeeze().reshape(H * W)[mask]
xy_depth = torch.stack((x_ndc, y_ndc, z), dim=1) xy_depth = torch.stack((x_ndc, y_ndc, z), dim=1)
p_inters = pose.unproject_points(xy_depth, world_coordinates=True) p_inters = pose.unproject_points(xy_depth, world_coordinates=True)
# if there are outlier points, we should remove it # if there are outlier points, we should remove it
if (p_inters.max()>1) | (p_inters.min()<-1): if (p_inters.max() > 1) | (p_inters.min() < -1):
mask_bound = (p_inters>=-1) & (p_inters<=1) mask_bound = (p_inters >= -1) & (p_inters <= 1)
mask_bound = (mask_bound.sum(dim=-1)==3) mask_bound = mask_bound.sum(dim=-1) == 3
mask[mask==True] = mask_bound mask[mask is True] = mask_bound
p_inters = p_inters[mask_bound] p_inters = p_inters[mask_bound]
print('!!!!!find outlier!') print("!!!!!find outlier!")
return p_inters, mask, f_p, w_masked return p_inters, mask, f_p, w_masked
def mesh_rasterization(verts, faces, pose, img_size): def mesh_rasterization(verts, faces, pose, img_size):
''' """Use PyTorch3D to rasterize the mesh given a camera."""
Use PyTorch3D to rasterize the mesh given a camera transformed_v = pose.transform_points(verts.detach()) # world -> ndc coordinate system
'''
transformed_v = pose.transform_points(verts.detach()) # world -> ndc coordinate system
if isinstance(pose, PerspectiveCameras): if isinstance(pose, PerspectiveCameras):
transformed_v[..., 2] = 1/transformed_v[..., 2] transformed_v[..., 2] = 1 / transformed_v[..., 2]
# find p_closest on mesh of each pixel via rasterization # find p_closest on mesh of each pixel via rasterization
transformed_mesh = Meshes(verts=[transformed_v], faces=[faces]) transformed_mesh = Meshes(verts=[transformed_v], faces=[faces])
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
@ -331,9 +335,9 @@ def mesh_rasterization(verts, faces, pose, img_size):
image_size=img_size, image_size=img_size,
blur_radius=0, blur_radius=0,
faces_per_pixel=1, faces_per_pixel=1,
perspective_correct=False perspective_correct=False,
) )
pix_to_face = pix_to_face.reshape(1, -1) # B x reso x reso -> B x (reso x reso) pix_to_face = pix_to_face.reshape(1, -1) # B x reso x reso -> B x (reso x reso)
mask = pix_to_face.clone() != -1 mask = pix_to_face.clone() != -1
mask = mask.squeeze() mask = mask.squeeze()
pix_to_face = pix_to_face.squeeze() pix_to_face = pix_to_face.squeeze()
@ -341,11 +345,11 @@ def mesh_rasterization(verts, faces, pose, img_size):
return pix_to_face, w, mask return pix_to_face, w, mask
def verts_on_largest_mesh(verts, faces): def verts_on_largest_mesh(verts, faces):
''' """verts: Numpy array or Torch.Tensor (N, 3)
verts: Numpy array or Torch.Tensor (N, 3) faces: Numpy array (N, 3).
faces: Numpy array (N, 3) """
'''
if torch.is_tensor(faces): if torch.is_tensor(faces):
verts = verts.squeeze().detach().cpu().numpy() verts = verts.squeeze().detach().cpu().numpy()
faces = faces.squeeze().int().detach().cpu().numpy() faces = faces.squeeze().int().detach().cpu().numpy()
@ -355,8 +359,8 @@ def verts_on_largest_mesh(verts, faces):
if num == 0: if num == 0:
v_large, f_large = verts, faces v_large, f_large = verts, faces
else: else:
max_idx = conn_size.argmax() # find the index of the largest component max_idx = conn_size.argmax() # find the index of the largest component
v_large = verts[conn_idx==max_idx] # keep points on the largest component v_large = verts[conn_idx == max_idx] # keep points on the largest component
if True: if True:
mesh_largest = trimesh.Trimesh(verts, faces) mesh_largest = trimesh.Trimesh(verts, faces)
@ -366,36 +370,41 @@ def verts_on_largest_mesh(verts, faces):
v_large = v_large.astype(np.float32) v_large = v_large.astype(np.float32)
return v_large, f_large return v_large, f_large
def load_pointcloud(in_file): def load_pointcloud(in_file):
plydata = PlyData.read(in_file) plydata = PlyData.read(in_file)
vertices = np.stack([ vertices = np.stack(
plydata['vertex']['x'], [
plydata['vertex']['y'], plydata["vertex"]["x"],
plydata['vertex']['z'] plydata["vertex"]["y"],
], axis=1) plydata["vertex"]["z"],
],
axis=1,
)
return vertices return vertices
# General config # General config
def load_config(path, default_path=None): def load_config(path, default_path=None):
''' Loads config file. """Loads config file.
Args: Args:
path (str): path to config file path (str): path to config file
default_path (bool): whether to use default path default_path (bool): whether to use default path
''' """
# Load configuration from file itself # Load configuration from file itself
with open(path, 'r') as f: with open(path) as f:
cfg_special = yaml.load(f, Loader=yaml.Loader) cfg_special = yaml.load(f, Loader=yaml.Loader)
# Check if we should inherit from a config # Check if we should inherit from a config
inherit_from = cfg_special.get('inherit_from') inherit_from = cfg_special.get("inherit_from")
# If yes, load this config first as default # If yes, load this config first as default
# If no, use the default_path # If no, use the default_path
if inherit_from is not None: if inherit_from is not None:
cfg = load_config(inherit_from, default_path) cfg = load_config(inherit_from, default_path)
elif default_path is not None: elif default_path is not None:
with open(default_path, 'r') as f: with open(default_path) as f:
cfg = yaml.load(f, Loader=yaml.Loader) cfg = yaml.load(f, Loader=yaml.Loader)
else: else:
cfg = dict() cfg = dict()
@ -405,65 +414,67 @@ def load_config(path, default_path=None):
return cfg return cfg
def update_config(config, unknown): def update_config(config, unknown):
# update config given args # update config given args
for idx,arg in enumerate(unknown): for idx, arg in enumerate(unknown):
if arg.startswith("--"): if arg.startswith("--"):
keys = arg.replace("--","").split(':') keys = arg.replace("--", "").split(":")
assert(len(keys)==2) assert len(keys) == 2
k1, k2 = keys k1, k2 = keys
argtype = type(config[k1][k2]) argtype = type(config[k1][k2])
if argtype == bool: if argtype == bool:
v = unknown[idx+1].lower() == 'true' v = unknown[idx + 1].lower() == "true"
else: else:
if config[k1][k2] is not None: if config[k1][k2] is not None:
v = type(config[k1][k2])(unknown[idx+1]) v = type(config[k1][k2])(unknown[idx + 1])
else: else:
v = unknown[idx+1] v = unknown[idx + 1]
print(f'Changing {k1}:{k2} ---- {config[k1][k2]} to {v}') print(f"Changing {k1}:{k2} ---- {config[k1][k2]} to {v}")
config[k1][k2] = v config[k1][k2] = v
return config return config
def initialize_logger(cfg): def initialize_logger(cfg):
out_dir = cfg['train']['out_dir'] out_dir = cfg["train"]["out_dir"]
if not out_dir: if not out_dir:
os.makedirs(out_dir) os.makedirs(out_dir)
cfg['train']['dir_model'] = os.path.join(out_dir, 'model')
os.makedirs(cfg['train']['dir_model'], exist_ok=True)
if cfg['train']['exp_mesh']: cfg["train"]["dir_model"] = os.path.join(out_dir, "model")
cfg['train']['dir_mesh'] = os.path.join(out_dir, 'vis/mesh') os.makedirs(cfg["train"]["dir_model"], exist_ok=True)
os.makedirs(cfg['train']['dir_mesh'], exist_ok=True)
if cfg['train']['exp_pcl']:
cfg['train']['dir_pcl'] = os.path.join(out_dir, 'vis/pointcloud')
os.makedirs(cfg['train']['dir_pcl'], exist_ok=True)
if cfg['train']['vis_rendering']:
cfg['train']['dir_rendering'] = os.path.join(out_dir, 'vis/rendering')
os.makedirs(cfg['train']['dir_rendering'], exist_ok=True)
if cfg['train']['o3d_show']:
cfg['train']['dir_o3d'] = os.path.join(out_dir, 'vis/o3d')
os.makedirs(cfg['train']['dir_o3d'], exist_ok=True)
if cfg["train"]["exp_mesh"]:
cfg["train"]["dir_mesh"] = os.path.join(out_dir, "vis/mesh")
os.makedirs(cfg["train"]["dir_mesh"], exist_ok=True)
if cfg["train"]["exp_pcl"]:
cfg["train"]["dir_pcl"] = os.path.join(out_dir, "vis/pointcloud")
os.makedirs(cfg["train"]["dir_pcl"], exist_ok=True)
if cfg["train"]["vis_rendering"]:
cfg["train"]["dir_rendering"] = os.path.join(out_dir, "vis/rendering")
os.makedirs(cfg["train"]["dir_rendering"], exist_ok=True)
if cfg["train"]["o3d_show"]:
cfg["train"]["dir_o3d"] = os.path.join(out_dir, "vis/o3d")
os.makedirs(cfg["train"]["dir_o3d"], exist_ok=True)
logger = logging.getLogger("train") logger = logging.getLogger("train")
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
logger.handlers = [] logger.handlers = []
# ch = logging.StreamHandler() # ch = logging.StreamHandler()
# logger.addHandler(ch) # logger.addHandler(ch)
fh = logging.FileHandler(os.path.join(cfg['train']['out_dir'], "log.txt")) fh = logging.FileHandler(os.path.join(cfg["train"]["out_dir"], "log.txt"))
logger.addHandler(fh) logger.addHandler(fh)
logger.info('Outout dir: %s', out_dir) logger.info("Outout dir: %s", out_dir)
return logger return logger
def update_recursive(dict1, dict2): def update_recursive(dict1, dict2):
''' Update two config dictionaries recursively. """Update two config dictionaries recursively.
Args: Args:
dict1 (dict): first dictionary to be updated dict1 (dict): first dictionary to be updated
dict2 (dict): second dictionary which entries should be used dict2 (dict): second dictionary which entries should be used
''' """
for k, v in dict2.items(): for k, v in dict2.items():
if k not in dict1: if k not in dict1:
dict1[k] = dict() dict1[k] = dict()
@ -472,6 +483,7 @@ def update_recursive(dict1, dict2):
else: else:
dict1[k] = v dict1[k] = v
def export_pointcloud(name, points, normals=None): def export_pointcloud(name, points, normals=None):
if len(points.shape) > 2: if len(points.shape) > 2:
points = points[0] points = points[0]
@ -487,6 +499,7 @@ def export_pointcloud(name, points, normals=None):
pcd.normals = o3d.utility.Vector3dVector(normals) pcd.normals = o3d.utility.Vector3dVector(normals)
o3d.io.write_point_cloud(name, pcd) o3d.io.write_point_cloud(name, pcd)
def export_mesh(name, v, f): def export_mesh(name, v, f):
if len(v.shape) > 2: if len(v.shape) > 2:
v, f = v[0], f[0] v, f = v[0], f[0]
@ -498,59 +511,63 @@ def export_mesh(name, v, f):
mesh.triangles = o3d.utility.Vector3iVector(f) mesh.triangles = o3d.utility.Vector3iVector(f)
o3d.io.write_triangle_mesh(name, mesh) o3d.io.write_triangle_mesh(name, mesh)
def scale2onet(p, scale=1.2): def scale2onet(p, scale=1.2):
''' """Scale the point cloud from SAP to ONet range."""
Scale the point cloud from SAP to ONet range
'''
return (p - 0.5) * scale return (p - 0.5) * scale
def update_optimizer(inputs, cfg, epoch, model=None, schedule=None): def update_optimizer(inputs, cfg, epoch, model=None, schedule=None):
if model is not None: if model is not None:
if schedule is not None: if schedule is not None:
optimizer = torch.optim.Adam([ optimizer = torch.optim.Adam(
{"params": model.parameters(), [
"lr": schedule[0].get_learning_rate(epoch)}, {"params": model.parameters(), "lr": schedule[0].get_learning_rate(epoch)},
{"params": inputs, {"params": inputs, "lr": schedule[1].get_learning_rate(epoch)},
"lr": schedule[1].get_learning_rate(epoch)}]) ],
elif 'lr' in cfg['train']: )
optimizer = torch.optim.Adam([ elif "lr" in cfg["train"]:
{"params": model.parameters(), optimizer = torch.optim.Adam(
"lr": float(cfg['train']['lr'])}, [
{"params": inputs, {"params": model.parameters(), "lr": float(cfg["train"]["lr"])},
"lr": float(cfg['train']['lr_pcl'])}]) {"params": inputs, "lr": float(cfg["train"]["lr_pcl"])},
],
)
else: else:
raise Exception('no known learning rate') msg = "no known learning rate"
raise Exception(msg)
else: else:
if schedule is not None: if schedule is not None:
optimizer = torch.optim.Adam([inputs], lr=schedule[0].get_learning_rate(epoch)) optimizer = torch.optim.Adam([inputs], lr=schedule[0].get_learning_rate(epoch))
else: else:
optimizer = torch.optim.Adam([inputs], lr=float(cfg['train']['lr_pcl'])) optimizer = torch.optim.Adam([inputs], lr=float(cfg["train"]["lr_pcl"]))
return optimizer return optimizer
def is_url(url): def is_url(url):
scheme = urllib.parse.urlparse(url).scheme scheme = urllib.parse.urlparse(url).scheme
return scheme in ('http', 'https') return scheme in ("http", "https")
def load_url(url): def load_url(url):
'''Load a module dictionary from url. """Load a module dictionary from url.
Args: Args:
url (str): url to saved model url (str): url to saved model
''' """
print(url) print(url)
print('=> Loading checkpoint from url...') print("=> Loading checkpoint from url...")
state_dict = model_zoo.load_url(url, progress=True) state_dict = model_zoo.load_url(url, progress=True)
return state_dict return state_dict
class GaussianSmoothing(nn.Module): class GaussianSmoothing(nn.Module):
""" """Apply gaussian smoothing on a
Apply gaussian smoothing on a
1d, 2d or 3d tensor. Filtering is performed seperately for each channel 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
in the input using a depthwise convolution. in the input using a depthwise convolution.
Arguments: Arguments:
channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well. channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well.
kernel_size (int, sequence): Size of the gaussian kernel. kernel_size (int, sequence): Size of the gaussian kernel.
@ -558,8 +575,9 @@ class GaussianSmoothing(nn.Module):
dim (int, optional): The number of dimensions of the data. dim (int, optional): The number of dimensions of the data.
Default value is 2 (spatial). Default value is 2 (spatial).
""" """
def __init__(self, channels, kernel_size, sigma, dim=3): def __init__(self, channels, kernel_size, sigma, dim=3):
super(GaussianSmoothing, self).__init__() super().__init__()
if isinstance(kernel_size, numbers.Number): if isinstance(kernel_size, numbers.Number):
kernel_size = [kernel_size] * dim kernel_size = [kernel_size] * dim
if isinstance(sigma, numbers.Number): if isinstance(sigma, numbers.Number):
@ -569,15 +587,11 @@ class GaussianSmoothing(nn.Module):
# gaussian function of each dimension. # gaussian function of each dimension.
kernel = 1 kernel = 1
meshgrids = torch.meshgrid( meshgrids = torch.meshgrid(
[ [torch.arange(size, dtype=torch.float32) for size in kernel_size],
torch.arange(size, dtype=torch.float32)
for size in kernel_size
]
) )
for size, std, mgrid in zip(kernel_size, sigma, meshgrids): for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2 mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / std) ** 2) / 2)
torch.exp(-((mgrid - mean) / std) ** 2 / 2)
# Make sure sum of values in gaussian kernel equals 1. # Make sure sum of values in gaussian kernel equals 1.
kernel = kernel / torch.sum(kernel) kernel = kernel / torch.sum(kernel)
@ -586,7 +600,7 @@ class GaussianSmoothing(nn.Module):
kernel = kernel.view(1, 1, *kernel.size()) kernel = kernel.view(1, 1, *kernel.size())
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
self.register_buffer('weight', kernel) self.register_buffer("weight", kernel)
self.groups = channels self.groups = channels
if dim == 1: if dim == 1:
@ -596,36 +610,44 @@ class GaussianSmoothing(nn.Module):
elif dim == 3: elif dim == 3:
self.conv = F.conv3d self.conv = F.conv3d
else: else:
msg = f"Only 1, 2 and 3 dimensions are supported. Received {dim}."
raise RuntimeError( raise RuntimeError(
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) msg,
) )
def forward(self, input): def forward(self, input):
""" """Apply gaussian filter to input.
Apply gaussian filter to input.
Arguments: Arguments:
input (torch.Tensor): Input to apply gaussian filter on. input (torch.Tensor): Input to apply gaussian filter on.
Returns: Returns:
filtered (torch.Tensor): Filtered output. filtered (torch.Tensor): Filtered output.
""" """
return self.conv(input, weight=self.weight, groups=self.groups) return self.conv(input, weight=self.weight, groups=self.groups)
# Originally from https://github.com/amosgropp/IGR/blob/0db06b1273/code/utils/general.py # Originally from https://github.com/amosgropp/IGR/blob/0db06b1273/code/utils/general.py
def get_learning_rate_schedules(schedule_specs): def get_learning_rate_schedules(schedule_specs):
schedules = [] schedules = []
for key in schedule_specs.keys(): for key in schedule_specs.keys():
schedules.append(StepLearningRateSchedule( schedules.append(
schedule_specs[key]['initial'], StepLearningRateSchedule(
schedule_specs[key]["initial"],
schedule_specs[key]["interval"], schedule_specs[key]["interval"],
schedule_specs[key]["factor"], schedule_specs[key]["factor"],
schedule_specs[key]["final"])) schedule_specs[key]["final"],
),
)
return schedules return schedules
class LearningRateSchedule: class LearningRateSchedule:
def get_learning_rate(self, epoch): def get_learning_rate(self, epoch):
pass pass
class StepLearningRateSchedule(LearningRateSchedule): class StepLearningRateSchedule(LearningRateSchedule):
def __init__(self, initial, interval, factor, final=1e-6): def __init__(self, initial, interval, factor, final=1e-6):
self.initial = float(initial) self.initial = float(initial)
@ -640,6 +662,7 @@ class StepLearningRateSchedule(LearningRateSchedule):
else: else:
return self.final return self.final
def adjust_learning_rate(lr_schedules, optimizer, epoch): def adjust_learning_rate(lr_schedules, optimizer, epoch):
for i, param_group in enumerate(optimizer.param_groups): for i, param_group in enumerate(optimizer.param_groups):
param_group["lr"] = lr_schedules[i].get_learning_rate(epoch) param_group["lr"] = lr_schedules[i].get_learning_rate(epoch)

View file

@ -1,95 +1,93 @@
import os import os
import torch
import matplotlib.pyplot as plt
import numpy as np import numpy as np
import open3d as o3d import open3d as o3d
import matplotlib.pyplot as plt import torch
from skimage import measure
from src.utils import calc_inters_points, grid_interp
from scipy import ndimage from scipy import ndimage
from tqdm import trange
from torchvision.utils import save_image from torchvision.utils import save_image
from pdb import set_trace as st from tqdm import trange
from src.utils import calc_inters_points, grid_interp
def visualize_points_mesh(vis, points, normals, verts, faces, cfg, it, epoch, color_v=None): def visualize_points_mesh(vis, points, normals, verts, faces, cfg, it, epoch, color_v=None):
''' Visualization. """Visualization.
Args: Args:
data (dict): data dictionary data (dict): data dictionary
depth (int): PSR depth depth (int): PSR depth
out_path (str): output path for the mesh out_path (str): output path for the mesh
''' """
mesh = o3d.geometry.TriangleMesh() mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(verts) mesh.vertices = o3d.utility.Vector3dVector(verts)
mesh.triangles = o3d.utility.Vector3iVector(faces) mesh.triangles = o3d.utility.Vector3iVector(faces)
mesh.paint_uniform_color(np.array([0.7,0.7,0.7])) mesh.paint_uniform_color(np.array([0.7, 0.7, 0.7]))
if color_v is not None: if color_v is not None:
mesh.vertex_colors = o3d.utility.Vector3dVector(color_v) mesh.vertex_colors = o3d.utility.Vector3dVector(color_v)
if vis is not None: if vis is not None:
dir_o3d = cfg['train']['dir_o3d'] dir_o3d = cfg["train"]["dir_o3d"]
wire = o3d.geometry.LineSet.create_from_triangle_mesh(mesh) o3d.geometry.LineSet.create_from_triangle_mesh(mesh)
p = points.squeeze(0).detach().cpu().numpy() p = points.squeeze(0).detach().cpu().numpy()
n = normals.squeeze(0).detach().cpu().numpy() n = normals.squeeze(0).detach().cpu().numpy()
pcd = o3d.geometry.PointCloud() pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(p) pcd.points = o3d.utility.Vector3dVector(p)
pcd.normals = o3d.utility.Vector3dVector(n) pcd.normals = o3d.utility.Vector3dVector(n)
pcd.paint_uniform_color(np.array([0.7,0.7,1.0])) pcd.paint_uniform_color(np.array([0.7, 0.7, 1.0]))
# pcd = pcd.uniform_down_sample(5) # pcd = pcd.uniform_down_sample(5)
vis.clear_geometries() vis.clear_geometries()
vis.add_geometry(mesh) vis.add_geometry(mesh)
vis.update_geometry(mesh) vis.update_geometry(mesh)
#! Thingi wheel - an example for how to change cameras in Open3D viewers #! Thingi wheel - an example for how to change cameras in Open3D viewers
vis.get_view_control().set_front([ 0.0461, -0.7467, 0.6635 ]) vis.get_view_control().set_front([0.0461, -0.7467, 0.6635])
vis.get_view_control().set_lookat([ 0.0092, 0.0078, 0.0638 ]) vis.get_view_control().set_lookat([0.0092, 0.0078, 0.0638])
vis.get_view_control().set_up([ 0.0520, 0.6651, 0.7449 ]) vis.get_view_control().set_up([0.0520, 0.6651, 0.7449])
vis.get_view_control().set_zoom(0.7) vis.get_view_control().set_zoom(0.7)
vis.poll_events() vis.poll_events()
out_path = os.path.join(dir_o3d, '{}.jpg'.format(it)) out_path = os.path.join(dir_o3d, f"{it}.jpg")
vis.capture_screen_image(out_path) vis.capture_screen_image(out_path)
vis.clear_geometries() vis.clear_geometries()
vis.add_geometry(pcd, reset_bounding_box=False) vis.add_geometry(pcd, reset_bounding_box=False)
vis.update_geometry(pcd) vis.update_geometry(pcd)
vis.get_render_option().point_show_normal=True # visualize point normals vis.get_render_option().point_show_normal = True # visualize point normals
vis.get_view_control().set_front([ 0.0461, -0.7467, 0.6635 ]) vis.get_view_control().set_front([0.0461, -0.7467, 0.6635])
vis.get_view_control().set_lookat([ 0.0092, 0.0078, 0.0638 ]) vis.get_view_control().set_lookat([0.0092, 0.0078, 0.0638])
vis.get_view_control().set_up([ 0.0520, 0.6651, 0.7449 ]) vis.get_view_control().set_up([0.0520, 0.6651, 0.7449])
vis.get_view_control().set_zoom(0.7) vis.get_view_control().set_zoom(0.7)
vis.poll_events() vis.poll_events()
out_path = os.path.join(dir_o3d, '{}_pcd.jpg'.format(it)) out_path = os.path.join(dir_o3d, f"{it}_pcd.jpg")
vis.capture_screen_image(out_path) vis.capture_screen_image(out_path)
def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name='video.mp4'):
def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name="video.mp4"):
if pose is not None: if pose is not None:
device = psr_grid.device device = psr_grid.device
# get world coordinate of grid points [-1, 1] # get world coordinate of grid points [-1, 1]
res = psr_grid.shape[-1] res = psr_grid.shape[-1]
x = torch.linspace(-1, 1, steps=res) x = torch.linspace(-1, 1, steps=res)
co_x, co_y, co_z = torch.meshgrid(x, x, x) co_x, co_y, co_z = torch.meshgrid(x, x, x)
co_grid = torch.stack( co_grid = torch.stack([co_x.reshape(-1), co_y.reshape(-1), co_z.reshape(-1)], dim=1).to(device).unsqueeze(0)
[co_x.reshape(-1), co_y.reshape(-1), co_z.reshape(-1)],
dim=1).to(device).unsqueeze(0)
# visualize the projected occ_soft value # visualize the projected occ_soft value
res = 128 res = 128
psr_grid = psr_grid.reshape(-1) psr_grid = psr_grid.reshape(-1)
out_mask = psr_grid>0 out_mask = psr_grid > 0
in_mask = psr_grid<0 in_mask = psr_grid < 0
pix = pose.transform_points_screen(co_grid, ((res, res),))[..., :2].round().long().squeeze() pix = pose.transform_points_screen(co_grid, ((res, res),))[..., :2].round().long().squeeze()
vis_mask = (pix[..., 0]>=0) & (pix[..., 0]<=res-1) & \ vis_mask = (pix[..., 0] >= 0) & (pix[..., 0] <= res - 1) & (pix[..., 1] >= 0) & (pix[..., 1] <= res - 1)
(pix[..., 1]>=0) & (pix[..., 1]<=res-1)
pix_out = pix[vis_mask & out_mask] pix_out = pix[vis_mask & out_mask]
pix_in = pix[vis_mask & in_mask] pix_in = pix[vis_mask & in_mask]
img = torch.ones([res,res]).to(device) img = torch.ones([res, res]).to(device)
psr_grid = torch.sigmoid(- psr_grid * 5) psr_grid = torch.sigmoid(-psr_grid * 5)
img[pix_out[:, 1], pix_out[:, 0]] = psr_grid[vis_mask & out_mask] img[pix_out[:, 1], pix_out[:, 0]] = psr_grid[vis_mask & out_mask]
img[pix_in[:, 1], pix_in[:, 0]] = psr_grid[vis_mask & in_mask] img[pix_in[:, 1], pix_in[:, 0]] = psr_grid[vis_mask & in_mask]
# save_image(img, 'tmp.png', normalize=True) # save_image(img, 'tmp.png', normalize=True)
@ -98,78 +96,78 @@ def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name='video.
dir_psr_vis = out_dir dir_psr_vis = out_dir
os.makedirs(dir_psr_vis, exist_ok=True) os.makedirs(dir_psr_vis, exist_ok=True)
psr_grid = psr_grid.squeeze().detach().cpu().numpy() psr_grid = psr_grid.squeeze().detach().cpu().numpy()
axis = ['x', 'y', 'z']
s = psr_grid.shape[0] s = psr_grid.shape[0]
for idx in trange(s): for idx in trange(s):
my_dpi = 100 my_dpi = 100
plt.figure(figsize=(1000/my_dpi, 300/my_dpi), dpi=my_dpi) plt.figure(figsize=(1000 / my_dpi, 300 / my_dpi), dpi=my_dpi)
plt.subplot(1, 3, 1) plt.subplot(1, 3, 1)
plt.imshow(ndimage.rotate(psr_grid[idx], 180, mode='nearest'), cmap='nipy_spectral') plt.imshow(ndimage.rotate(psr_grid[idx], 180, mode="nearest"), cmap="nipy_spectral")
plt.clim(-1, 1) plt.clim(-1, 1)
plt.colorbar() plt.colorbar()
plt.title('x') plt.title("x")
plt.grid("off") plt.grid("off")
plt.axis("off") plt.axis("off")
plt.subplot(1, 3, 2) plt.subplot(1, 3, 2)
plt.imshow(ndimage.rotate(psr_grid[:, idx], 180, mode='nearest'), cmap='nipy_spectral') plt.imshow(ndimage.rotate(psr_grid[:, idx], 180, mode="nearest"), cmap="nipy_spectral")
plt.clim(-1, 1) plt.clim(-1, 1)
plt.colorbar() plt.colorbar()
plt.title('y') plt.title("y")
plt.grid("off") plt.grid("off")
plt.axis("off") plt.axis("off")
plt.subplot(1, 3, 3) plt.subplot(1, 3, 3)
plt.imshow(ndimage.rotate(psr_grid[:,:,idx], 90, mode='nearest'), cmap='nipy_spectral') plt.imshow(ndimage.rotate(psr_grid[:, :, idx], 90, mode="nearest"), cmap="nipy_spectral")
plt.clim(-1, 1) plt.clim(-1, 1)
plt.colorbar() plt.colorbar()
plt.title('z') plt.title("z")
plt.grid("off") plt.grid("off")
plt.axis("off") plt.axis("off")
plt.savefig(os.path.join(dir_psr_vis, f"{idx}"), pad_inches=0, dpi=100)
plt.savefig(os.path.join(dir_psr_vis, '{}'.format(idx)), pad_inches = 0, dpi=100)
plt.close() plt.close()
os.system("rm {}/{}".format(dir_psr_vis, out_video_name)) os.system(f"rm {dir_psr_vis}/{out_video_name}")
os.system("ffmpeg -framerate 25 -start_number 0 -i {}/%d.png -pix_fmt yuv420p -crf 17 {}/{}".format(dir_psr_vis, dir_psr_vis, out_video_name)) os.system(
f"ffmpeg -framerate 25 -start_number 0 -i {dir_psr_vis}/%d.png -pix_fmt yuv420p -crf 17 {dir_psr_vis}/{out_video_name}",
)
return None
return None
def visualize_mesh_phong(v, f, n, pose, img_size, name, device='cpu'):
def visualize_mesh_phong(v, f, n, pose, img_size, name, device="cpu"):
#! Mesh rendering using Phong shading model #! Mesh rendering using Phong shading model
_, mask, f_p, w = calc_inters_points(v, f, pose, img_size) _, mask, f_p, w = calc_inters_points(v, f, pose, img_size)
n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]] n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]]
n_inters = w[..., 0, None] * n_a.squeeze() + \ n_inters = w[..., 0, None] * n_a.squeeze() + w[..., 1, None] * n_b.squeeze() + w[..., 2, None] * n_c.squeeze()
w[..., 1, None] * n_b.squeeze() + \
w[..., 2, None] * n_c.squeeze()
n_inters = n_inters.detach().to(device) n_inters = n_inters.detach().to(device)
light_source = -pose.R@pose.T.squeeze() light_source = -pose.R @ pose.T.squeeze()
light = (light_source / light_source.norm(2)).permute(1, 0).to(device).float() light = (light_source / light_source.norm(2)).permute(1, 0).to(device).float()
diffuse_per = torch.Tensor([0.7,0.7,0.7]).float() diffuse_per = torch.Tensor([0.7, 0.7, 0.7]).float()
ambiant = torch.Tensor([0.3,0.3,0.3]).float() ambiant = torch.Tensor([0.3, 0.3, 0.3]).float()
diffuse = torch.mm(n_inters, light).clamp_min(0).repeat(1, 3) * diffuse_per.unsqueeze(0).to(device) diffuse = torch.mm(n_inters, light).clamp_min(0).repeat(1, 3) * diffuse_per.unsqueeze(0).to(device)
phong = torch.ones([img_size[0]*img_size[1], 3]).to(device) phong = torch.ones([img_size[0] * img_size[1], 3]).to(device)
phong[mask] = (ambiant.unsqueeze(0).to(device) + diffuse).clamp_max(1.0) phong[mask] = (ambiant.unsqueeze(0).to(device) + diffuse).clamp_max(1.0)
pp = phong.reshape(img_size[0], img_size[1], -1) pp = phong.reshape(img_size[0], img_size[1], -1)
save_image(pp.permute(2, 0, 1), name) save_image(pp.permute(2, 0, 1), name)
def render_rgb(v, f, n, pose, renderer, img_size, mask_gt=None, ray=None, fea_grid=None): def render_rgb(v, f, n, pose, renderer, img_size, mask_gt=None, ray=None, fea_grid=None):
p_inters, mask, f_p, w = calc_inters_points(v.detach(), f, pose, img_size, mask_gt=mask_gt) p_inters, mask, f_p, w = calc_inters_points(v.detach(), f, pose, img_size, mask_gt=mask_gt)
# normals for p_inters # normals for p_inters
n_inters = None n_inters = None
if n is not None: if n is not None:
n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]] n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]]
n_inters = w[..., 0, None] * n_a.squeeze() + \ n_inters = w[..., 0, None] * n_a.squeeze() + w[..., 1, None] * n_b.squeeze() + w[..., 2, None] * n_c.squeeze()
w[..., 1, None] * n_b.squeeze() + \ if ray is not None:
w[..., 2, None] * n_c.squeeze() ray = ray.squeeze()[mask]
if ray is not None:
ray = ray.squeeze()[mask]
fea = None
if fea_grid is not None:
fea = grid_interp(fea_grid, (p_inters.detach()[None] + 1) / 2).squeeze()
# use MLP to regress color fea = None
color_pred = renderer(p_inters, normals=n_inters, view_dirs=ray, feature_vectors=fea).squeeze() if fea_grid is not None:
fea = grid_interp(fea_grid, (p_inters.detach()[None] + 1) / 2).squeeze()
return color_pred, mask # use MLP to regress color
color_pred = renderer(p_inters, normals=n_inters, view_dirs=ray, feature_vectors=fea).squeeze()
return color_pred, mask

189
train.py
View file

@ -4,182 +4,185 @@ abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath) dname = os.path.dirname(abspath)
os.chdir(dname) os.chdir(dname)
import argparse
import shutil
import time
import numpy as np
import torch import torch
import torch.optim as optim import torch.optim as optim
import open3d as o3d
import numpy as np; np.set_printoptions(precision=4)
import shutil, argparse, time
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from src import config from src import config
from src.data import collate_remove_none, collate_stack_together, worker_init_fn from src.data import collate_remove_none, worker_init_fn
from src.training import Trainer
from src.model import Encode2Points from src.model import Encode2Points
from src.utils import load_config, initialize_logger, \ from src.training import Trainer
AverageMeter, load_model_manual from src.utils import AverageMeter, initialize_logger, load_config, load_model_manual
np.set_printoptions(precision=4)
def main(): def main():
parser = argparse.ArgumentParser(description='MNIST toy experiment') parser = argparse.ArgumentParser(description="MNIST toy experiment")
parser.add_argument('config', type=str, help='Path to config file.') parser.add_argument("config", type=str, help="Path to config file.")
parser.add_argument('--no_cuda', action='store_true', default=False, parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
help='disables CUDA training') parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
args = parser.parse_args() args = parser.parse_args()
cfg = load_config(args.config, 'configs/default.yaml') cfg = load_config(args.config, "configs/default.yaml")
use_cuda = not args.no_cuda and torch.cuda.is_available() use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu") device = torch.device("cuda" if use_cuda else "cpu")
input_type = cfg['data']['input_type'] cfg["data"]["input_type"]
batch_size = cfg['train']['batch_size'] batch_size = cfg["train"]["batch_size"]
model_selection_metric = cfg['train']['model_selection_metric'] model_selection_metric = cfg["train"]["model_selection_metric"]
# PYTORCH VERSION > 1.0.0 # PYTORCH VERSION > 1.0.0
assert(float(torch.__version__.split('.')[-3]) > 0) assert float(torch.__version__.split(".")[-3]) > 0
# boiler-plate # boiler-plate
if cfg['train']['timestamp']: if cfg["train"]["timestamp"]:
cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S") cfg["train"]["out_dir"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
logger = initialize_logger(cfg) logger = initialize_logger(cfg)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
shutil.copyfile(args.config, os.path.join(cfg['train']['out_dir'], 'config.yaml')) shutil.copyfile(args.config, os.path.join(cfg["train"]["out_dir"], "config.yaml"))
logger.info("using GPU: " + torch.cuda.get_device_name(0)) logger.info("using GPU: " + torch.cuda.get_device_name(0))
# TensorboardX writer # TensorboardX writer
tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log") tblogdir = os.path.join(cfg["train"]["out_dir"], "tensorboard_log")
if not os.path.exists(tblogdir): if not os.path.exists(tblogdir):
os.makedirs(tblogdir, exist_ok=True) os.makedirs(tblogdir, exist_ok=True)
writer = SummaryWriter(log_dir=tblogdir) writer = SummaryWriter(log_dir=tblogdir)
inputs = None inputs = None
train_dataset = config.get_dataset('train', cfg) train_dataset = config.get_dataset("train", cfg)
val_dataset = config.get_dataset('val', cfg) val_dataset = config.get_dataset("val", cfg)
vis_dataset = config.get_dataset('vis', cfg) vis_dataset = config.get_dataset("vis", cfg)
collate_fn = collate_remove_none collate_fn = collate_remove_none
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, num_workers=cfg['train']['n_workers'], shuffle=True, train_dataset,
collate_fn=collate_fn, batch_size=batch_size,
worker_init_fn=worker_init_fn) num_workers=cfg["train"]["n_workers"],
shuffle=True,
collate_fn=collate_fn,
worker_init_fn=worker_init_fn,
)
val_loader = torch.utils.data.DataLoader( val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False, val_dataset,
collate_fn=collate_remove_none, batch_size=1,
worker_init_fn=worker_init_fn) num_workers=cfg["train"]["n_workers_val"],
shuffle=False,
collate_fn=collate_remove_none,
worker_init_fn=worker_init_fn,
)
vis_loader = torch.utils.data.DataLoader( vis_loader = torch.utils.data.DataLoader(
vis_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False, vis_dataset,
collate_fn=collate_fn, batch_size=1,
worker_init_fn=worker_init_fn) num_workers=cfg["train"]["n_workers_val"],
shuffle=False,
collate_fn=collate_fn,
worker_init_fn=worker_init_fn,
)
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(Encode2Points(cfg)).to(device) model = torch.nn.DataParallel(Encode2Points(cfg)).to(device)
else: else:
model = Encode2Points(cfg).to(device) model = Encode2Points(cfg).to(device)
n_parameter = sum(p.numel() for p in model.parameters() if p.requires_grad) n_parameter = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info('Number of parameters: %d'% n_parameter) logger.info("Number of parameters: %d" % n_parameter)
# load model # load model
try: try:
# load model # load model
state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt')) state_dict = torch.load(os.path.join(cfg["train"]["out_dir"], "model.pt"))
load_model_manual(state_dict['state_dict'], model) load_model_manual(state_dict["state_dict"], model)
out = "Load model from iteration %d" % state_dict.get('it', 0) out = "Load model from iteration %d" % state_dict.get("it", 0)
logger.info(out) logger.info(out)
# load point cloud # load point cloud
except: except:
state_dict = dict() state_dict = dict()
metric_val_best = state_dict.get(
'loss_val_best', np.inf)
logger.info('Current best validation metric (%s): %.8f' metric_val_best = state_dict.get("loss_val_best", np.inf)
% (model_selection_metric, metric_val_best))
LR = float(cfg['train']['lr']) logger.info(f"Current best validation metric ({model_selection_metric}): {metric_val_best:.8f}")
LR = float(cfg["train"]["lr"])
optimizer = optim.Adam(model.parameters(), lr=LR) optimizer = optim.Adam(model.parameters(), lr=LR)
start_epoch = state_dict.get('epoch', -1) start_epoch = state_dict.get("epoch", -1)
it = state_dict.get('it', -1) it = state_dict.get("it", -1)
trainer = Trainer(cfg, optimizer, device=device) trainer = Trainer(cfg, optimizer, device=device)
runtime = {} runtime = {}
runtime['all'] = AverageMeter() runtime["all"] = AverageMeter()
# training loop
for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1):
# training loop
for epoch in range(start_epoch + 1, cfg["train"]["total_epochs"] + 1):
for batch in train_loader: for batch in train_loader:
it += 1 it += 1
start = time.time() start = time.time()
loss, loss_each = trainer.train_step(inputs, batch, model) loss, loss_each = trainer.train_step(inputs, batch, model)
# measure elapsed time # measure elapsed time
end = time.time() end = time.time()
runtime['all'].update(end - start) runtime["all"].update(end - start)
if it % cfg['train']['print_every'] == 0: if it % cfg["train"]["print_every"] == 0:
log_text = ('[Epoch %02d] it=%d, loss=%.4f') %(epoch, it, loss) log_text = ("[Epoch %02d] it=%d, loss=%.4f") % (epoch, it, loss)
writer.add_scalar('train/loss', loss, it) writer.add_scalar("train/loss", loss, it)
if loss_each is not None: if loss_each is not None:
for k, l in loss_each.items(): for k, l in loss_each.items():
if l.item() != 0.: if l.item() != 0.0:
log_text += (' loss_%s=%.4f') % (k, l.item()) log_text += f" loss_{k}={l.item():.4f}"
writer.add_scalar('train/%s' % k, l, it) writer.add_scalar("train/%s" % k, l, it)
log_text += (' time=%.3f / %.2f') % (runtime['all'].val, runtime['all'].sum) log_text += (" time={:.3f} / {:.2f}").format(runtime["all"].val, runtime["all"].sum)
logger.info(log_text) logger.info(log_text)
if (it>0)& (it % cfg['train']['visualize_every'] == 0): if (it > 0) & (it % cfg["train"]["visualize_every"] == 0):
for i, batch_vis in enumerate(vis_loader): for i, batch_vis in enumerate(vis_loader):
trainer.save(model, batch_vis, it, i) trainer.save(model, batch_vis, it, i)
if i >= 4: if i >= 4:
break break
logger.info('Saved mesh and pointcloud') logger.info("Saved mesh and pointcloud")
# run validation # run validation
if it > 0 and (it % cfg['train']['validate_every']) == 0: if it > 0 and (it % cfg["train"]["validate_every"]) == 0:
eval_dict = trainer.evaluate(val_loader, model) eval_dict = trainer.evaluate(val_loader, model)
metric_val = eval_dict[model_selection_metric] metric_val = eval_dict[model_selection_metric]
logger.info('Validation metric (%s): %.4f' logger.info(f"Validation metric ({model_selection_metric}): {metric_val:.4f}")
% (model_selection_metric, metric_val))
for k, v in eval_dict.items():
writer.add_scalar('val/%s' % k, v, it)
if -(metric_val - metric_val_best) >= 0: for k, v in eval_dict.items():
writer.add_scalar("val/%s" % k, v, it)
if -(metric_val - metric_val_best) >= 0:
metric_val_best = metric_val metric_val_best = metric_val
logger.info('New best model (loss %.4f)' % metric_val_best) logger.info("New best model (loss %.4f)" % metric_val_best)
state = {'epoch': epoch, state = {"epoch": epoch, "it": it, "loss_val_best": metric_val_best}
'it': it, state["state_dict"] = model.state_dict()
'loss_val_best': metric_val_best} torch.save(state, os.path.join(cfg["train"]["out_dir"], "model_best.pt"))
state['state_dict'] = model.state_dict()
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model_best.pt'))
# save checkpoint # save checkpoint
if (epoch > 0) & (it % cfg['train']['checkpoint_every'] == 0): if (epoch > 0) & (it % cfg["train"]["checkpoint_every"] == 0):
state = {'epoch': epoch, state = {"epoch": epoch, "it": it, "loss_val_best": metric_val_best}
'it': it, state["state_dict"] = model.state_dict()
'loss_val_best': metric_val_best}
pcl = None
state['state_dict'] = model.state_dict()
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt'))
if (it % cfg['train']['backup_every'] == 0): torch.save(state, os.path.join(cfg["train"]["out_dir"], "model.pt"))
torch.save(state, os.path.join(cfg['train']['dir_model'], '%04d' % it + '.pt'))
if it % cfg["train"]["backup_every"] == 0:
torch.save(state, os.path.join(cfg["train"]["dir_model"], "%04d" % it + ".pt"))
logger.info("Backup model at iteration %d" % it) logger.info("Backup model at iteration %d" % it)
logger.info("Save new model at iteration %d" % it) logger.info("Save new model at iteration %d" % it)
done=time.time() time.time()
if __name__ == '__main__':
main() if __name__ == "__main__":
main()