🎨 apply auto formatting

This commit is contained in:
Laurent FAINSIN 2023-05-26 14:59:53 +02:00
parent ee1be08bce
commit 41b213e4a5
28 changed files with 2189 additions and 2048 deletions

View file

@ -1,155 +1,145 @@
import argparse
import os
import numpy as np
import pandas as pd
import torch
import trimesh
from torch.utils.data import Dataset, DataLoader
import numpy as np; np.set_printoptions(precision=4)
import shutil, argparse, time, os
import pandas as pd
from src.data import collate_remove_none, collate_stack_together, worker_init_fn
from src.training import Trainer
from src.model import Encode2Points
from src.data import PointCloudField, IndexField, Shapes3dDataset
from src.utils import load_config, load_pointcloud
from src.eval import MeshEvaluator
from tqdm import tqdm
from pdb import set_trace as st
from src.data import IndexField, PointCloudField, Shapes3dDataset
from src.eval import MeshEvaluator
from src.utils import load_config, load_pointcloud
np.set_printoptions(precision=4)
def main():
parser = argparse.ArgumentParser(description='MNIST toy experiment')
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--no_cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.')
parser = argparse.ArgumentParser(description="MNIST toy experiment")
parser.add_argument("config", type=str, help="Path to config file.")
parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument("--iter", type=int, metavar="S", help="the training iteration to be evaluated.")
args = parser.parse_args()
cfg = load_config(args.config, 'configs/default.yaml')
cfg = load_config(args.config, "configs/default.yaml")
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
data_type = cfg['data']['data_type']
torch.device("cuda" if use_cuda else "cpu")
cfg["data"]["data_type"]
# Shorthands
out_dir = cfg['train']['out_dir']
generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir'])
out_dir = cfg["train"]["out_dir"]
generation_dir = os.path.join(out_dir, cfg["generation"]["generation_dir"])
if cfg['generation'].get('iter', 0)!=0:
generation_dir += '_%04d'%cfg['generation']['iter']
if cfg["generation"].get("iter", 0) != 0:
generation_dir += "_%04d" % cfg["generation"]["iter"]
elif args.iter is not None:
generation_dir += '_%04d'%args.iter
print('Evaluate meshes under %s'%generation_dir)
out_file = os.path.join(generation_dir, 'eval_meshes_full.pkl')
out_file_class = os.path.join(generation_dir, 'eval_meshes.csv')
# PYTORCH VERSION > 1.0.0
assert(float(torch.__version__.split('.')[-3]) > 0)
generation_dir += "_%04d" % args.iter
pointcloud_field = PointCloudField(cfg['data']['pointcloud_file'])
print("Evaluate meshes under %s" % generation_dir)
out_file = os.path.join(generation_dir, "eval_meshes_full.pkl")
out_file_class = os.path.join(generation_dir, "eval_meshes.csv")
# PYTORCH VERSION > 1.0.0
assert float(torch.__version__.split(".")[-3]) > 0
pointcloud_field = PointCloudField(cfg["data"]["pointcloud_file"])
fields = {
'pointcloud': pointcloud_field,
'idx': IndexField(),
"pointcloud": pointcloud_field,
"idx": IndexField(),
}
print('Test split: ', cfg['data']['test_split'])
print("Test split: ", cfg["data"]["test_split"])
dataset_folder = cfg['data']['path']
dataset_folder = cfg["data"]["path"]
dataset = Shapes3dDataset(
dataset_folder, fields,
cfg['data']['test_split'],
categories=cfg['data']['class'], cfg=cfg)
dataset_folder, fields, cfg["data"]["test_split"], categories=cfg["data"]["class"], cfg=cfg,
)
# Loader
test_loader = torch.utils.data.DataLoader(
dataset, batch_size=1, num_workers=0, shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False)
# Evaluator
evaluator = MeshEvaluator(n_points=100000)
eval_dicts = []
print('Evaluating meshes...')
for it, data in enumerate(tqdm(test_loader)):
eval_dicts = []
print("Evaluating meshes...")
for _it, data in enumerate(tqdm(test_loader)):
if data is None:
print('Invalid data.')
print("Invalid data.")
continue
mesh_dir = os.path.join(generation_dir, 'meshes')
pointcloud_dir = os.path.join(generation_dir, 'pointcloud')
mesh_dir = os.path.join(generation_dir, "meshes")
pointcloud_dir = os.path.join(generation_dir, "pointcloud")
# Get index etc.
idx = data['idx'].item()
idx = data["idx"].item()
try:
model_dict = dataset.get_model_dict(idx)
except AttributeError:
model_dict = {'model': str(idx), 'category': 'n/a'}
model_dict = {"model": str(idx), "category": "n/a"}
modelname = model_dict['model']
category_id = model_dict['category']
modelname = model_dict["model"]
category_id = model_dict["category"]
try:
category_name = dataset.metadata[category_id].get('name', 'n/a')
category_name = dataset.metadata[category_id].get("name", "n/a")
except AttributeError:
category_name = 'n/a'
category_name = "n/a"
if category_id != 'n/a':
if category_id != "n/a":
mesh_dir = os.path.join(mesh_dir, category_id)
pointcloud_dir = os.path.join(pointcloud_dir, category_id)
# Evaluate
pointcloud_tgt = data['pointcloud'].squeeze(0).numpy()
normals_tgt = data['pointcloud.normals'].squeeze(0).numpy()
pointcloud_tgt = data["pointcloud"].squeeze(0).numpy()
normals_tgt = data["pointcloud.normals"].squeeze(0).numpy()
eval_dict = {
'idx': idx,
'class id': category_id,
'class name': category_name,
'modelname':modelname,
"idx": idx,
"class id": category_id,
"class name": category_name,
"modelname": modelname,
}
eval_dicts.append(eval_dict)
# Evaluate mesh
if cfg['test']['eval_mesh']:
mesh_file = os.path.join(mesh_dir, '%s.off' % modelname)
if cfg["test"]["eval_mesh"]:
mesh_file = os.path.join(mesh_dir, "%s.off" % modelname)
if os.path.exists(mesh_file):
mesh = trimesh.load(mesh_file, process=False)
eval_dict_mesh = evaluator.eval_mesh(
mesh, pointcloud_tgt, normals_tgt)
eval_dict_mesh = evaluator.eval_mesh(mesh, pointcloud_tgt, normals_tgt)
for k, v in eval_dict_mesh.items():
eval_dict[k + ' (mesh)'] = v
eval_dict[k + " (mesh)"] = v
else:
print('Warning: mesh does not exist: %s' % mesh_file)
print("Warning: mesh does not exist: %s" % mesh_file)
# Evaluate point cloud
if cfg['test']['eval_pointcloud']:
pointcloud_file = os.path.join(
pointcloud_dir, '%s.ply' % modelname)
if cfg["test"]["eval_pointcloud"]:
pointcloud_file = os.path.join(pointcloud_dir, "%s.ply" % modelname)
if os.path.exists(pointcloud_file):
pointcloud = load_pointcloud(pointcloud_file).astype(np.float32)
eval_dict_pcl = evaluator.eval_pointcloud(
pointcloud, pointcloud_tgt)
eval_dict_pcl = evaluator.eval_pointcloud(pointcloud, pointcloud_tgt)
for k, v in eval_dict_pcl.items():
eval_dict[k + ' (pcl)'] = v
eval_dict[k + " (pcl)"] = v
else:
print('Warning: pointcloud does not exist: %s'
% pointcloud_file)
print("Warning: pointcloud does not exist: %s" % pointcloud_file)
# Create pandas dataframe and save
eval_df = pd.DataFrame(eval_dicts)
eval_df.set_index(['idx'], inplace=True)
eval_df.set_index(["idx"], inplace=True)
eval_df.to_pickle(out_file)
# Create CSV file with main statistics
eval_df_class = eval_df.groupby(by=['class name']).mean()
eval_df_class.loc['mean'] = eval_df_class.mean()
eval_df_class = eval_df.groupby(by=["class name"]).mean()
eval_df_class.loc["mean"] = eval_df_class.mean()
eval_df_class.to_csv(out_file_class)
# Print results
print(eval_df_class)
if __name__ == '__main__':
main()
if __name__ == "__main__":
main()

View file

@ -1,155 +1,164 @@
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np; np.set_printoptions(precision=4)
import shutil, argparse, time, os
import pandas as pd
import argparse
import os
import shutil
from collections import defaultdict
from src import config
from src.utils import mc_from_psr, export_mesh, export_pointcloud
from src.dpsr import DPSR
from src.training import Trainer
from src.model import Encode2Points
from src.utils import load_config, load_model_manual, scale2onet, is_url, load_url
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from pdb import set_trace as st
from src import config
from src.dpsr import DPSR
from src.model import Encode2Points
from src.utils import (
export_mesh,
export_pointcloud,
is_url,
load_config,
load_model_manual,
load_url,
mc_from_psr,
scale2onet,
)
np.set_printoptions(precision=4)
def main():
parser = argparse.ArgumentParser(description='MNIST toy experiment')
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--no_cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.')
parser = argparse.ArgumentParser(description="MNIST toy experiment")
parser.add_argument("config", type=str, help="Path to config file.")
parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument("--iter", type=int, metavar="S", help="the training iteration to be evaluated.")
args = parser.parse_args()
cfg = load_config(args.config, 'configs/default.yaml')
cfg = load_config(args.config, "configs/default.yaml")
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
data_type = cfg['data']['data_type']
input_type = cfg['data']['input_type']
vis_n_outputs = cfg['generation']['vis_n_outputs']
cfg["data"]["data_type"]
cfg["data"]["input_type"]
vis_n_outputs = cfg["generation"]["vis_n_outputs"]
if vis_n_outputs is None:
vis_n_outputs = -1
# Shorthands
out_dir = cfg['train']['out_dir']
out_dir = cfg["train"]["out_dir"]
if not out_dir:
os.makedirs(out_dir)
generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir'])
out_time_file = os.path.join(generation_dir, 'time_generation_full.pkl')
out_time_file_class = os.path.join(generation_dir, 'time_generation.pkl')
generation_dir = os.path.join(out_dir, cfg["generation"]["generation_dir"])
out_time_file = os.path.join(generation_dir, "time_generation_full.pkl")
out_time_file_class = os.path.join(generation_dir, "time_generation.pkl")
# PYTORCH VERSION > 1.0.0
assert(float(torch.__version__.split('.')[-3]) > 0)
assert float(torch.__version__.split(".")[-3]) > 0
dataset = config.get_dataset('test', cfg, return_idx=True)
test_loader = torch.utils.data.DataLoader(
dataset, batch_size=1, num_workers=0, shuffle=False)
dataset = config.get_dataset("test", cfg, return_idx=True)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False)
model = Encode2Points(cfg).to(device)
# load model
try:
if is_url(cfg['test']['model_file']):
state_dict = load_url(cfg['test']['model_file'])
elif cfg['generation'].get('iter', 0)!=0:
state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% cfg['generation']['iter']))
generation_dir += '_%04d'%cfg['generation']['iter']
if is_url(cfg["test"]["model_file"]):
state_dict = load_url(cfg["test"]["model_file"])
elif cfg["generation"].get("iter", 0) != 0:
state_dict = torch.load(os.path.join(out_dir, "model", "%04d.pt" % cfg["generation"]["iter"]))
generation_dir += "_%04d" % cfg["generation"]["iter"]
elif args.iter is not None:
state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% args.iter))
state_dict = torch.load(os.path.join(out_dir, "model", "%04d.pt" % args.iter))
else:
state_dict = torch.load(os.path.join(out_dir, 'model_best.pt'))
state_dict = torch.load(os.path.join(out_dir, "model_best.pt"))
load_model_manual(state_dict['state_dict'], model)
load_model_manual(state_dict["state_dict"], model)
except:
print('Model loading error. Exiting.')
print("Model loading error. Exiting.")
exit()
# Generator
generator = config.get_generator(model, cfg, device=device)
# Determine what to generate
generate_mesh = cfg['generation']['generate_mesh']
generate_pointcloud = cfg['generation']['generate_pointcloud']
generate_mesh = cfg["generation"]["generate_mesh"]
generate_pointcloud = cfg["generation"]["generate_pointcloud"]
# Statistics
time_dicts = []
# Generate
model.eval()
dpsr = DPSR(res=(cfg['generation']['psr_resolution'],
cfg['generation']['psr_resolution'],
cfg['generation']['psr_resolution']),
sig= cfg['generation']['psr_sigma']).to(device)
dpsr = DPSR(
res=(
cfg["generation"]["psr_resolution"],
cfg["generation"]["psr_resolution"],
cfg["generation"]["psr_resolution"],
),
sig=cfg["generation"]["psr_sigma"],
).to(device)
# Count how many models already created
model_counter = defaultdict(int)
print('Generating...')
for it, data in enumerate(tqdm(test_loader)):
print("Generating...")
for _it, data in enumerate(tqdm(test_loader)):
# Output folders
mesh_dir = os.path.join(generation_dir, 'meshes')
in_dir = os.path.join(generation_dir, 'input')
pointcloud_dir = os.path.join(generation_dir, 'pointcloud')
generation_vis_dir = os.path.join(generation_dir, 'vis', )
mesh_dir = os.path.join(generation_dir, "meshes")
in_dir = os.path.join(generation_dir, "input")
pointcloud_dir = os.path.join(generation_dir, "pointcloud")
generation_vis_dir = os.path.join(generation_dir, "vis")
# Get index etc.
idx = data['idx'].item()
idx = data["idx"].item()
try:
model_dict = dataset.get_model_dict(idx)
except AttributeError:
model_dict = {'model': str(idx), 'category': 'n/a'}
model_dict = {"model": str(idx), "category": "n/a"}
modelname = model_dict['model']
category_id = model_dict['category']
modelname = model_dict["model"]
category_id = model_dict["category"]
try:
category_name = dataset.metadata[category_id].get('name', 'n/a')
category_name = dataset.metadata[category_id].get("name", "n/a")
except AttributeError:
category_name = 'n/a'
if category_id != 'n/a':
category_name = "n/a"
if category_id != "n/a":
mesh_dir = os.path.join(mesh_dir, str(category_id))
pointcloud_dir = os.path.join(pointcloud_dir, str(category_id))
in_dir = os.path.join(in_dir, str(category_id))
folder_name = str(category_id)
if category_name != 'n/a':
folder_name = str(folder_name) + '_' + category_name.split(',')[0]
if category_name != "n/a":
folder_name = str(folder_name) + "_" + category_name.split(",")[0]
generation_vis_dir = os.path.join(generation_vis_dir, folder_name)
# Create directories if necessary
if vis_n_outputs >= 0 and not os.path.exists(generation_vis_dir):
os.makedirs(generation_vis_dir)
if generate_mesh and not os.path.exists(mesh_dir):
os.makedirs(mesh_dir)
if generate_pointcloud and not os.path.exists(pointcloud_dir):
os.makedirs(pointcloud_dir)
if not os.path.exists(in_dir):
os.makedirs(in_dir)
# Timing dict
time_dict = {
'idx': idx,
'class id': category_id,
'class name': category_name,
'modelname':modelname,
"idx": idx,
"class id": category_id,
"class name": category_name,
"modelname": modelname,
}
time_dicts.append(time_dict)
# Generate outputs
out_file_dict = {}
if generate_mesh:
#! deploy the generator to a separate class
out = generator.generate_mesh(data)
@ -158,60 +167,56 @@ def main():
time_dict.update(stats_dict)
# Write output
mesh_out_file = os.path.join(mesh_dir, '%s.off' % modelname)
mesh_out_file = os.path.join(mesh_dir, "%s.off" % modelname)
export_mesh(mesh_out_file, scale2onet(v), f)
out_file_dict['mesh'] = mesh_out_file
out_file_dict["mesh"] = mesh_out_file
if generate_pointcloud:
pointcloud_out_file = os.path.join(
pointcloud_dir, '%s.ply' % modelname)
pointcloud_out_file = os.path.join(pointcloud_dir, "%s.ply" % modelname)
export_pointcloud(pointcloud_out_file, scale2onet(points), normals)
out_file_dict['pointcloud'] = pointcloud_out_file
if cfg['generation']['copy_input']:
inputs_path = os.path.join(in_dir, '%s.ply' % modelname)
p = data.get('inputs').to(device)
out_file_dict["pointcloud"] = pointcloud_out_file
if cfg["generation"]["copy_input"]:
inputs_path = os.path.join(in_dir, "%s.ply" % modelname)
p = data.get("inputs").to(device)
export_pointcloud(inputs_path, scale2onet(p))
out_file_dict['in'] = inputs_path
out_file_dict["in"] = inputs_path
# Copy to visualization directory for first vis_n_output samples
c_it = model_counter[category_id]
if c_it < vis_n_outputs:
# Save output files
img_name = '%02d.off' % c_it
"%02d.off" % c_it
for k, filepath in out_file_dict.items():
ext = os.path.splitext(filepath)[1]
out_file = os.path.join(generation_vis_dir, '%02d_%s%s'
% (c_it, k, ext))
out_file = os.path.join(generation_vis_dir, "%02d_%s%s" % (c_it, k, ext))
shutil.copyfile(filepath, out_file)
# Also generate oracle meshes
if cfg['generation']['exp_oracle']:
points_gt = data.get('gt_points').to(device)
normals_gt = data.get('gt_points.normals').to(device)
psr_gt = dpsr(points_gt, normals_gt)
v, f, _ = mc_from_psr(psr_gt,
zero_level=cfg['data']['zero_level'])
out_file = os.path.join(generation_vis_dir, '%02d_%s%s'
% (c_it, 'mesh_oracle', '.off'))
export_mesh(out_file, scale2onet(v), f)
model_counter[category_id] += 1
# Also generate oracle meshes
if cfg["generation"]["exp_oracle"]:
points_gt = data.get("gt_points").to(device)
normals_gt = data.get("gt_points.normals").to(device)
psr_gt = dpsr(points_gt, normals_gt)
v, f, _ = mc_from_psr(psr_gt, zero_level=cfg["data"]["zero_level"])
out_file = os.path.join(generation_vis_dir, "%02d_%s%s" % (c_it, "mesh_oracle", ".off"))
export_mesh(out_file, scale2onet(v), f)
model_counter[category_id] += 1
# Create pandas dataframe and save
time_df = pd.DataFrame(time_dicts)
time_df.set_index(['idx'], inplace=True)
time_df.set_index(["idx"], inplace=True)
time_df.to_pickle(out_time_file)
# Create pickle files with main statistics
time_df_class = time_df.groupby(by=['class name']).mean()
time_df_class.loc['mean'] = time_df_class.mean()
time_df_class = time_df.groupby(by=["class name"]).mean()
time_df_class.loc["mean"] = time_df_class.mean()
time_df_class.to_pickle(out_time_file_class)
# Print results
print('Timings [s]:')
print("Timings [s]:")
print(time_df_class)
if __name__ == '__main__':
main()
if __name__ == "__main__":
main()

336
optim.py
View file

@ -1,79 +1,86 @@
import argparse
import glob
import os
import shutil
import time
import numpy as np
import open3d as o3d
import torch
import trimesh
import shutil, argparse, time, os, glob
import numpy as np; np.set_printoptions(precision=4)
import open3d as o3d
from plyfile import PlyData
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures import Meshes
from skimage import measure
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image
from torchvision.io import write_video
from src.optimization import Trainer
from src.utils import load_config, update_config, initialize_logger, \
get_learning_rate_schedules, adjust_learning_rate, AverageMeter,\
update_optimizer, export_pointcloud
from skimage import measure
from plyfile import PlyData
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.structures import Meshes
from src.utils import (
AverageMeter,
adjust_learning_rate,
export_pointcloud,
get_learning_rate_schedules,
initialize_logger,
load_config,
update_config,
update_optimizer,
)
np.set_printoptions(precision=4)
def main():
parser = argparse.ArgumentParser(description='MNIST toy experiment')
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--no_cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1457, metavar='S',
help='random seed')
args, unknown = parser.parse_known_args()
cfg = load_config(args.config, 'configs/default.yaml')
parser = argparse.ArgumentParser(description="MNIST toy experiment")
parser.add_argument("config", type=str, help="Path to config file.")
parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
parser.add_argument("--seed", type=int, default=1457, metavar="S", help="random seed")
args, unknown = parser.parse_known_args()
cfg = load_config(args.config, "configs/default.yaml")
cfg = update_config(cfg, unknown)
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
data_type = cfg['data']['data_type']
data_class = cfg['data']['class']
data_type = cfg["data"]["data_type"]
cfg["data"]["class"]
print(cfg['train']['out_dir'])
print(cfg["train"]["out_dir"])
# PYTORCH VERSION > 1.0.0
assert(float(torch.__version__.split('.')[-3]) > 0)
assert float(torch.__version__.split(".")[-3]) > 0
# boiler-plate
if cfg['train']['timestamp']:
cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S")
if cfg["train"]["timestamp"]:
cfg["train"]["out_dir"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
logger = initialize_logger(cfg)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
shutil.copyfile(args.config,
os.path.join(cfg['train']['out_dir'], 'config.yaml'))
shutil.copyfile(args.config, os.path.join(cfg["train"]["out_dir"], "config.yaml"))
# tensorboardX writer
tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log")
tblogdir = os.path.join(cfg["train"]["out_dir"], "tensorboard_log")
if not os.path.exists(tblogdir):
os.makedirs(tblogdir)
writer = SummaryWriter(log_dir=tblogdir)
SummaryWriter(log_dir=tblogdir)
# initialize o3d visualizer
vis = None
if cfg['train']['o3d_show']:
if cfg["train"]["o3d_show"]:
vis = o3d.visualization.Visualizer()
vis.create_window(width=cfg['train']['o3d_window_size'],
height=cfg['train']['o3d_window_size'])
vis.create_window(width=cfg["train"]["o3d_window_size"], height=cfg["train"]["o3d_window_size"])
# initialize dataset
if data_type == 'point':
if cfg['data']['object_id'] != -1:
data_paths = sorted(glob.glob(cfg['data']['data_path']))
data_path = data_paths[cfg['data']['object_id']]
print('Loaded %d/%d object' % (cfg['data']['object_id']+1, len(data_paths)))
if data_type == "point":
if cfg["data"]["object_id"] != -1:
data_paths = sorted(glob.glob(cfg["data"]["data_path"]))
data_path = data_paths[cfg["data"]["object_id"]]
print("Loaded %d/%d object" % (cfg["data"]["object_id"] + 1, len(data_paths)))
else:
data_path = cfg['data']['data_path']
print('Data loaded')
ext = data_path.split('.')[-1]
if ext == 'obj': # have GT mesh
data_path = cfg["data"]["data_path"]
print("Data loaded")
ext = data_path.split(".")[-1]
if ext == "obj": # have GT mesh
mesh = load_objs_as_meshes([data_path], device=device)
# scale the mesh into unit cube
verts = mesh.verts_packed()
@ -81,20 +88,15 @@ def main():
center = verts.mean(0)
mesh.offset_verts_(-center.expand(N, 3))
scale = max((verts - center).abs().max(0)[0])
mesh.scale_verts_((1.0 / float(scale)))
mesh.scale_verts_(1.0 / float(scale))
# important for our DPSR to have the range in [0, 1), not reaching 1
mesh.scale_verts_(0.9)
target_pts, target_normals = sample_points_from_meshes(mesh,
num_samples=200000, return_normals=True)
elif ext == 'ply': # only have the point cloud
target_pts, target_normals = sample_points_from_meshes(mesh, num_samples=200000, return_normals=True)
elif ext == "ply": # only have the point cloud
plydata = PlyData.read(data_path)
vertices = np.stack([plydata['vertex']['x'],
plydata['vertex']['y'],
plydata['vertex']['z']], axis=1)
normals = np.stack([plydata['vertex']['nx'],
plydata['vertex']['ny'],
plydata['vertex']['nz']], axis=1)
vertices = np.stack([plydata["vertex"]["x"], plydata["vertex"]["y"], plydata["vertex"]["z"]], axis=1)
normals = np.stack([plydata["vertex"]["nx"], plydata["vertex"]["ny"], plydata["vertex"]["nz"]], axis=1)
N = vertices.shape[0]
center = vertices.mean(0)
scale = np.max(np.max(np.abs(vertices - center), axis=0))
@ -104,212 +106,212 @@ def main():
target_pts = torch.tensor(vertices, device=device)[None].float()
target_normals = torch.tensor(normals, device=device)[None].float()
mesh = None # no GT mesh
mesh = None # no GT mesh
if not torch.is_tensor(center):
center = torch.from_numpy(center)
if not torch.is_tensor(scale):
scale = torch.from_numpy(np.array([scale]))
data = {'target_points': target_pts,
'target_normals': target_normals, # normals are never used
'gt_mesh': mesh}
data = {
"target_points": target_pts,
"target_normals": target_normals, # normals are never used
"gt_mesh": mesh,
}
else:
raise NotImplementedError
# save the input point cloud
if 'target_points' in data.keys():
outdir_pcl = os.path.join(cfg['train']['out_dir'], 'target_pcl.ply')
if 'target_normals' in data.keys():
export_pointcloud(outdir_pcl, data['target_points'], data['target_normals'])
if "target_points" in data.keys():
outdir_pcl = os.path.join(cfg["train"]["out_dir"], "target_pcl.ply")
if "target_normals" in data.keys():
export_pointcloud(outdir_pcl, data["target_points"], data["target_normals"])
else:
export_pointcloud(outdir_pcl, data['target_points'])
export_pointcloud(outdir_pcl, data["target_points"])
# save oracle PSR mesh (mesh from our PSR using GT point+normals)
if data.get('gt_mesh') is not None:
gt_verts, gt_faces = data['gt_mesh'].get_mesh_verts_faces(0)
pts_gt, norms_gt = sample_points_from_meshes(data['gt_mesh'],
num_samples=500000, return_normals=True)
if data.get("gt_mesh") is not None:
gt_verts, gt_faces = data["gt_mesh"].get_mesh_verts_faces(0)
pts_gt, norms_gt = sample_points_from_meshes(data["gt_mesh"], num_samples=500000, return_normals=True)
pts_gt = (pts_gt + 1) / 2
from src.dpsr import DPSR
dpsr_tmp = DPSR(res=(cfg['model']['grid_res'],
cfg['model']['grid_res'],
cfg['model']['grid_res']),
sig=cfg['model']['psr_sigma']).to(device)
dpsr_tmp = DPSR(
res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
sig=cfg["model"]["psr_sigma"],
).to(device)
target = dpsr_tmp(pts_gt, norms_gt).unsqueeze(1).to(device)
target = torch.tanh(target)
s = target.shape[-1] # size of psr_grid
s = target.shape[-1] # size of psr_grid
psr_grid_numpy = target.squeeze().detach().cpu().numpy()
verts, faces, _, _ = measure.marching_cubes(psr_grid_numpy)
verts = verts / s * 2. - 1 # [-1, 1]
verts = verts / s * 2.0 - 1 # [-1, 1]
mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(verts)
mesh.triangles = o3d.utility.Vector3iVector(faces)
outdir_mesh = os.path.join(cfg['train']['out_dir'], 'oracle_mesh.ply')
outdir_mesh = os.path.join(cfg["train"]["out_dir"], "oracle_mesh.ply")
o3d.io.write_triangle_mesh(outdir_mesh, mesh)
# initialize the source point cloud given an input mesh
if 'input_mesh' in cfg['train'].keys() and \
os.path.isfile(cfg['train']['input_mesh']):
if cfg['train']['input_mesh'].split('/')[-2] == 'mesh':
mesh_tmp = trimesh.load_mesh(cfg['train']['input_mesh'])
if "input_mesh" in cfg["train"].keys() and os.path.isfile(cfg["train"]["input_mesh"]):
if cfg["train"]["input_mesh"].split("/")[-2] == "mesh":
mesh_tmp = trimesh.load_mesh(cfg["train"]["input_mesh"])
verts = torch.from_numpy(mesh_tmp.vertices[None]).float().to(device)
faces = torch.from_numpy(mesh_tmp.faces[None]).to(device)
mesh = Meshes(verts=verts, faces=faces)
points, normals = sample_points_from_meshes(mesh,
num_samples=cfg['data']['num_points'], return_normals=True)
points, normals = sample_points_from_meshes(
mesh, num_samples=cfg["data"]["num_points"], return_normals=True,
)
# mesh is saved in the original scale of the gt
points -= center.float().to(device)
points /= scale.float().to(device)
points *= 0.9
# make sure the points are within the range of [0, 1)
points = points / 2. + 0.5
points = points / 2.0 + 0.5
else:
# directly initialize from a point cloud
pcd = o3d.io.read_point_cloud(cfg['train']['input_mesh'])
pcd = o3d.io.read_point_cloud(cfg["train"]["input_mesh"])
points = torch.from_numpy(np.array(pcd.points)[None]).float().to(device)
normals = torch.from_numpy(np.array(pcd.normals)[None]).float().to(device)
points -= center.float().to(device)
points /= scale.float().to(device)
points *= 0.9
points = points / 2. + 0.5
else: #! initialize our source point cloud from a sphere
sphere_radius = cfg['model']['sphere_radius']
sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius,
count=[256,256])
points, idx = sphere_mesh.sample(cfg['data']['num_points'],
return_index=True)
points += 0.5 # make sure the points are within the range of [0, 1)
points = points / 2.0 + 0.5
else: #! initialize our source point cloud from a sphere
sphere_radius = cfg["model"]["sphere_radius"]
sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius, count=[256, 256])
points, idx = sphere_mesh.sample(cfg["data"]["num_points"], return_index=True)
points += 0.5 # make sure the points are within the range of [0, 1)
normals = sphere_mesh.face_normals[idx]
points = torch.from_numpy(points).unsqueeze(0).to(device)
normals = torch.from_numpy(normals).unsqueeze(0).to(device)
points = torch.log(points/(1-points)) # inverse sigmoid
points = torch.log(points / (1 - points)) # inverse sigmoid
inputs = torch.cat([points, normals], axis=-1).float()
inputs.requires_grad = True
model = None # no network
model = None # no network
# initialize optimizer
cfg['train']['schedule']['pcl']['initial'] = cfg['train']['lr_pcl']
print('Initial learning rate:', cfg['train']['schedule']['pcl']['initial'])
if 'schedule' in cfg['train']:
lr_schedules = get_learning_rate_schedules(cfg['train']['schedule'])
cfg["train"]["schedule"]["pcl"]["initial"] = cfg["train"]["lr_pcl"]
print("Initial learning rate:", cfg["train"]["schedule"]["pcl"]["initial"])
if "schedule" in cfg["train"]:
lr_schedules = get_learning_rate_schedules(cfg["train"]["schedule"])
else:
lr_schedules = None
optimizer = update_optimizer(inputs, cfg,
epoch=0, model=model, schedule=lr_schedules)
optimizer = update_optimizer(inputs, cfg, epoch=0, model=model, schedule=lr_schedules)
try:
# load model
state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt'))
if ('pcl' in state_dict.keys()) & (state_dict['pcl'] is not None):
inputs = state_dict['pcl'].to(device)
state_dict = torch.load(os.path.join(cfg["train"]["out_dir"], "model.pt"))
if ("pcl" in state_dict.keys()) & (state_dict["pcl"] is not None):
inputs = state_dict["pcl"].to(device)
inputs.requires_grad = True
optimizer = update_optimizer(inputs, cfg,
epoch=state_dict.get('epoch'), schedule=lr_schedules)
out = "Load model from epoch %d" % state_dict.get('epoch', 0)
optimizer = update_optimizer(inputs, cfg, epoch=state_dict.get("epoch"), schedule=lr_schedules)
out = "Load model from epoch %d" % state_dict.get("epoch", 0)
print(out)
logger.info(out)
except:
state_dict = dict()
start_epoch = state_dict.get('epoch', -1)
start_epoch = state_dict.get("epoch", -1)
trainer = Trainer(cfg, optimizer, device=device)
runtime = {}
runtime['all'] = AverageMeter()
runtime["all"] = AverageMeter()
# training loop
for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1):
for epoch in range(start_epoch + 1, cfg["train"]["total_epochs"] + 1):
# schedule the learning rate
if (epoch>0) & (lr_schedules is not None):
if (epoch % lr_schedules[0].interval == 0):
if (epoch > 0) & (lr_schedules is not None):
if epoch % lr_schedules[0].interval == 0:
adjust_learning_rate(lr_schedules, optimizer, epoch)
if len(lr_schedules) >1:
print('[epoch {}] net_lr: {}, pcl_lr: {}'.format(epoch,
lr_schedules[0].get_learning_rate(epoch),
lr_schedules[1].get_learning_rate(epoch)))
if len(lr_schedules) > 1:
print(
"[epoch {}] net_lr: {}, pcl_lr: {}".format(
epoch, lr_schedules[0].get_learning_rate(epoch), lr_schedules[1].get_learning_rate(epoch),
),
)
else:
print('[epoch {}] adjust pcl_lr to: {}'.format(epoch,
lr_schedules[0].get_learning_rate(epoch)))
print(f"[epoch {epoch}] adjust pcl_lr to: {lr_schedules[0].get_learning_rate(epoch)}")
start = time.time()
loss, loss_each = trainer.train_step(data, inputs, model, epoch)
runtime['all'].update(time.time() - start)
runtime["all"].update(time.time() - start)
if epoch % cfg['train']['print_every'] == 0:
log_text = ('[Epoch %02d] loss=%.5f') %(epoch, loss)
if epoch % cfg["train"]["print_every"] == 0:
log_text = ("[Epoch %02d] loss=%.5f") % (epoch, loss)
if loss_each is not None:
for k, l in loss_each.items():
if l.item() != 0.:
log_text += (' loss_%s=%.5f') % (k, l.item())
log_text += (' time=%.3f / %.3f') % (runtime['all'].val,
runtime['all'].sum)
if l.item() != 0.0:
log_text += f" loss_{k}={l.item():.5f}"
log_text += (" time={:.3f} / {:.3f}").format(runtime["all"].val, runtime["all"].sum)
logger.info(log_text)
print(log_text)
# visualize point clouds and meshes
if (epoch % cfg['train']['visualize_every'] == 0) & (vis is not None):
if (epoch % cfg["train"]["visualize_every"] == 0) & (vis is not None):
trainer.visualize(data, inputs, model, epoch, o3d_vis=vis)
# save outputs
if epoch % cfg['train']['save_every'] == 0:
trainer.save_mesh_pointclouds(inputs, epoch,
center.cpu().numpy(),
scale.cpu().numpy()*(1/0.9))
if epoch % cfg["train"]["save_every"] == 0:
trainer.save_mesh_pointclouds(inputs, epoch, center.cpu().numpy(), scale.cpu().numpy() * (1 / 0.9))
# save checkpoints
if (epoch > 0) & (epoch % cfg['train']['checkpoint_every'] == 0):
state = {'epoch': epoch}
pcl = None
if (epoch > 0) & (epoch % cfg["train"]["checkpoint_every"] == 0):
state = {"epoch": epoch}
if isinstance(inputs, torch.Tensor):
state['pcl'] = inputs.detach().cpu()
torch.save(state, os.path.join(cfg['train']['dir_model'],
'%04d' % epoch + '.pt'))
state["pcl"] = inputs.detach().cpu()
torch.save(state, os.path.join(cfg["train"]["dir_model"], "%04d" % epoch + ".pt"))
print("Save new model at epoch %d" % epoch)
logger.info("Save new model at epoch %d" % epoch)
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt'))
torch.save(state, os.path.join(cfg["train"]["out_dir"], "model.pt"))
# resample and gradually add new points to the source pcl
if (epoch > 0) & \
(cfg['train']['resample_every']!=0) & \
(epoch % cfg['train']['resample_every'] == 0) & \
(epoch < cfg['train']['total_epochs']):
inputs = trainer.point_resampling(inputs)
optimizer = update_optimizer(inputs, cfg,
epoch=epoch, model=model, schedule=lr_schedules)
trainer = Trainer(cfg, optimizer, device=device)
if (
(epoch > 0)
& (cfg["train"]["resample_every"] != 0)
& (epoch % cfg["train"]["resample_every"] == 0)
& (epoch < cfg["train"]["total_epochs"])
):
inputs = trainer.point_resampling(inputs)
optimizer = update_optimizer(inputs, cfg, epoch=epoch, model=model, schedule=lr_schedules)
trainer = Trainer(cfg, optimizer, device=device)
# visualize the Open3D outputs
if cfg['train']['o3d_show']:
out_video_dir = os.path.join(cfg['train']['out_dir'],
'vis/o3d/video.mp4')
if cfg["train"]["o3d_show"]:
out_video_dir = os.path.join(cfg["train"]["out_dir"], "vis/o3d/video.mp4")
if os.path.isfile(out_video_dir):
os.system('rm {}'.format(out_video_dir))
os.system('ffmpeg -framerate 30 \
os.system(f"rm {out_video_dir}")
os.system(
"ffmpeg -framerate 30 \
-start_number 0 \
-i {}/vis/o3d/%04d.jpg \
-pix_fmt yuv420p \
-crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir))
out_video_dir = os.path.join(cfg['train']['out_dir'],
'vis/o3d/video_pcd.mp4')
-crf 17 {}".format(
cfg["train"]["out_dir"], out_video_dir,
),
)
out_video_dir = os.path.join(cfg["train"]["out_dir"], "vis/o3d/video_pcd.mp4")
if os.path.isfile(out_video_dir):
os.system('rm {}'.format(out_video_dir))
os.system('ffmpeg -framerate 30 \
os.system(f"rm {out_video_dir}")
os.system(
"ffmpeg -framerate 30 \
-start_number 0 \
-i {}/vis/o3d/%04d_pcd.jpg \
-pix_fmt yuv420p \
-crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir))
print('Video saved.')
-crf 17 {}".format(
cfg["train"]["out_dir"], out_video_dir,
),
)
print("Video saved.")
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -1,69 +1,80 @@
import sys, os
import argparse
import os
from src.utils import load_config
import subprocess
os.environ['MKL_THREADING_LAYER'] = 'GNU'
os.environ["MKL_THREADING_LAYER"] = "GNU"
def main():
parser = argparse.ArgumentParser(description="MNIST toy experiment")
parser.add_argument("config", type=str, help="Path to config file.")
parser.add_argument("--start_res", type=int, default=-1, help="Resolution to start with.")
parser.add_argument("--object_id", type=int, default=-1, help="Object index.")
parser = argparse.ArgumentParser(description='MNIST toy experiment')
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--start_res', type=int, default=-1, help='Resolution to start with.')
parser.add_argument('--object_id', type=int, default=-1, help='Object index.')
args, unknown = parser.parse_known_args()
cfg = load_config(args.config, "configs/default.yaml")
args, unknown = parser.parse_known_args()
cfg = load_config(args.config, 'configs/default.yaml')
resolutions=[32, 64, 128, 256]
iterations=[1000, 1000, 1000, 200]
lrs=[2e-3, 2e-3*0.7, 2e-3*(0.7**2), 2e-3*(0.7**3)] # reduce lr
for idx,(res, iteration, lr) in enumerate(zip(resolutions, iterations, lrs)):
if res<args.start_res:
resolutions = [32, 64, 128, 256]
iterations = [1000, 1000, 1000, 200]
lrs = [2e-3, 2e-3 * 0.7, 2e-3 * (0.7**2), 2e-3 * (0.7**3)] # reduce lr
for idx, (res, iteration, lr) in enumerate(zip(resolutions, iterations, lrs)):
if res < args.start_res:
continue
if res>cfg['model']['grid_res']:
if res > cfg["model"]["grid_res"]:
continue
psr_sigma= 2 if res<=128 else 3
psr_sigma = 2 if res <= 128 else 3
if res > 128:
psr_sigma = 5 if 'thingi_noisy' in args.config else 3
psr_sigma = 5 if "thingi_noisy" in args.config else 3
if args.object_id != -1:
out_dir = os.path.join(cfg['train']['out_dir'], 'object_%02d'%args.object_id, 'res_%d'%res)
out_dir = os.path.join(cfg["train"]["out_dir"], "object_%02d" % args.object_id, "res_%d" % res)
else:
out_dir = os.path.join(cfg['train']['out_dir'], 'res_%d'%res)
out_dir = os.path.join(cfg["train"]["out_dir"], "res_%d" % res)
# sample from mesh when resampling is enabled, otherwise reuse the pointcloud
init_shape='mesh' if cfg['train']['resample_every']>0 else 'pointcloud'
init_shape = "mesh" if cfg["train"]["resample_every"] > 0 else "pointcloud"
if args.object_id != -1:
input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'],
'object_%02d'%args.object_id, 'res_%d' % (resolutions[idx-1]),
'vis', init_shape, '%04d.ply' % (iterations[idx-1]))
input_mesh = (
"None"
if idx == 0
else os.path.join(
cfg["train"]["out_dir"],
"object_%02d" % args.object_id,
"res_%d" % (resolutions[idx - 1]),
"vis",
init_shape,
"%04d.ply" % (iterations[idx - 1]),
)
)
else:
input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'],
'res_%d' % (resolutions[idx-1]),
'vis', init_shape, '%04d.ply' % (iterations[idx-1]))
cmd = 'export MKL_SERVICE_FORCE_INTEL=1 && '
cmd += "python optim.py %s --model:grid_res %d --model:psr_sigma %d \
input_mesh = (
"None"
if idx == 0
else os.path.join(
cfg["train"]["out_dir"],
"res_%d" % (resolutions[idx - 1]),
"vis",
init_shape,
"%04d.ply" % (iterations[idx - 1]),
)
)
cmd = "export MKL_SERVICE_FORCE_INTEL=1 && "
cmd += (
"python optim.py %s --model:grid_res %d --model:psr_sigma %d \
--train:input_mesh %s --train:total_epochs %d \
--train:out_dir %s --train:lr_pcl %f \
--data:object_id %d" % (
args.config,
res,
psr_sigma,
input_mesh,
iteration,
out_dir,
lr,
args.object_id)
--data:object_id %d"
% (args.config, res, psr_sigma, input_mesh, iteration, out_dir, lr, args.object_id)
)
print(cmd)
os.system(cmd)
if __name__=="__main__":
if __name__ == "__main__":
main()

View file

@ -1,14 +1,16 @@
import os
import torch
import time
import multiprocessing
import os
import time
import numpy as np
import torch
from tqdm import tqdm
from src.dpsr import DPSR
data_path = 'data/ShapeNet' # path for ShapeNet from ONet
base = 'data' # output base directory
dataset_name = 'shapenet_psr'
data_path = "data/ShapeNet" # path for ShapeNet from ONet
base = "data" # output base directory
dataset_name = "shapenet_psr"
multiprocess = True
njobs = 8
save_pointcloud = True
@ -20,52 +22,56 @@ padding = 1.2
dpsr = DPSR(res=(resolution, resolution, resolution), sig=0)
def process_one(obj):
obj_name = obj.split('/')[-1]
c = obj.split('/')[-2]
def process_one(obj):
obj_name = obj.split("/")[-1]
c = obj.split("/")[-2]
# create new for the current object
out_path_cur = os.path.join(base, dataset_name, c)
out_path_cur_obj = os.path.join(out_path_cur, obj_name)
os.makedirs(out_path_cur_obj, exist_ok=True)
gt_path = os.path.join(data_path, c, obj_name, 'pointcloud.npz')
gt_path = os.path.join(data_path, c, obj_name, "pointcloud.npz")
data = np.load(gt_path)
points = data['points']
normals = data['normals']
points = data["points"]
normals = data["normals"]
# normalize the point to [0, 1)
points = points / padding + 0.5
# to scale back during inference, we should:
#! p = (p - 0.5) * padding
if save_pointcloud:
outdir = os.path.join(out_path_cur_obj, 'pointcloud.npz')
# np.savez(outdir, points=points, normals=normals)
np.savez(outdir, points=data['points'], normals=data['normals'])
# return
if save_psr_field:
psr_gt = dpsr(torch.from_numpy(points.astype(np.float32))[None],
torch.from_numpy(normals.astype(np.float32))[None]).squeeze().cpu().numpy().astype(np.float16)
outdir = os.path.join(out_path_cur_obj, 'psr.npz')
if save_pointcloud:
outdir = os.path.join(out_path_cur_obj, "pointcloud.npz")
# np.savez(outdir, points=points, normals=normals)
np.savez(outdir, points=data["points"], normals=data["normals"])
# return
if save_psr_field:
psr_gt = (
dpsr(torch.from_numpy(points.astype(np.float32))[None], torch.from_numpy(normals.astype(np.float32))[None])
.squeeze()
.cpu()
.numpy()
.astype(np.float16)
)
outdir = os.path.join(out_path_cur_obj, "psr.npz")
np.savez(outdir, psr=psr_gt)
def main(c):
print("---------------------------------------")
print(f"Processing {c} {split}")
print("---------------------------------------")
print('---------------------------------------')
print('Processing {} {}'.format(c, split))
print('---------------------------------------')
for split in ["train", "val", "test"]:
fname = os.path.join(data_path, c, split + ".lst")
with open(fname) as f:
obj_list = f.read().splitlines()
for split in ['train', 'val', 'test']:
fname = os.path.join(data_path, c, split+'.lst')
with open(fname, 'r') as f:
obj_list = f.read().splitlines()
obj_list = [c+'/'+s for s in obj_list]
obj_list = [c + "/" + s for s in obj_list]
if multiprocess:
# multiprocessing.set_start_method('spawn', force=True)
@ -81,21 +87,30 @@ def main(c):
else:
for obj in tqdm(obj_list):
process_one(obj)
print('Done Processing {} {}!'.format(c, split))
print(f"Done Processing {c} {split}!")
if __name__ == "__main__":
classes = [
"02691156",
"02828884",
"02933112",
"02958343",
"03211117",
"03001627",
"03636649",
"03691459",
"04090263",
"04256520",
"04379243",
"04401088",
"04530566",
]
classes = ['02691156', '02828884', '02933112',
'02958343', '03211117', '03001627',
'03636649', '03691459', '04090263',
'04256520', '04379243', '04401088', '04530566']
t_start = time.time()
for c in classes:
main(c)
t_end = time.time()
print('Total processing time: ', t_end - t_start)
print("Total processing time: ", t_end - t_start)

View file

@ -1,146 +1,151 @@
import yaml
from torchvision import transforms
from src import data, generation
from src.dpsr import DPSR
from ipdb import set_trace as st
# Generator for final mesh extraction
def get_generator(model, cfg, device, **kwargs):
''' Returns the generator object.
"""Returns the generator object.
Args:
model (nn.Module): Occupancy Network model
cfg (dict): imported yaml config
device (device): pytorch device
'''
if cfg['generation']['psr_resolution'] == 0:
psr_res = cfg['model']['grid_res']
psr_sigma = cfg['model']['psr_sigma']
"""
if cfg["generation"]["psr_resolution"] == 0:
psr_res = cfg["model"]["grid_res"]
psr_sigma = cfg["model"]["psr_sigma"]
else:
psr_res = cfg['generation']['psr_resolution']
psr_sigma = cfg['generation']['psr_sigma']
dpsr = DPSR(res=(psr_res, psr_res, psr_res),
sig= psr_sigma).to(device)
psr_res = cfg["generation"]["psr_resolution"]
psr_sigma = cfg["generation"]["psr_sigma"]
dpsr = DPSR(res=(psr_res, psr_res, psr_res), sig=psr_sigma).to(device)
generator = generation.Generator3D(
model,
device=device,
threshold=cfg['data']['zero_level'],
sample=cfg['generation']['use_sampling'],
input_type = cfg['data']['input_type'],
padding=cfg['data']['padding'],
threshold=cfg["data"]["zero_level"],
sample=cfg["generation"]["use_sampling"],
input_type=cfg["data"]["input_type"],
padding=cfg["data"]["padding"],
dpsr=dpsr,
psr_tanh=cfg['model']['psr_tanh']
psr_tanh=cfg["model"]["psr_tanh"],
)
return generator
# Datasets
def get_dataset(mode, cfg, return_idx=False):
''' Returns the dataset.
"""Returns the dataset.
Args:
model (nn.Module): the model which is used
cfg (dict): config dictionary
return_idx (bool): whether to include an ID field
'''
dataset_type = cfg['data']['dataset']
dataset_folder = cfg['data']['path']
categories = cfg['data']['class']
"""
dataset_type = cfg["data"]["dataset"]
dataset_folder = cfg["data"]["path"]
categories = cfg["data"]["class"]
# Get split
splits = {
'train': cfg['data']['train_split'],
'val': cfg['data']['val_split'],
'test': cfg['data']['test_split'],
'vis': cfg['data']['val_split'],
"train": cfg["data"]["train_split"],
"val": cfg["data"]["val_split"],
"test": cfg["data"]["test_split"],
"vis": cfg["data"]["val_split"],
}
split = splits[mode]
# Create dataset
if dataset_type == 'Shapes3D':
if dataset_type == "Shapes3D":
fields = get_data_fields(mode, cfg)
# Input fields
inputs_field = get_inputs_field(mode, cfg)
if inputs_field is not None:
fields['inputs'] = inputs_field
fields["inputs"] = inputs_field
if return_idx:
fields['idx'] = data.IndexField()
fields["idx"] = data.IndexField()
dataset = data.Shapes3dDataset(
dataset_folder, fields,
dataset_folder,
fields,
split=split,
categories=categories,
cfg = cfg
cfg=cfg,
)
else:
raise ValueError('Invalid dataset "%s"' % cfg['data']['dataset'])
raise ValueError('Invalid dataset "%s"' % cfg["data"]["dataset"])
return dataset
def get_inputs_field(mode, cfg):
''' Returns the inputs fields.
"""Returns the inputs fields.
Args:
mode (str): the mode which is used
cfg (dict): config dictionary
'''
input_type = cfg['data']['input_type']
"""
input_type = cfg["data"]["input_type"]
if input_type is None:
inputs_field = None
elif input_type == 'pointcloud':
noise_level = cfg['data']['pointcloud_noise']
if cfg['data']['pointcloud_outlier_ratio']>0:
transform = transforms.Compose([
data.SubsamplePointcloud(cfg['data']['pointcloud_n']),
data.PointcloudNoise(noise_level),
data.PointcloudOutliers(cfg['data']['pointcloud_outlier_ratio'])
])
elif input_type == "pointcloud":
noise_level = cfg["data"]["pointcloud_noise"]
if cfg["data"]["pointcloud_outlier_ratio"] > 0:
transform = transforms.Compose(
[
data.SubsamplePointcloud(cfg["data"]["pointcloud_n"]),
data.PointcloudNoise(noise_level),
data.PointcloudOutliers(cfg["data"]["pointcloud_outlier_ratio"]),
],
)
else:
transform = transforms.Compose([
data.SubsamplePointcloud(cfg['data']['pointcloud_n']),
data.PointcloudNoise(noise_level)
])
transform = transforms.Compose(
[
data.SubsamplePointcloud(cfg["data"]["pointcloud_n"]),
data.PointcloudNoise(noise_level),
],
)
data_type = cfg['data']['data_type']
data_type = cfg["data"]["data_type"]
inputs_field = data.PointCloudField(
cfg['data']['pointcloud_file'], data_type, transform,
multi_files= cfg['data']['multi_files']
)
cfg["data"]["pointcloud_file"],
data_type,
transform,
multi_files=cfg["data"]["multi_files"],
)
else:
raise ValueError(
'Invalid input type (%s)' % input_type)
raise ValueError("Invalid input type (%s)" % input_type)
return inputs_field
def get_data_fields(mode, cfg):
''' Returns the data fields.
"""Returns the data fields.
Args:
mode (str): the mode which is used
cfg (dict): imported yaml config
'''
data_type = cfg['data']['data_type']
"""
data_type = cfg["data"]["data_type"]
fields = {}
if (mode in ('val', 'test')):
if mode in ("val", "test"):
transform = data.SubsamplePointcloud(100000)
else:
transform = data.SubsamplePointcloud(cfg['data']['num_gt_points'])
data_name = cfg['data']['pointcloud_file']
fields['gt_points'] = data.PointCloudField(data_name,
transform=transform, data_type=data_type, multi_files=cfg['data']['multi_files'])
if data_type == 'psr_full':
if mode != 'test':
fields['gt_psr'] = data.FullPSRField(multi_files=cfg['data']['multi_files'])
else:
raise ValueError('Invalid data type (%s)' % data_type)
transform = data.SubsamplePointcloud(cfg["data"]["num_gt_points"])
return fields
data_name = cfg["data"]["pointcloud_file"]
fields["gt_points"] = data.PointCloudField(
data_name, transform=transform, data_type=data_type, multi_files=cfg["data"]["multi_files"],
)
if data_type == "psr_full":
if mode != "test":
fields["gt_psr"] = data.FullPSRField(multi_files=cfg["data"]["multi_files"])
else:
raise ValueError("Invalid data type (%s)" % data_type)
return fields

View file

@ -1,14 +1,12 @@
from src.data.core import (
Shapes3dDataset, collate_remove_none, worker_init_fn, collate_stack_together
)
from src.data.fields import (
IndexField, PointCloudField, FullPSRField
)
from src.data.transforms import (
PointcloudNoise, SubsamplePointcloud,
PointcloudOutliers,
Shapes3dDataset,
collate_remove_none,
collate_stack_together,
worker_init_fn,
)
from src.data.fields import FullPSRField, IndexField, PointCloudField
from src.data.transforms import PointcloudNoise, PointcloudOutliers, SubsamplePointcloud
__all__ = [
# Core
Shapes3dDataset,

View file

@ -1,45 +1,41 @@
import os
import logging
from torch.utils import data
from pdb import set_trace as st
import os
import numpy as np
import yaml
from torch.utils import data
logger = logging.getLogger(__name__)
# Fields
class Field(object):
''' Data fields class.
'''
class Field:
"""Data fields class."""
def load(self, data_path, idx, category):
''' Loads a data point.
"""Loads a data point.
Args:
data_path (str): path to data file
idx (int): index of data point
category (int): index of category
'''
"""
raise NotImplementedError
def check_complete(self, files):
''' Checks if set is complete.
"""Checks if set is complete.
Args:
files: files
'''
"""
raise NotImplementedError
class Shapes3dDataset(data.Dataset):
''' 3D Shapes dataset class.
'''
"""3D Shapes dataset class."""
def __init__(self, dataset_folder, fields, split=None,
categories=None, no_except=True, transform=None, cfg=None):
''' Initialization of the the 3D shape dataset.
def __init__(self, dataset_folder, fields, split=None, categories=None, no_except=True, transform=None, cfg=None):
"""Initialization of the the 3D shape dataset.
Args:
dataset_folder (str): dataset folder
@ -49,7 +45,7 @@ class Shapes3dDataset(data.Dataset):
no_except (bool): no exception
transform (callable): transformation applied to data points
cfg (yaml): config file
'''
"""
# Attributes
self.dataset_folder = dataset_folder
self.fields = fields
@ -60,76 +56,69 @@ class Shapes3dDataset(data.Dataset):
# If categories is None, use all subfolders
if categories is None:
categories = os.listdir(dataset_folder)
categories = [c for c in categories
if os.path.isdir(os.path.join(dataset_folder, c))]
categories = [c for c in categories if os.path.isdir(os.path.join(dataset_folder, c))]
# Read metadata file
metadata_file = os.path.join(dataset_folder, 'metadata.yaml')
metadata_file = os.path.join(dataset_folder, "metadata.yaml")
if os.path.exists(metadata_file):
with open(metadata_file, 'r') as f:
with open(metadata_file) as f:
self.metadata = yaml.load(f, Loader=yaml.Loader)
else:
self.metadata = {
c: {'id': c, 'name': 'n/a'} for c in categories
}
self.metadata = {c: {"id": c, "name": "n/a"} for c in categories}
# Set index
for c_idx, c in enumerate(categories):
self.metadata[c]['idx'] = c_idx
self.metadata[c]["idx"] = c_idx
# Get all models
self.models = []
for c_idx, c in enumerate(categories):
subpath = os.path.join(dataset_folder, c)
if not os.path.isdir(subpath):
logger.warning('Category %s does not exist in dataset.' % c)
logger.warning("Category %s does not exist in dataset." % c)
if split is None:
self.models += [
{'category': c, 'model': m} for m in [d for d in os.listdir(subpath) if (os.path.isdir(os.path.join(subpath, d)) and d != '') ]
{"category": c, "model": m}
for m in [d for d in os.listdir(subpath) if (os.path.isdir(os.path.join(subpath, d)) and d != "")]
]
else:
split_file = os.path.join(subpath, split + '.lst')
with open(split_file, 'r') as f:
models_c = f.read().split('\n')
if '' in models_c:
models_c.remove('')
split_file = os.path.join(subpath, split + ".lst")
with open(split_file) as f:
models_c = f.read().split("\n")
if "" in models_c:
models_c.remove("")
self.models += [{"category": c, "model": m} for m in models_c]
self.models += [
{'category': c, 'model': m}
for m in models_c
]
# precompute
self.split = split
def __len__(self):
''' Returns the length of the dataset.
'''
"""Returns the length of the dataset."""
return len(self.models)
def __getitem__(self, idx):
''' Returns an item of the dataset.
"""Returns an item of the dataset.
Args:
idx (int): ID of data point
'''
category = self.models[idx]['category']
model = self.models[idx]['model']
c_idx = self.metadata[category]['idx']
"""
category = self.models[idx]["category"]
model = self.models[idx]["model"]
c_idx = self.metadata[category]["idx"]
model_path = os.path.join(self.dataset_folder, category, model)
data = {}
info = c_idx
if self.cfg['data']['multi_files'] is not None:
idx = np.random.randint(self.cfg['data']['multi_files'])
if self.split != 'train':
if self.cfg["data"]["multi_files"] is not None:
idx = np.random.randint(self.cfg["data"]["multi_files"])
if self.split != "train":
idx = 0
for field_name, field in self.fields.items():
@ -137,9 +126,8 @@ class Shapes3dDataset(data.Dataset):
field_data = field.load(model_path, idx, info)
except Exception:
if self.no_except:
logger.warn(
'Error occured when loading field %s of model %s'
% (field_name, model)
logger.warning(
f"Error occured when loading field {field_name} of model {model}",
)
return None
else:
@ -150,7 +138,7 @@ class Shapes3dDataset(data.Dataset):
if k is None:
data[field_name] = v
else:
data['%s.%s' % (field_name, k)] = v
data[f"{field_name}.{k}"] = v
else:
data[field_name] = field_data
@ -158,78 +146,76 @@ class Shapes3dDataset(data.Dataset):
data = self.transform(data)
return data
def get_model_dict(self, idx):
return self.models[idx]
def test_model_complete(self, category, model):
''' Tests if model is complete.
"""Tests if model is complete.
Args:
model (str): modelname
'''
"""
model_path = os.path.join(self.dataset_folder, category, model)
files = os.listdir(model_path)
for field_name, field in self.fields.items():
if not field.check_complete(files):
logger.warn('Field "%s" is incomplete: %s'
% (field_name, model_path))
logger.warning(f'Field "{field_name}" is incomplete: {model_path}')
return False
return True
def collate_remove_none(batch):
''' Collater that puts each data field into a tensor with outer dimension
"""Collater that puts each data field into a tensor with outer dimension
batch size.
Args:
batch: batch
'''
"""
batch = list(filter(lambda x: x is not None, batch))
return data.dataloader.default_collate(batch)
def collate_stack_together(batch):
''' Collater that puts each data field into a tensor with outer dimension
"""Collater that puts each data field into a tensor with outer dimension
batch size.
Args:
batch: batch
'''
"""
batch = list(filter(lambda x: x is not None, batch))
keys = batch[0].keys()
concat = {}
if len(batch)>1:
if len(batch) > 1:
for key in keys:
key_val = [item[key] for item in batch]
concat[key] = np.concatenate(key_val, axis=0)
if key == 'inputs':
if key == "inputs":
n_pts = [item[key].shape[0] for item in batch]
concat['batch_ind'] = np.concatenate(
[i * np.ones(n, dtype=int) for i, n in enumerate(n_pts)], axis=0)
concat["batch_ind"] = np.concatenate([i * np.ones(n, dtype=int) for i, n in enumerate(n_pts)], axis=0)
return data.dataloader.default_collate([concat])
else:
n_pts = batch[0]['inputs'].shape[0]
batch[0]['batch_ind'] = np.zeros(n_pts, dtype=int)
n_pts = batch[0]["inputs"].shape[0]
batch[0]["batch_ind"] = np.zeros(n_pts, dtype=int)
return data.dataloader.default_collate(batch)
def worker_init_fn(worker_id):
''' Worker init function to ensure true randomness.
'''
"""Worker init function to ensure true randomness."""
def set_num_threads(nt):
try:
import mkl; mkl.set_num_threads(nt)
except:
try:
import mkl
mkl.set_num_threads(nt)
except:
pass
torch.set_num_threads(1)
os.environ['IPC_ENABLE']='1'
for o in ['OPENBLAS_NUM_THREADS','NUMEXPR_NUM_THREADS','OMP_NUM_THREADS','MKL_NUM_THREADS']:
os.environ["IPC_ENABLE"] = "1"
for o in ["OPENBLAS_NUM_THREADS", "NUMEXPR_NUM_THREADS", "OMP_NUM_THREADS", "MKL_NUM_THREADS"]:
os.environ[o] = str(nt)
random_data = os.urandom(4)

View file

@ -1,63 +1,61 @@
import os
import glob
import time
import random
from PIL import Image
import numpy as np
import trimesh
from src.data.core import Field
from pdb import set_trace as st
class IndexField(Field):
''' Basic index field.'''
"""Basic index field."""
def load(self, model_path, idx, category):
''' Loads the index field.
"""Loads the index field.
Args:
model_path (str): path to model
idx (int): ID of data point
category (int): index of category
'''
"""
return idx
def check_complete(self, files):
''' Check if field is complete.
"""Check if field is complete.
Args:
files: files
'''
"""
return True
class FullPSRField(Field):
def __init__(self, transform=None, multi_files=None):
self.transform = transform
# self.unpackbits = unpackbits
self.multi_files = multi_files
def load(self, model_path, idx, category):
def load(self, model_path, idx, category):
# try:
# t0 = time.time()
if self.multi_files is not None:
psr_path = os.path.join(model_path, 'psr', 'psr_{:02d}.npz'.format(idx))
psr_path = os.path.join(model_path, "psr", f"psr_{idx:02d}.npz")
else:
psr_path = os.path.join(model_path, 'psr.npz')
psr_path = os.path.join(model_path, "psr.npz")
psr_dict = np.load(psr_path)
# t1 = time.time()
psr = psr_dict['psr']
psr = psr_dict["psr"]
psr = psr.astype(np.float32)
# t2 = time.time()
# print('load PSR: {:.4f}, change type: {:.4f}, total: {:.4f}'.format(t1 - t0, t2 - t1, t2-t0))
data = {None: psr}
if self.transform is not None:
data = self.transform(data)
return data
class PointCloudField(Field):
''' Point cloud field.
"""Point cloud field.
It provides the field used for point cloud data. These are the points
randomly sampled on the mesh.
@ -66,53 +64,54 @@ class PointCloudField(Field):
file_name (str): file name
transform (list): list of transformations applied to data points
multi_files (callable): number of files
'''
"""
def __init__(self, file_name, data_type=None, transform=None, multi_files=None, padding=0.1, scale=1.2):
self.file_name = file_name
self.data_type = data_type # to make sure the range of input is correct
self.data_type = data_type # to make sure the range of input is correct
self.transform = transform
self.multi_files = multi_files
self.padding = padding
self.scale = scale
def load(self, model_path, idx, category):
''' Loads the data point.
"""Loads the data point.
Args:
model_path (str): path to model
idx (int): ID of data point
category (int): index of category
'''
"""
if self.multi_files is None:
file_path = os.path.join(model_path, self.file_name)
else:
# num = np.random.randint(self.multi_files)
# file_path = os.path.join(model_path, self.file_name, '%s_%02d.npz' % (self.file_name, num))
file_path = os.path.join(model_path, self.file_name, 'pointcloud_%02d.npz' % (idx))
file_path = os.path.join(model_path, self.file_name, "pointcloud_%02d.npz" % (idx))
pointcloud_dict = np.load(file_path)
points = pointcloud_dict['points'].astype(np.float32)
normals = pointcloud_dict['normals'].astype(np.float32)
points = pointcloud_dict["points"].astype(np.float32)
normals = pointcloud_dict["normals"].astype(np.float32)
data = {
None: points,
'normals': normals,
"normals": normals,
}
if self.transform is not None:
data = self.transform(data)
if self.data_type == 'psr_full':
if self.data_type == "psr_full":
# scale the point cloud to the range of (0, 1)
data[None] = data[None] / self.scale + 0.5
return data
def check_complete(self, files):
''' Check if field is complete.
"""Check if field is complete.
Args:
files: files
'''
complete = (self.file_name in files)
"""
complete = self.file_name in files
return complete

View file

@ -2,24 +2,24 @@ import numpy as np
# Transforms
class PointcloudNoise(object):
''' Point cloud noise transformation class.
class PointcloudNoise:
"""Point cloud noise transformation class.
It adds noise to point cloud data.
Args:
stddev (int): standard deviation
'''
"""
def __init__(self, stddev):
self.stddev = stddev
def __call__(self, data):
''' Calls the transformation.
"""Calls the transformation.
Args:
data (dictionary): data dictionary
'''
"""
data_out = data.copy()
points = data[None]
noise = self.stddev * np.random.randn(*points.shape)
@ -27,60 +27,63 @@ class PointcloudNoise(object):
data_out[None] = points + noise
return data_out
class PointcloudOutliers(object):
''' Point cloud outlier transformation class.
class PointcloudOutliers:
"""Point cloud outlier transformation class.
It adds outliers to point cloud data.
Args:
ratio (int): outlier percentage to the entire point cloud
'''
"""
def __init__(self, ratio):
self.ratio = ratio
def __call__(self, data):
''' Calls the transformation.
"""Calls the transformation.
Args:
data (dictionary): data dictionary
'''
"""
data_out = data.copy()
points = data[None]
n_points = points.shape[0]
n_outlier_points = int(n_points*self.ratio)
n_outlier_points = int(n_points * self.ratio)
ind = np.random.randint(0, n_points, n_outlier_points)
outliers = np.random.uniform(-0.55, 0.55, (n_outlier_points, 3))
outliers = outliers.astype(np.float32)
points[ind] = outliers
data_out[None] = points
return data_out
class SubsamplePointcloud(object):
''' Point cloud subsampling transformation class.
class SubsamplePointcloud:
"""Point cloud subsampling transformation class.
It subsamples the point cloud data.
Args:
N (int): number of points to be subsampled
'''
"""
def __init__(self, N):
self.N = N
def __call__(self, data):
''' Calls the transformation.
"""Calls the transformation.
Args:
data (dict): data dictionary
'''
"""
data_out = data.copy()
points = data[None]
indices = np.random.randint(points.shape[0], size=self.N)
data_out[None] = points[indices, :]
if 'normals' in data.keys():
normals = data['normals']
data_out['normals'] = normals[indices, :]
if "normals" in data.keys():
normals = data["normals"]
data_out["normals"] = normals[indices, :]
return data_out
return data_out

View file

@ -1,17 +1,20 @@
import os
import cv2
import torch
import numpy as np
from glob import glob
from torch.utils import data
from src.utils import load_rgb, load_mask, get_camera_params
import cv2
import numpy as np
import torch
from pytorch3d.renderer import PerspectiveCameras
from skimage import img_as_float32
from torch.utils import data
from src.utils import get_camera_params, load_mask, load_rgb
##################################################
# Below are for the differentiable renderer
# Taken from https://github.com/lioryariv/idr/blob/main/code/utils/rend_util.py
def load_rgb(path):
img = imageio.imread(path)
img = img_as_float32(img)
@ -23,6 +26,7 @@ def load_rgb(path):
# img = img.transpose(2, 0, 1)
return img
def load_mask(path):
alpha = imageio.imread(path, as_gray=True)
alpha = img_as_float32(alpha)
@ -32,13 +36,13 @@ def load_mask(path):
def get_camera_params(uv, pose, intrinsics):
if pose.shape[1] == 7: #In case of quaternion vector representation
if pose.shape[1] == 7: # In case of quaternion vector representation
cam_loc = pose[:, 4:]
R = quat_to_rot(pose[:,:4])
p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float()
R = quat_to_rot(pose[:, :4])
p = torch.eye(4).repeat(pose.shape[0], 1, 1).cuda().float()
p[:, :3, :3] = R
p[:, :3, 3] = cam_loc
else: # In case of pose matrix representation
else: # In case of pose matrix representation
cam_loc = pose[:, :3, 3]
p = pose
@ -60,25 +64,27 @@ def get_camera_params(uv, pose, intrinsics):
return ray_dirs, cam_loc
def quat_to_rot(q):
batch_size, _ = q.shape
q = F.normalize(q, dim=1)
R = torch.ones((batch_size, 3,3)).cuda()
qr=q[:,0]
R = torch.ones((batch_size, 3, 3)).cuda()
qr = q[:, 0]
qi = q[:, 1]
qj = q[:, 2]
qk = q[:, 3]
R[:, 0, 0]=1-2 * (qj**2 + qk**2)
R[:, 0, 1] = 2 * (qj *qi -qk*qr)
R[:, 0, 0] = 1 - 2 * (qj**2 + qk**2)
R[:, 0, 1] = 2 * (qj * qi - qk * qr)
R[:, 0, 2] = 2 * (qi * qk + qr * qj)
R[:, 1, 0] = 2 * (qj * qi + qk * qr)
R[:, 1, 1] = 1-2 * (qi**2 + qk**2)
R[:, 1, 2] = 2*(qj*qk - qi*qr)
R[:, 2, 0] = 2 * (qk * qi-qj * qr)
R[:, 2, 1] = 2 * (qj*qk + qi*qr)
R[:, 2, 2] = 1-2 * (qi**2 + qj**2)
R[:, 1, 1] = 1 - 2 * (qi**2 + qk**2)
R[:, 1, 2] = 2 * (qj * qk - qi * qr)
R[:, 2, 0] = 2 * (qk * qi - qj * qr)
R[:, 2, 1] = 2 * (qj * qk + qi * qr)
R[:, 2, 2] = 1 - 2 * (qi**2 + qj**2)
return R
def lift(x, y, z, intrinsics):
# parse intrinsics
# intrinsics = intrinsics.cuda()
@ -88,7 +94,16 @@ def lift(x, y, z, intrinsics):
cy = intrinsics[:, 1, 2]
sk = intrinsics[:, 0, 1]
x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z
x_lift = (
(
x
- cx.unsqueeze(-1)
+ cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1)
- sk.unsqueeze(-1) * y / fy.unsqueeze(-1)
)
/ fx.unsqueeze(-1)
* z
)
y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z
# homogeneous
@ -96,21 +111,18 @@ def lift(x, y, z, intrinsics):
class PixelNeRFDTUDataset(data.Dataset):
"""
Processed DTU from pixelNeRF
"""
def __init__(self,
data_dir='data/DTU',
scan_id=65,
img_size=None,
device=None,
fixed_scale=0,
):
data_dir = os.path.join(data_dir, "scan{}".format(scan_id))
rgb_paths = [
x for x in glob(os.path.join(data_dir, "image", "*"))
if (x.endswith(".jpg") or x.endswith(".png"))
]
"""Processed DTU from pixelNeRF."""
def __init__(
self,
data_dir="data/DTU",
scan_id=65,
img_size=None,
device=None,
fixed_scale=0,
):
data_dir = os.path.join(data_dir, f"scan{scan_id}")
rgb_paths = [x for x in glob(os.path.join(data_dir, "image", "*")) if (x.endswith((".jpg", ".png")))]
rgb_paths = sorted(rgb_paths)
mask_paths = sorted(glob(os.path.join(data_dir, "mask", "*.png")))
if len(mask_paths) == 0:
@ -129,27 +141,24 @@ class PixelNeRFDTUDataset(data.Dataset):
all_T = []
for idx, (rgb_path, mask_path) in enumerate(zip(rgb_paths, mask_paths)):
i = sel_indices[idx]
rgb = load_rgb(rgb_path)
mask = load_mask(mask_path)
rgb[~mask] = 0.
rgb[~mask] = 0.0
rgb = torch.from_numpy(rgb).float().to(device)
mask = torch.from_numpy(mask).float().to(device)
x_scale = y_scale = 1.0
xy_delta = 0.0
P = all_cam["world_mat_" + str(i)]
P = P[:3]
# scale the original shape to really [-0.9, 0.9]
if fixed_scale!=0.:
if fixed_scale != 0.0:
scale_mat_new = np.eye(4, 4)
scale_mat_new[:3, :3] *= fixed_scale # scale to [-0.9, 0.9]
scale_mat_new[:3, :3] *= fixed_scale # scale to [-0.9, 0.9]
P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)] @ scale_mat_new
else:
P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)]
P = P[:3, :4]
K, R, t = cv2.decomposeProjectionMatrix(P)[:3]
K = K / K[2, 2]
@ -158,38 +167,34 @@ class PixelNeRFDTUDataset(data.Dataset):
########!!!!!
RR = torch.from_numpy(R).permute(1, 0).unsqueeze(0)
tt = torch.from_numpy(-R@(t[:3] / t[3])).permute(1, 0)
tt = torch.from_numpy(-R @ (t[:3] / t[3])).permute(1, 0)
focal = torch.tensor((fx, fy), dtype=torch.float32).unsqueeze(0)
pc = torch.tensor((cx, cy), dtype=torch.float32).unsqueeze(0)
im_size = (rgb.shape[1], rgb.shape[0])
# check https://pytorch3d.org/docs/cameras for how to transform from screen to NDC
s = min(im_size)
focal[:, 0] = focal[:, 0] * 2 / (s-1)
focal[:, 1] = focal[:, 1] * 2 /(s-1)
pc[:, 0] = -(pc[:, 0] - (im_size[0]-1)/2) * 2 / (s-1)
pc[:, 1] = -(pc[:, 1] - (im_size[1]-1)/2) * 2 / (s-1)
focal[:, 0] = focal[:, 0] * 2 / (s - 1)
focal[:, 1] = focal[:, 1] * 2 / (s - 1)
pc[:, 0] = -(pc[:, 0] - (im_size[0] - 1) / 2) * 2 / (s - 1)
pc[:, 1] = -(pc[:, 1] - (im_size[1] - 1) / 2) * 2 / (s - 1)
camera = PerspectiveCameras(focal_length=-focal, principal_point=pc,
device=device, R=RR, T=tt)
camera = PerspectiveCameras(focal_length=-focal, principal_point=pc, device=device, R=RR, T=tt)
# calculate camera rays
uv = uv_creation(im_size)[None].float()
pose = np.eye(4, dtype=np.float32)
pose[:3, :3] = R.transpose()
pose[:3,3] = (t[:3] / t[3])[:,0]
pose[:3, 3] = (t[:3] / t[3])[:, 0]
pose = torch.from_numpy(pose)[None].float()
intrinsics = np.eye(4)
intrinsics[:3, :3] = K
intrinsics[0, 1] = 0. #! remove skew for now
intrinsics[0, 1] = 0.0 #! remove skew for now
intrinsics = torch.from_numpy(intrinsics)[None].float()
rays, _ = get_camera_params(uv, pose, intrinsics)
rays = -rays.to(device)
all_poses.append(camera)
all_imgs.append(rgb)
all_masks.append(mask)
@ -198,7 +203,7 @@ class PixelNeRFDTUDataset(data.Dataset):
# only for neural renderer
all_K.append(torch.tensor(K).to(device))
all_R.append(torch.tensor(R).to(device))
all_T.append(torch.tensor(t[:3]/t[3]).to(device))
all_T.append(torch.tensor(t[:3] / t[3]).to(device))
all_imgs = torch.stack(all_imgs)
all_masks = torch.stack(all_masks)
@ -210,19 +215,20 @@ class PixelNeRFDTUDataset(data.Dataset):
all_T = torch.stack(all_T).permute(0, 2, 1).float()
uv = uv_creation((all_imgs.size(2), all_imgs.size(1)))
self.data = {'rgbs': all_imgs,
'masks': all_masks,
'poses': all_poses,
'rays': all_rays,
'uv': uv,
'light_pose': all_light_pose, # for rendering lights
'K': all_K,
'R': all_R,
'T': all_T,
}
self.data = {
"rgbs": all_imgs,
"masks": all_masks,
"poses": all_poses,
"rays": all_rays,
"uv": uv,
"light_pose": all_light_pose, # for rendering lights
"K": all_K,
"R": all_R,
"T": all_T,
}
def __len__(self):
return 1
def __getitem__(self, idx):
return self.data
return self.data

View file

@ -1,17 +1,17 @@
import torch
import torch.nn as nn
from src.utils import spec_gaussian_filter, fftfreqs, img, grid_interp, point_rasterize
import numpy as np
import torch
import torch.fft
import torch.nn as nn
from src.utils import fftfreqs, grid_interp, img, point_rasterize, spec_gaussian_filter
class DPSR(nn.Module):
def __init__(self, res, sig=10, scale=True, shift=True):
"""
:param res: tuple of output field resolution. eg., (128,128)
""":param res: tuple of output field resolution. eg., (128,128)
:param sig: degree of gaussian smoothing
"""
super(DPSR, self).__init__()
super().__init__()
self.res = res
self.sig = sig
self.dim = len(res)
@ -22,45 +22,44 @@ class DPSR(nn.Module):
self.scale = scale
self.shift = shift
self.register_buffer("G", G)
def forward(self, V, N):
"""
:param V: (batch, nv, 2 or 3) tensor for point cloud coordinates
""":param V: (batch, nv, 2 or 3) tensor for point cloud coordinates
:param N: (batch, nv, 2 or 3) tensor for point normals
:return phi: (batch, res, res, ...) tensor of output indicator function field
"""
assert(V.shape == N.shape) # [b, nv, ndims]
assert V.shape == N.shape # [b, nv, ndims]
ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2]
ras_s = torch.fft.rfftn(ras_p, dim=(2,3,4))
ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+1))+[self.dim+1, 1]))
N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1]
omega = fftfreqs(self.res, dtype=torch.float32).unsqueeze(-1) # [dim0, dim1, dim2/2+1, n_dim, 1]
ras_s = torch.fft.rfftn(ras_p, dim=(2, 3, 4))
ras_s = ras_s.permute(*tuple([0, *list(range(2, self.dim + 1)), self.dim + 1, 1]))
N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1]
omega = fftfreqs(self.res, dtype=torch.float32).unsqueeze(-1) # [dim0, dim1, dim2/2+1, n_dim, 1]
omega *= 2 * np.pi # normalize frequencies
omega = omega.to(V.device)
DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2)
Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1]
Phi = DivN / (Lap+1e-6) # [b, dim0, dim1, dim2/2+1, 2]
Phi = Phi.permute(*tuple([list(range(1,self.dim+2)) + [0]])) # [dim0, dim1, dim2/2+1, 2, b]
Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1]
Phi = DivN / (Lap + 1e-6) # [b, dim0, dim1, dim2/2+1, 2]
Phi = Phi.permute(*tuple([[*list(range(1, self.dim + 2)), 0]])) # [dim0, dim1, dim2/2+1, 2, b]
Phi[tuple([0] * self.dim)] = 0
Phi = Phi.permute(*tuple([[self.dim+1] + list(range(self.dim+1))])) # [b, dim0, dim1, dim2/2+1, 2]
phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1,2,3))
Phi = Phi.permute(*tuple([[self.dim + 1, *list(range(self.dim + 1))]])) # [b, dim0, dim1, dim2/2+1, 2]
phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1, 2, 3))
if self.shift or self.scale:
# ensure values at points are zero
fv = grid_interp(phi.unsqueeze(-1), V, batched=True).squeeze(-1) # [b, nv]
if self.shift: # offset points to have mean of 0
offset = torch.mean(fv, dim=-1) # [b,]
fv = grid_interp(phi.unsqueeze(-1), V, batched=True).squeeze(-1) # [b, nv]
if self.shift: # offset points to have mean of 0
offset = torch.mean(fv, dim=-1) # [b,]
phi -= offset.view(*tuple([-1] + [1] * self.dim))
phi = phi.permute(*tuple([list(range(1,self.dim+1)) + [0]]))
phi = phi.permute(*tuple([[*list(range(1, self.dim + 1)), 0]]))
fv0 = phi[tuple([0] * self.dim)] # [b,]
phi = phi.permute(*tuple([[self.dim] + list(range(self.dim))]))
phi = phi.permute(*tuple([[self.dim, *list(range(self.dim))]]))
if self.scale:
phi = -phi / torch.abs(fv0.view(*tuple([-1]+[1] * self.dim))) *0.5
return phi
phi = -phi / torch.abs(fv0.view(*tuple([-1] + [1] * self.dim))) * 0.5
return phi

View file

@ -1,43 +1,45 @@
import logging
import numpy as np
import trimesh
from pykdtree.kdtree import KDTree
EMPTY_PCL_DICT = {
'completeness': np.sqrt(3),
'accuracy': np.sqrt(3),
'completeness2': 3,
'accuracy2': 3,
'chamfer': 6,
"completeness": np.sqrt(3),
"accuracy": np.sqrt(3),
"completeness2": 3,
"accuracy2": 3,
"chamfer": 6,
}
EMPTY_PCL_DICT_NORMALS = {
'normals completeness': -1.,
'normals accuracy': -1.,
'normals': -1.,
"normals completeness": -1.0,
"normals accuracy": -1.0,
"normals": -1.0,
}
logger = logging.getLogger(__name__)
class MeshEvaluator(object):
''' Mesh evaluation class.
class MeshEvaluator:
"""Mesh evaluation class.
It handles the mesh evaluation process.
Args:
n_points (int): number of points to be used for evaluation
'''
n_points (int): number of points to be used for evaluation.
"""
def __init__(self, n_points=100000):
self.n_points = n_points
def eval_mesh(self, mesh, pointcloud_tgt, normals_tgt, thresholds=np.linspace(1./1000, 1, 1000)):
''' Evaluates a mesh.
def eval_mesh(self, mesh, pointcloud_tgt, normals_tgt, thresholds=np.linspace(1.0 / 1000, 1, 1000)):
"""Evaluates a mesh.
Args:
mesh (trimesh): mesh which should be evaluated
pointcloud_tgt (numpy array): target point cloud
normals_tgt (numpy array): target normals
thresholds (numpy arry): for F-Score
'''
thresholds (numpy arry): for F-Score.
"""
if len(mesh.vertices) != 0 and len(mesh.faces) != 0:
pointcloud, idx = mesh.sample(self.n_points, return_index=True)
@ -47,25 +49,25 @@ class MeshEvaluator(object):
pointcloud = np.empty((0, 3))
normals = np.empty((0, 3))
out_dict = self.eval_pointcloud(
pointcloud, pointcloud_tgt, normals, normals_tgt, thresholds=thresholds)
out_dict = self.eval_pointcloud(pointcloud, pointcloud_tgt, normals, normals_tgt, thresholds=thresholds)
return out_dict
def eval_pointcloud(self, pointcloud, pointcloud_tgt,
normals=None, normals_tgt=None,
thresholds=np.linspace(1./1000, 1, 1000)):
''' Evaluates a point cloud.
def eval_pointcloud(
self, pointcloud, pointcloud_tgt, normals=None, normals_tgt=None, thresholds=np.linspace(1.0 / 1000, 1, 1000),
):
"""Evaluates a point cloud.
Args:
pointcloud (numpy array): predicted point cloud
pointcloud_tgt (numpy array): target point cloud
normals (numpy array): predicted normals
normals_tgt (numpy array): target normals
thresholds (numpy array): threshold values for the F-score calculation
'''
thresholds (numpy array): threshold values for the F-score calculation.
"""
# Return maximum losses if pointcloud is empty
if pointcloud.shape[0] == 0:
logger.warn('Empty pointcloud / mesh detected!')
logger.warning("Empty pointcloud / mesh detected!")
out_dict = EMPTY_PCL_DICT.copy()
if normals is not None and normals_tgt is not None:
out_dict.update(EMPTY_PCL_DICT_NORMALS)
@ -74,11 +76,13 @@ class MeshEvaluator(object):
pointcloud = np.asarray(pointcloud)
pointcloud_tgt = np.asarray(pointcloud_tgt)
# Completeness: how far are the points of the target point cloud
# from thre predicted point cloud
completeness, completeness_normals = distance_p2p(
pointcloud_tgt, normals_tgt, pointcloud, normals
pointcloud_tgt,
normals_tgt,
pointcloud,
normals,
)
recall = get_threshold_percentage(completeness, thresholds)
completeness2 = completeness**2
@ -90,7 +94,10 @@ class MeshEvaluator(object):
# Accuracy: how far are th points of the predicted pointcloud
# from the target pointcloud
accuracy, accuracy_normals = distance_p2p(
pointcloud, normals, pointcloud_tgt, normals_tgt
pointcloud,
normals,
pointcloud_tgt,
normals_tgt,
)
precision = get_threshold_percentage(accuracy, thresholds)
accuracy2 = accuracy**2
@ -101,68 +108,61 @@ class MeshEvaluator(object):
# Chamfer distance
chamferL2 = 0.5 * (completeness2 + accuracy2)
normals_correctness = (
0.5 * completeness_normals + 0.5 * accuracy_normals
)
normals_correctness = 0.5 * completeness_normals + 0.5 * accuracy_normals
chamferL1 = 0.5 * (completeness + accuracy)
# F-Score
F = [
2 * precision[i] * recall[i] / (precision[i] + recall[i])
for i in range(len(precision))
]
F = [2 * precision[i] * recall[i] / (precision[i] + recall[i]) for i in range(len(precision))]
out_dict = {
'completeness': completeness,
'accuracy': accuracy,
'normals completeness': completeness_normals,
'normals accuracy': accuracy_normals,
'normals': normals_correctness,
'completeness2': completeness2,
'accuracy2': accuracy2,
'chamfer-L2': chamferL2,
'chamfer-L1': chamferL1,
'f-score': F[9], # threshold = 1.0%
'f-score-15': F[14], # threshold = 1.5%
'f-score-20': F[19], # threshold = 2.0%
"completeness": completeness,
"accuracy": accuracy,
"normals completeness": completeness_normals,
"normals accuracy": accuracy_normals,
"normals": normals_correctness,
"completeness2": completeness2,
"accuracy2": accuracy2,
"chamfer-L2": chamferL2,
"chamfer-L1": chamferL1,
"f-score": F[9], # threshold = 1.0%
"f-score-15": F[14], # threshold = 1.5%
"f-score-20": F[19], # threshold = 2.0%
}
return out_dict
def distance_p2p(points_src, normals_src, points_tgt, normals_tgt):
''' Computes minimal distances of each point in points_src to points_tgt.
"""Computes minimal distances of each point in points_src to points_tgt.
Args:
points_src (numpy array): source points
normals_src (numpy array): source normals
points_tgt (numpy array): target points
normals_tgt (numpy array): target normals
'''
normals_tgt (numpy array): target normals.
"""
kdtree = KDTree(points_tgt)
dist, idx = kdtree.query(points_src)
if normals_src is not None and normals_tgt is not None:
normals_src = \
normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True)
normals_tgt = \
normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True)
normals_src = normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True)
normals_tgt = normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True)
normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1)
# Handle normals that point into wrong direction gracefully
# (mostly due to mehtod not caring about this in generation)
normals_dot_product = np.abs(normals_dot_product)
else:
normals_dot_product = np.array(
[np.nan] * points_src.shape[0], dtype=np.float32)
normals_dot_product = np.array([np.nan] * points_src.shape[0], dtype=np.float32)
return dist, normals_dot_product
def get_threshold_percentage(dist, thresholds):
''' Evaluates a point cloud.
"""Evaluates a point cloud.
Args:
dist (numpy array): calculated distance
thresholds (numpy array): threshold values for the F-score calculation
'''
in_threshold = [
(dist <= t).mean() for t in thresholds
]
return in_threshold
thresholds (numpy array): threshold values for the F-score calculation.
"""
in_threshold = [(dist <= t).mean() for t in thresholds]
return in_threshold

View file

@ -1,11 +1,12 @@
import torch
import time
import trimesh
import numpy as np
import torch
from src.utils import mc_from_psr
class Generator3D(object):
''' Generator class for Occupancy Networks.
class Generator3D:
"""Generator class for Occupancy Networks.
It provides functions to generate the final mesh as well refining options.
@ -17,11 +18,20 @@ class Generator3D(object):
padding (float): how much padding should be used for MISE
sample (bool): whether z should be sampled
input_type (str): type of input
'''
"""
def __init__(self, model, points_batch_size=100000,
threshold=0.5, device=None, padding=0.1,
sample=False, input_type = None, dpsr=None, psr_tanh=True):
def __init__(
self,
model,
points_batch_size=100000,
threshold=0.5,
device=None,
padding=0.1,
sample=False,
input_type=None,
dpsr=None,
psr_tanh=True,
):
self.model = model.to(device)
self.points_batch_size = points_batch_size
self.threshold = threshold
@ -31,33 +41,32 @@ class Generator3D(object):
self.sample = sample
self.dpsr = dpsr
self.psr_tanh = psr_tanh
def generate_mesh(self, data, return_stats=True):
''' Generates the output mesh.
"""Generates the output mesh.
Args:
data (tensor): data tensor
return_stats (bool): whether stats should be returned
'''
"""
self.model.eval()
device = self.device
stats_dict = {}
p = data.get('inputs', torch.empty(1, 0)).to(device)
p = data.get("inputs", torch.empty(1, 0)).to(device)
t0 = time.time()
points, normals = self.model(p)
t1 = time.time()
psr_grid = self.dpsr(points, normals)
t2 = time.time()
v, f, _ = mc_from_psr(psr_grid,
zero_level=self.threshold)
stats_dict['pcl'] = t1 - t0
stats_dict['dpsr'] = t2 - t1
stats_dict['mc'] = time.time() - t2
stats_dict['total'] = time.time() - t0
v, f, _ = mc_from_psr(psr_grid, zero_level=self.threshold)
stats_dict["pcl"] = t1 - t0
stats_dict["dpsr"] = t2 - t1
stats_dict["mc"] = time.time() - t2
stats_dict["total"] = time.time() - t0
if return_stats:
return v, f, points, normals, stats_dict
else:
return v, f, points, normals
return v, f, points, normals

View file

@ -1,18 +1,17 @@
import torch
import numpy as np
import time
from src.utils import point_rasterize, grid_interp, mc_from_psr, \
calc_inters_points
from src.dpsr import DPSR
import torch
import torch.nn as nn
from src.network import encoder_dict, decoder_dict
from src.network import decoder_dict, encoder_dict
from src.network.utils import map2local
from src.utils import calc_inters_points, grid_interp, mc_from_psr, point_rasterize
class PSR2Mesh(torch.autograd.Function):
@staticmethod
def forward(ctx, psr_grid):
"""
In the forward pass we receive a Tensor containing the input and return
"""In the forward pass we receive a Tensor containing the input and return
a Tensor containing the output. ctx is a context object that can be used
to stash information for backward computation. You can cache arbitrary
objects for use in the backward pass using the ctx.save_for_backward method.
@ -29,8 +28,7 @@ class PSR2Mesh(torch.autograd.Function):
@staticmethod
def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals):
"""
In the backward pass we receive a Tensor containing the gradient of the loss
"""In the backward pass we receive a Tensor containing the gradient of the loss
with respect to the output, and we need to compute the gradient of the loss
with respect to the input.
"""
@ -39,17 +37,17 @@ class PSR2Mesh(torch.autograd.Function):
# matrix multiplication between dL/dV and dV/dPSR
# dV/dPSR = - normals
grad_vert = torch.matmul(dL_dVertex.permute(1, 0, 2), -normals.permute(1, 2, 0))
grad_grid = point_rasterize(vert_pts, grad_vert.permute(1, 0, 2), res) # b x 1 x res x res x res
grad_grid = point_rasterize(vert_pts, grad_vert.permute(1, 0, 2), res) # b x 1 x res x res x res
return grad_grid
class PSR2SurfacePoints(torch.autograd.Function):
@staticmethod
def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample):
verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
verts = verts * 2. - 1. # within the range of [-1, 1]
verts = verts * 2.0 - 1.0 # within the range of [-1, 1]
p_all, n_all, mask_all = [], [], []
for i in range(len(poses)):
@ -67,7 +65,6 @@ class PSR2SurfacePoints(torch.autograd.Function):
n_inters_all = torch.cat(n_all, dim=0)
mask_visible = torch.stack(mask_all, dim=0)
res = torch.tensor(psr_grid.detach().shape[2])
ctx.save_for_backward(p_inters_all, n_inters_all, res)
@ -80,30 +77,31 @@ class PSR2SurfacePoints(torch.autograd.Function):
# grad from the p_inters via MLP renderer
grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None])
grad_grid_pts = point_rasterize((pts[None]+1)/2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res
grad_grid_pts = point_rasterize((pts[None] + 1) / 2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res
return grad_grid_pts, None, None, None, None, None
class Encode2Points(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
encoder = cfg['model']['encoder']
decoder = cfg['model']['decoder']
dim = cfg['data']['dim'] # input dim
c_dim = cfg['model']['c_dim']
encoder_kwargs = cfg['model']['encoder_kwargs']
if encoder_kwargs == None:
encoder = cfg["model"]["encoder"]
decoder = cfg["model"]["decoder"]
dim = cfg["data"]["dim"] # input dim
c_dim = cfg["model"]["c_dim"]
encoder_kwargs = cfg["model"]["encoder_kwargs"]
if encoder_kwargs is None:
encoder_kwargs = {}
decoder_kwargs = cfg['model']['decoder_kwargs']
padding = cfg['data']['padding']
self.predict_normal = cfg['model']['predict_normal']
self.predict_offset = cfg['model']['predict_offset']
decoder_kwargs = cfg["model"]["decoder_kwargs"]
cfg["data"]["padding"]
self.predict_normal = cfg["model"]["predict_normal"]
self.predict_offset = cfg["model"]["predict_offset"]
out_dim = 3
out_dim_offset = 3
num_offset = cfg['data']['num_offset']
num_offset = cfg["data"]["num_offset"]
# each point predict more than one offset to add output points
if num_offset > 1:
out_dim_offset = out_dim * num_offset
@ -111,44 +109,40 @@ class Encode2Points(nn.Module):
# local mapping
self.map2local = None
if cfg['model']['local_coord']:
if 'unet' in encoder_kwargs.keys():
unit_size = 1 / encoder_kwargs['plane_resolution']
if cfg["model"]["local_coord"]:
if "unet" in encoder_kwargs.keys():
unit_size = 1 / encoder_kwargs["plane_resolution"]
else:
unit_size = 1 / encoder_kwargs['grid_resolution']
unit_size = 1 / encoder_kwargs["grid_resolution"]
local_mapping = map2local(unit_size)
self.encoder = encoder_dict[encoder](
dim=dim, c_dim=c_dim, map2local=local_mapping,
**encoder_kwargs
dim=dim,
c_dim=c_dim,
map2local=local_mapping,
**encoder_kwargs,
)
if self.predict_normal:
# decoder for normal prediction
self.decoder_normal = decoder_dict[decoder](
dim=dim, c_dim=c_dim, out_dim=out_dim,
**decoder_kwargs)
self.decoder_normal = decoder_dict[decoder](dim=dim, c_dim=c_dim, out_dim=out_dim, **decoder_kwargs)
if self.predict_offset:
# decoder for offset prediction
self.decoder_offset = decoder_dict[decoder](
dim=dim, c_dim=c_dim, out_dim=out_dim_offset,
map2local=local_mapping,
**decoder_kwargs)
dim=dim, c_dim=c_dim, out_dim=out_dim_offset, map2local=local_mapping, **decoder_kwargs,
)
self.s_off = cfg["model"]["s_offset"]
self.s_off = cfg['model']['s_offset']
def forward(self, p):
''' Performs a forward pass through the network.
"""Performs a forward pass through the network.
Args:
p (tensor): input unoriented points
'''
"""
time_dict = {}
mask = None
batch_size = p.size(0)
points = p.clone()
@ -168,14 +162,12 @@ class Encode2Points(nn.Module):
if self.predict_normal:
normals = self.decoder_normal(points, c)
t2 = time.perf_counter()
time_dict['encode'] = t1 - t0
time_dict['predict'] = t2 - t1
points = torch.clamp(points, 0.0, 0.99)
if self.cfg['model']['normal_normalize']:
normals = normals / (normals.norm(dim=-1, keepdim=True)+1e-8)
time_dict["encode"] = t1 - t0
time_dict["predict"] = t2 - t1
points = torch.clamp(points, 0.0, 0.99)
if self.cfg["model"]["normal_normalize"]:
normals = normals / (normals.norm(dim=-1, keepdim=True) + 1e-8)
return points, normals

View file

@ -1,17 +1,19 @@
import torch
from pytorch3d.renderer import (
MeshRasterizer,
MeshRenderer,
PerspectiveCameras,
RasterizationSettings,
SoftSilhouetteShader,
)
from pytorch3d.structures import Meshes
from src.network.net_rgb import RenderingNetwork
from src.utils import approx_psr_grad
from pytorch3d.renderer import (
RasterizationSettings,
PerspectiveCameras,
MeshRenderer,
MeshRasterizer,
SoftSilhouetteShader)
from pytorch3d.structures import Meshes
def approx_psr_grad(psr_grid, res, normalize=True):
delta_x = delta_y = delta_z = 1/res
delta_x = delta_y = delta_z = 1 / res
psr_pad = torch.nn.ReplicationPad3d(1)(psr_grid).squeeze()
grad_x = (psr_pad[2:, :, :] - psr_pad[:-2, :, :]) / 2 / delta_x
@ -35,76 +37,77 @@ class SAP2Image(nn.Module):
self.psr2sur = PSR2SurfacePoints.apply
self.psr2mesh = PSR2Mesh.apply
# initialize DPSR
self.dpsr = DPSR(res=(cfg['model']['grid_res'],
cfg['model']['grid_res'],
cfg['model']['grid_res']),
sig=cfg['model']['psr_sigma'])
self.dpsr = DPSR(
res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
sig=cfg["model"]["psr_sigma"],
)
self.cfg = cfg
if cfg['train']['l_weight']['rgb'] != 0.:
self.rendering_network = RenderingNetwork(**cfg['model']['renderer'])
if cfg["train"]["l_weight"]["rgb"] != 0.0:
self.rendering_network = RenderingNetwork(**cfg["model"]["renderer"])
if cfg['train']['l_weight']['mask'] != 0.:
if cfg["train"]["l_weight"]["mask"] != 0.0:
# initialize rasterizer
sigma = 1e-4
raster_settings_soft = RasterizationSettings(
image_size=img_size,
blur_radius=np.log(1. / 1e-4 - 1.)*sigma,
image_size=img_size,
blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma,
faces_per_pixel=150,
perspective_correct=False
perspective_correct=False,
)
# initialize silhouette renderer
# initialize silhouette renderer
self.mesh_rasterizer = MeshRenderer(
rasterizer=MeshRasterizer(
raster_settings=raster_settings_soft
raster_settings=raster_settings_soft,
),
shader=SoftSilhouetteShader()
shader=SoftSilhouetteShader(),
)
self.cfg = cfg
self.img_size = img_size
def forward(self, inputs, data):
points, normals = inputs[...,:3], inputs[...,3:]
points, normals = inputs[..., :3], inputs[..., 3:]
points = torch.sigmoid(points)
normals = normals / normals.norm(dim=-1, keepdim=True)
# DPSR to get grid
psr_grid = self.dpsr(points, normals).unsqueeze(1)
psr_grid = torch.tanh(psr_grid)
return self.render_img(psr_grid, data)
def render_img(self, psr_grid, data):
n_views = len(data['masks'])
n_views_per_iter = self.cfg['data']['n_views_per_iter']
rgb_render_mode = self.cfg['model']['renderer']['mode']
uv = data['uv']
n_views = len(data["masks"])
n_views_per_iter = self.cfg["data"]["n_views_per_iter"]
self.cfg["model"]["renderer"]["mode"]
uv = data["uv"]
idx = np.random.randint(0, n_views, n_views_per_iter)
pose = [data['poses'][i] for i in idx]
rgb = data['rgbs'][idx]
mask_gt = data['masks'][idx]
pose = [data["poses"][i] for i in idx]
rgb = data["rgbs"][idx]
mask_gt = data["masks"][idx]
ray = None
pred_rgb = None
pred_mask = None
if self.cfg['train']['l_weight']['rgb'] != 0.:
psr_grad = approx_psr_grad(psr_grid, self.cfg['model']['grid_res'])
if self.cfg["train"]["l_weight"]["rgb"] != 0.0:
psr_grad = approx_psr_grad(psr_grid, self.cfg["model"]["grid_res"])
p_inters, visible_mask = self.psr2sur(psr_grid, pose, self.img_size, uv, psr_grad, None)
n_inters = grid_interp(psr_grad[None], (p_inters.detach()[None] + 1) / 2)
fea_interp = None
if 'rays' in data.keys():
ray = data['rays'].squeeze()[idx][visible_mask]
pred_rgb = self.rendering_network(p_inters, normals=n_inters.squeeze(), view_dirs=ray, feature_vectors=fea_interp)
if "rays" in data.keys():
ray = data["rays"].squeeze()[idx][visible_mask]
pred_rgb = self.rendering_network(
p_inters, normals=n_inters.squeeze(), view_dirs=ray, feature_vectors=fea_interp,
)
# silhouette loss
if self.cfg['train']['l_weight']['mask'] != 0.:
if self.cfg["train"]["l_weight"]["mask"] != 0.0:
# build mesh
v, f, _ = self.psr2mesh(psr_grid)
v = v * 2. - 1 # within the range of [-1, 1]
v = v * 2.0 - 1 # within the range of [-1, 1]
# ! Fast but more GPU usage
mesh = Meshes(verts=[v.squeeze()], faces=[f.squeeze()])
if True:
@ -114,11 +117,7 @@ class SAP2Image(nn.Module):
T = torch.cat([p.T for p in pose], dim=0)
focal = torch.cat([p.focal_length for p in pose], dim=0)
pp = torch.cat([p.principal_point for p in pose], dim=0)
pose_cur = PerspectiveCameras(
focal_length=focal,
principal_point=pp,
R=R, T=T,
device=R.device)
pose_cur = PerspectiveCameras(focal_length=focal, principal_point=pp, R=R, T=T, device=R.device)
pred_mask = self.mesh_rasterizer(mesh.extend(n_views_per_iter), cameras=pose_cur)[..., 3]
else:
pred_mask = []
@ -129,11 +128,11 @@ class SAP2Image(nn.Module):
pred_mask = torch.cat(pred_mask, dim=0)
output = {
'rgb': pred_rgb,
'rgb_gt': rgb,
'mask': pred_mask,
'mask_gt': mask_gt,
'vis_mask': visible_mask,
}
return output
"rgb": pred_rgb,
"rgb_gt": rgb,
"mask": pred_mask,
"mask_gt": mask_gt,
"vis_mask": visible_mask,
}
return output

View file

@ -1,8 +1,8 @@
from src.network import encoder, decoder
from src.network import decoder, encoder
encoder_dict = {
'local_pool_pointnet': encoder.LocalPoolPointnet,
"local_pool_pointnet": encoder.LocalPoolPointnet,
}
decoder_dict = {
'simple_local': decoder.LocalDecoder,
}
"simple_local": decoder.LocalDecoder,
}

View file

@ -1,15 +1,17 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ipdb import set_trace as st
from src.network.utils import normalize_3d_coordinate, ResnetBlockFC, \
normalize_coordinate
from src.network.utils import (
ResnetBlockFC,
normalize_3d_coordinate,
normalize_coordinate,
)
class LocalDecoder(nn.Module):
''' Decoder.
"""Decoder.
Instead of conditioning on global features, on plane/volume local features.
Args:
dim (int): input dimension
c_dim (int): dimension of latent conditioned code c
@ -17,26 +19,31 @@ class LocalDecoder(nn.Module):
n_blocks (int): number of blocks ResNetBlockFC layers
leaky (bool): whether to use leaky ReLUs
sample_mode (str): sampling feature strategy, bilinear|nearest
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
'''
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55].
"""
def __init__(self, dim=3, c_dim=128, out_dim=3,
hidden_size=256, n_blocks=5, leaky=False, sample_mode='bilinear', padding=0.1, map2local=None):
def __init__(
self,
dim=3,
c_dim=128,
out_dim=3,
hidden_size=256,
n_blocks=5,
leaky=False,
sample_mode="bilinear",
padding=0.1,
map2local=None,
):
super().__init__()
self.c_dim = c_dim
self.n_blocks = n_blocks
if c_dim != 0:
self.fc_c = nn.ModuleList([
nn.Linear(c_dim, hidden_size) for i in range(n_blocks)
])
self.fc_c = nn.ModuleList([nn.Linear(c_dim, hidden_size) for i in range(n_blocks)])
self.fc_p = nn.Linear(dim, hidden_size)
self.blocks = nn.ModuleList([
ResnetBlockFC(hidden_size) for i in range(n_blocks)
])
self.blocks = nn.ModuleList([ResnetBlockFC(hidden_size) for i in range(n_blocks)])
self.fc_out = nn.Linear(hidden_size, out_dim)
@ -49,46 +56,45 @@ class LocalDecoder(nn.Module):
self.padding = padding
self.map2local = map2local
self.out_dim = out_dim
def sample_plane_feature(self, p, c, plane='xz'):
xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
def sample_plane_feature(self, p, c, plane="xz"):
xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
xy = xy[:, :, None].float()
vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
c = F.grid_sample(c, vgrid, padding_mode='border',
align_corners=True,
mode=self.sample_mode).squeeze(-1)
vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
c = F.grid_sample(c, vgrid, padding_mode="border", align_corners=True, mode=self.sample_mode).squeeze(-1)
return c
def sample_grid_feature(self, p, c):
p_nor = normalize_3d_coordinate(p.clone())
p_nor = p_nor[:, :, None, None].float()
vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1)
vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1)
# acutally trilinear interpolation if mode = 'bilinear'
c = F.grid_sample(c, vgrid, padding_mode='border',
align_corners=True,
mode=self.sample_mode).squeeze(-1).squeeze(-1)
c = (
F.grid_sample(c, vgrid, padding_mode="border", align_corners=True, mode=self.sample_mode)
.squeeze(-1)
.squeeze(-1)
)
return c
def forward(self, p, c_plane, **kwargs):
batch_size = p.shape[0]
plane_type = list(c_plane.keys())
c = 0
if 'grid' in plane_type:
c += self.sample_grid_feature(p, c_plane['grid'])
if 'xz' in plane_type:
c += self.sample_plane_feature(p, c_plane['xz'], plane='xz')
if 'xy' in plane_type:
c += self.sample_plane_feature(p, c_plane['xy'], plane='xy')
if 'yz' in plane_type:
c += self.sample_plane_feature(p, c_plane['yz'], plane='yz')
if "grid" in plane_type:
c += self.sample_grid_feature(p, c_plane["grid"])
if "xz" in plane_type:
c += self.sample_plane_feature(p, c_plane["xz"], plane="xz")
if "xy" in plane_type:
c += self.sample_plane_feature(p, c_plane["xy"], plane="xy")
if "yz" in plane_type:
c += self.sample_plane_feature(p, c_plane["yz"], plane="yz")
c = c.transpose(1, 2)
p = p.float()
if self.map2local:
p = self.map2local(p)
net = self.fc_p(p)
for i in range(self.n_blocks):
@ -98,9 +104,8 @@ class LocalDecoder(nn.Module):
net = self.blocks[i](net)
out = self.fc_out(self.actvn(net))
if self.out_dim > 3:
out = out.reshape(batch_size, -1, 3)
return out
return out

View file

@ -1,17 +1,23 @@
import torch
import torch.nn as nn
import numpy as np
from src.network.unet3d import UNet3D
from torch_scatter import scatter_max, scatter_mean
from src.network.unet import UNet
from ipdb import set_trace as st
from torch_scatter import scatter_mean, scatter_max
from src.network.utils import get_embedder, normalize_3d_coordinate,\
coordinate2index, ResnetBlockFC, normalize_coordinate
from src.network.unet3d import UNet3D
from src.network.utils import (
ResnetBlockFC,
coordinate2index,
get_embedder,
normalize_3d_coordinate,
normalize_coordinate,
)
class LocalPoolPointnet(nn.Module):
''' PointNet-based encoder network with ResNet blocks for each point.
"""PointNet-based encoder network with ResNet blocks for each point.
Number of input points are fixed.
Args:
c_dim (int): dimension of latent code c
dim (int): input points dimension
@ -22,26 +28,38 @@ class LocalPoolPointnet(nn.Module):
unet3d (bool): weather to use 3D U-Net
unet3d_kwargs (str): 3D U-Net parameters
plane_resolution (int): defined resolution for plane feature
grid_resolution (int): defined resolution for grid feature
grid_resolution (int): defined resolution for grid feature
plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
n_blocks (int): number of blocks ResNetBlockFC layers
map2local (function): map global coordintes to local ones
pos_encoding (int): frequency for the positional encoding
'''
"""
def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max',
unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None,
plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5,
map2local=None, pos_encoding=0):
def __init__(
self,
c_dim=128,
dim=3,
hidden_dim=128,
scatter_type="max",
unet=False,
unet_kwargs=None,
unet3d=False,
unet3d_kwargs=None,
plane_resolution=None,
grid_resolution=None,
plane_type="xz",
padding=0.1,
n_blocks=5,
map2local=None,
pos_encoding=0,
):
super().__init__()
self.c_dim = c_dim
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
self.blocks = nn.ModuleList([
ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
])
self.fc_pos = nn.Linear(dim, 2 * hidden_dim)
self.blocks = nn.ModuleList([ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks)])
self.fc_c = nn.Linear(hidden_dim, c_dim)
self.actvn = nn.ReLU()
@ -59,34 +77,36 @@ class LocalPoolPointnet(nn.Module):
self.reso_grid = grid_resolution
self.plane_type = plane_type
self.padding = padding
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
self.fc_pos = nn.Linear(dim, 2 * hidden_dim)
self.pe = None
if pos_encoding > 0:
embed_fn, input_ch = get_embedder(pos_encoding, d_in=dim)
self.pe = embed_fn
self.fc_pos = nn.Linear(input_ch, 2*hidden_dim)
self.fc_pos = nn.Linear(input_ch, 2 * hidden_dim)
self.map2local = map2local
if scatter_type == 'max':
if scatter_type == "max":
self.scatter = scatter_max
elif scatter_type == 'mean':
elif scatter_type == "mean":
self.scatter = scatter_mean
else:
raise ValueError('incorrect scatter type')
msg = "incorrect scatter type"
raise ValueError(msg)
def generate_plane_features(self, p, c, plane='xz'):
def generate_plane_features(self, p, c, plane="xz"):
# acquire indices of features in plane
xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1)
index = coordinate2index(xy, self.reso_plane)
# scatter plane features from points
fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
c = c.permute(0, 2, 1) # B x 512 x T
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)
c = c.permute(0, 2, 1) # B x 512 x T
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
fea_plane = fea_plane.reshape(
p.size(0), self.c_dim, self.reso_plane, self.reso_plane,
) # sparce matrix (B x 512 x reso x reso)
# process the plane features with UNet
if self.unet is not None:
@ -96,12 +116,14 @@ class LocalPoolPointnet(nn.Module):
def generate_grid_features(self, p, c):
p_nor = normalize_3d_coordinate(p.clone())
index = coordinate2index(p_nor, self.reso_grid, coord_type='3d')
index = coordinate2index(p_nor, self.reso_grid, coord_type="3d")
# scatter grid features from points
fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3)
c = c.permute(0, 2, 1)
fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3
fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) # sparce matrix (B x 512 x reso x reso)
fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3
fea_grid = fea_grid.reshape(
p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid,
) # sparce matrix (B x 512 x reso x reso)
if self.unet3d is not None:
fea_grid = self.unet3d(fea_grid)
@ -115,7 +137,7 @@ class LocalPoolPointnet(nn.Module):
c_out = 0
for key in keys:
# scatter plane features from points
if key == 'grid':
if key == "grid":
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid**3)
else:
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2)
@ -126,33 +148,31 @@ class LocalPoolPointnet(nn.Module):
c_out += fea
return c_out.permute(0, 2, 1)
def forward(self, p, normalize=True):
batch_size, T, D = p.size()
# acquire the index for each point
coord = {}
index = {}
if 'xz' in self.plane_type:
coord['xz'] = normalize_coordinate(p.clone(), plane='xz')
index['xz'] = coordinate2index(coord['xz'], self.reso_plane)
if 'xy' in self.plane_type:
coord['xy'] = normalize_coordinate(p.clone(), plane='xy')
index['xy'] = coordinate2index(coord['xy'], self.reso_plane)
if 'yz' in self.plane_type:
coord['yz'] = normalize_coordinate(p.clone(), plane='yz')
index['yz'] = coordinate2index(coord['yz'], self.reso_plane)
if 'grid' in self.plane_type:
if "xz" in self.plane_type:
coord["xz"] = normalize_coordinate(p.clone(), plane="xz")
index["xz"] = coordinate2index(coord["xz"], self.reso_plane)
if "xy" in self.plane_type:
coord["xy"] = normalize_coordinate(p.clone(), plane="xy")
index["xy"] = coordinate2index(coord["xy"], self.reso_plane)
if "yz" in self.plane_type:
coord["yz"] = normalize_coordinate(p.clone(), plane="yz")
index["yz"] = coordinate2index(coord["yz"], self.reso_plane)
if "grid" in self.plane_type:
if normalize:
coord['grid'] = normalize_3d_coordinate(p.clone())
coord["grid"] = normalize_3d_coordinate(p.clone())
else:
coord['grid'] = p.clone()[...,:3]
index['grid'] = coordinate2index(coord['grid'], self.reso_grid, coord_type='3d')
coord["grid"] = p.clone()[..., :3]
index["grid"] = coordinate2index(coord["grid"], self.reso_grid, coord_type="3d")
if self.pe:
p = self.pe(p)
if self.map2local:
pp = self.map2local(p)
net = self.fc_pos(pp)
@ -169,13 +189,13 @@ class LocalPoolPointnet(nn.Module):
c = self.fc_c(net)
fea = {}
if 'grid' in self.plane_type:
fea['grid'] = self.generate_grid_features(p, c)
if 'xz' in self.plane_type:
fea['xz'] = self.generate_plane_features(p, c, plane='xz')
if 'xy' in self.plane_type:
fea['xy'] = self.generate_plane_features(p, c, plane='xy')
if 'yz' in self.plane_type:
fea['yz'] = self.generate_plane_features(p, c, plane='yz')
if "grid" in self.plane_type:
fea["grid"] = self.generate_grid_features(p, c)
if "xz" in self.plane_type:
fea["xz"] = self.generate_plane_features(p, c, plane="xz")
if "xy" in self.plane_type:
fea["xy"] = self.generate_plane_features(p, c, plane="xy")
if "yz" in self.plane_type:
fea["yz"] = self.generate_plane_features(p, c, plane="yz")
return fea
return fea

View file

@ -1,39 +1,41 @@
# code from IDR (https://github.com/lioryariv/idr/blob/main/code/model/implicit_differentiable_renderer.py)
import numpy as np
import torch
import torch.nn as nn
import numpy as np
from src.network.utils import get_embedder
from pdb import set_trace as st
class RenderingNetwork(nn.Module):
def __init__(
self,
fea_size=0,
mode='naive',
d_out=3,
dims=[512, 512, 512, 512],
weight_norm=True,
pe_freq_view=0 # for positional encoding
self,
fea_size=0,
mode="naive",
d_out=3,
dims=[512, 512, 512, 512],
weight_norm=True,
pe_freq_view=0, # for positional encoding
):
super().__init__()
self.mode = mode
if mode == 'naive':
if mode == "naive":
d_in = 3
elif mode == 'no_feature':
elif mode == "no_feature":
d_in = 3 + 3 + 3
fea_size = 0
elif mode == 'full':
elif mode == "full":
d_in = 3 + 3 + 3
else:
d_in = 3 + 3
dims = [d_in + fea_size] + dims + [d_out]
dims = [d_in + fea_size, *dims, d_out]
self.embedview_fn = None
if pe_freq_view > 0:
embedview_fn, input_ch = get_embedder(pe_freq_view, d_in=3)
self.embedview_fn = embedview_fn
dims[0] += (input_ch - 3)
dims[0] += input_ch - 3
self.num_layers = len(dims)
@ -54,13 +56,13 @@ class RenderingNetwork(nn.Module):
view_dirs = self.embedview_fn(view_dirs)
# points = self.embedview_fn(points)
if (self.mode == 'full') & (feature_vectors is not None):
if (self.mode == "full") & (feature_vectors is not None):
rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
elif (self.mode == 'no_feature') | ((self.mode == 'full') & (feature_vectors is None)):
elif (self.mode == "no_feature") | ((self.mode == "full") & (feature_vectors is None)):
rendering_input = torch.cat([points, view_dirs, normals], dim=-1)
elif self.mode == 'no_view_dir':
elif self.mode == "no_view_dir":
rendering_input = torch.cat([points, normals], dim=-1)
elif self.mode == 'no_normal':
elif self.mode == "no_normal":
rendering_input = torch.cat([points, view_dirs], dim=-1)
else:
rendering_input = points
@ -81,28 +83,27 @@ class RenderingNetwork(nn.Module):
class NeRFRenderingNetwork(nn.Module):
def __init__(
self,
feature_vector_size=0,
mode='naive',
d_in=3,
d_out=3,
dims=[512, 512, 512, 256],
weight_norm=True,
multires=0, # positional encoding of points
multires_view=0 # positional encoding of view
self,
feature_vector_size=0,
mode="naive",
d_in=3,
d_out=3,
dims=[512, 512, 512, 256],
weight_norm=True,
multires=0, # positional encoding of points
multires_view=0, # positional encoding of view
):
super().__init__()
self.mode = mode
dims = [d_in + feature_vector_size] + dims
dims = [d_in + feature_vector_size, *dims]
self.embed_fn = None
if multires > 0:
embed_fn, input_ch = get_embedder(multires, d_in=d_in)
self.embed_fn = embed_fn
dims[0] += (input_ch - 3)
dims[0] += input_ch - 3
self.num_layers = len(dims)
self.pts_net = nn.ModuleList([nn.Linear(dims[i], dims[i + 1]) for i in range(self.num_layers - 1)])
@ -113,13 +114,12 @@ class NeRFRenderingNetwork(nn.Module):
self.embedview_fn = embedview_fn
# dims[0] += (input_ch - 3)
if mode == 'full':
self.view_net = nn.ModuleList([nn.Linear(dims[-1]+view_ch, 128)])
if mode == "full":
self.view_net = nn.ModuleList([nn.Linear(dims[-1] + view_ch, 128)])
self.rgb_net = nn.Linear(128, 3)
else:
else:
self.rgb_net = nn.Linear(dims[-1], 3)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
@ -134,7 +134,7 @@ class NeRFRenderingNetwork(nn.Module):
x = net(x)
x = self.relu(x)
if self.mode=='full':
if self.mode == "full":
x = torch.cat([x, view_dirs], -1)
for net in self.view_net:
x = net(x)
@ -144,22 +144,23 @@ class NeRFRenderingNetwork(nn.Module):
x = self.tanh(x)
return x
class ImplicitNetwork(nn.Module):
def __init__(
self,
d_in,
d_out,
dims,
geometric_init=True,
feature_vector_size=0,
bias=1.0,
skip_in=(),
weight_norm=True,
multires=0
self,
d_in,
d_out,
dims,
geometric_init=True,
feature_vector_size=0,
bias=1.0,
skip_in=(),
weight_norm=True,
multires=0,
):
super().__init__()
dims = [d_in] + dims + [d_out + feature_vector_size]
dims = [d_in, *dims, d_out + feature_vector_size]
self.embed_fn = None
if multires > 0:
@ -189,7 +190,7 @@ class ImplicitNetwork(nn.Module):
elif multires > 0 and l in self.skip_in:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
@ -222,13 +223,9 @@ class ImplicitNetwork(nn.Module):
def gradient(self, x):
x.requires_grad_(True)
y = self.forward(x)[:,:1]
y = self.forward(x)[:, :1]
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return gradients.unsqueeze(1)
outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True,
)[0]
return gradients.unsqueeze(1)

View file

@ -1,55 +1,38 @@
'''
Codes are from:
https://github.com/jaxony/unet-pytorch/blob/master/model.py
'''
"""Codes are from:
https://github.com/jaxony/unet-pytorch/blob/master/model.py.
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import numpy as np
def conv3x3(in_channels, out_channels, stride=1,
padding=1, bias=True, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=padding,
bias=bias,
groups=groups)
def upconv2x2(in_channels, out_channels, mode='transpose'):
if mode == 'transpose':
return nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=2,
stride=2)
def conv3x3(in_channels, out_channels, stride=1, padding=1, bias=True, groups=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, bias=bias, groups=groups)
def upconv2x2(in_channels, out_channels, mode="transpose"):
if mode == "transpose":
return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
else:
# out_channels is always going to be the same
# as in_channels
return nn.Sequential(
nn.Upsample(mode='bilinear', scale_factor=2),
conv1x1(in_channels, out_channels))
return nn.Sequential(nn.Upsample(mode="bilinear", scale_factor=2), conv1x1(in_channels, out_channels))
def conv1x1(in_channels, out_channels, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
groups=groups,
stride=1)
return nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=groups, stride=1)
class DownConv(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 MaxPool.
"""A helper Module that performs 2 convolutions and 1 MaxPool.
A ReLU activation follows each convolution.
"""
def __init__(self, in_channels, out_channels, pooling=True):
super(DownConv, self).__init__()
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@ -71,39 +54,35 @@ class DownConv(nn.Module):
class UpConv(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 UpConvolution.
"""A helper Module that performs 2 convolutions and 1 UpConvolution.
A ReLU activation follows each convolution.
"""
def __init__(self, in_channels, out_channels,
merge_mode='concat', up_mode='transpose'):
super(UpConv, self).__init__()
def __init__(self, in_channels, out_channels, merge_mode="concat", up_mode="transpose"):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.merge_mode = merge_mode
self.up_mode = up_mode
self.upconv = upconv2x2(self.in_channels, self.out_channels,
mode=self.up_mode)
self.upconv = upconv2x2(self.in_channels, self.out_channels, mode=self.up_mode)
if self.merge_mode == 'concat':
self.conv1 = conv3x3(
2*self.out_channels, self.out_channels)
if self.merge_mode == "concat":
self.conv1 = conv3x3(2 * self.out_channels, self.out_channels)
else:
# num of input channels to conv2 is same
self.conv1 = conv3x3(self.out_channels, self.out_channels)
self.conv2 = conv3x3(self.out_channels, self.out_channels)
def forward(self, from_down, from_up):
""" Forward pass
"""Forward pass
Arguments:
from_down: tensor from the encoder pathway
from_up: upconv'd tensor from the decoder pathway
from_up: upconv'd tensor from the decoder pathway.
"""
from_up = self.upconv(from_up)
if self.merge_mode == 'concat':
if self.merge_mode == "concat":
x = torch.cat((from_up, from_down), 1)
else:
x = from_up + from_down
@ -113,7 +92,7 @@ class UpConv(nn.Module):
class UNet(nn.Module):
""" `UNet` class is based on https://arxiv.org/abs/1505.04597
"""`UNet` class is based on https://arxiv.org/abs/1505.04597.
The U-Net is a convolutional encoder-decoder neural network.
Contextual spatial information (from the decoding,
@ -135,44 +114,37 @@ class UNet(nn.Module):
the tranpose convolution (specified by upmode='transpose')
"""
def __init__(self, num_classes, in_channels=3, depth=5,
start_filts=64, up_mode='transpose',
merge_mode='concat', **kwargs):
def __init__(
self, num_classes, in_channels=3, depth=5, start_filts=64, up_mode="transpose", merge_mode="concat", **kwargs,
):
"""Arguments:
in_channels: int, number of channels in the input tensor.
Default is 3 for RGB images.
depth: int, number of MaxPools in the U-Net.
start_filts: int, number of convolutional filters for the
first conv.
up_mode: string, type of upconvolution. Choices: 'transpose'
for transpose convolution or 'upsample' for nearest neighbour
upsampling.
"""
Arguments:
in_channels: int, number of channels in the input tensor.
Default is 3 for RGB images.
depth: int, number of MaxPools in the U-Net.
start_filts: int, number of convolutional filters for the
first conv.
up_mode: string, type of upconvolution. Choices: 'transpose'
for transpose convolution or 'upsample' for nearest neighbour
upsampling.
"""
super(UNet, self).__init__()
super().__init__()
if up_mode in ('transpose', 'upsample'):
if up_mode in ("transpose", "upsample"):
self.up_mode = up_mode
else:
raise ValueError("\"{}\" is not a valid mode for "
"upsampling. Only \"transpose\" and "
"\"upsample\" are allowed.".format(up_mode))
if merge_mode in ('concat', 'add'):
msg = f'"{up_mode}" is not a valid mode for upsampling. Only "transpose" and "upsample" are allowed.'
raise ValueError(msg)
if merge_mode in ("concat", "add"):
self.merge_mode = merge_mode
else:
raise ValueError("\"{}\" is not a valid mode for"
"merging up and down paths. "
"Only \"concat\" and "
"\"add\" are allowed.".format(up_mode))
msg = f'"{up_mode}" is not a valid mode formerging up and down paths. Only "concat" and "add" are allowed.'
raise ValueError(msg)
# NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
if self.up_mode == 'upsample' and self.merge_mode == 'add':
raise ValueError("up_mode \"upsample\" is incompatible "
"with merge_mode \"add\" at the moment "
"because it doesn't make sense to use "
"nearest neighbour to reduce "
"depth channels (by half).")
if self.up_mode == "upsample" and self.merge_mode == "add":
msg = 'up_mode "upsample" is incompatible with merge_mode "add" at the moment because it doesn\'t make sense to use nearest neighbour to reduce depth channels (by half).'
raise ValueError(msg)
self.num_classes = num_classes
self.in_channels = in_channels
@ -185,19 +157,18 @@ class UNet(nn.Module):
# create the encoder pathway and add to a list
for i in range(depth):
ins = self.in_channels if i == 0 else outs
outs = self.start_filts*(2**i)
pooling = True if i < depth-1 else False
outs = self.start_filts * (2**i)
pooling = True if i < depth - 1 else False
down_conv = DownConv(ins, outs, pooling=pooling)
self.down_convs.append(down_conv)
# create the decoder pathway and add to a list
# - careful! decoding only requires depth-1 blocks
for i in range(depth-1):
for i in range(depth - 1):
ins = outs
outs = ins // 2
up_conv = UpConv(ins, outs, up_mode=up_mode,
merge_mode=merge_mode)
up_conv = UpConv(ins, outs, up_mode=up_mode, merge_mode=merge_mode)
self.up_convs.append(up_conv)
# add the list of modules to current module
@ -214,12 +185,10 @@ class UNet(nn.Module):
init.xavier_normal_(m.weight)
init.constant_(m.bias, 0)
def reset_params(self):
for i, m in enumerate(self.modules()):
for _i, m in enumerate(self.modules()):
self.weight_init(m)
def forward(self, x):
encoder_outs = []
# encoder pathway, save outputs for merging
@ -227,30 +196,31 @@ class UNet(nn.Module):
x, before_pool = module(x)
encoder_outs.append(before_pool)
for i, module in enumerate(self.up_convs):
before_pool = encoder_outs[-(i+2)]
before_pool = encoder_outs[-(i + 2)]
x = module(before_pool, x)
# No softmax is used. This means you need to use
# nn.CrossEntropyLoss is your training script,
# as this module includes a softmax already.
x = self.conv_final(x)
return x
if __name__ == "__main__":
"""
testing
"""
model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32)
model = UNet(1, depth=5, merge_mode="concat", in_channels=1, start_filts=32)
print(model)
print(sum(p.numel() for p in model.parameters()))
reso = 176
x = np.zeros((1, 1, reso, reso))
x[:,:,int(reso/2-1), int(reso/2-1)] = np.nan
x[:, :, int(reso / 2 - 1), int(reso / 2 - 1)] = np.nan
x = torch.FloatTensor(x)
out = model(x)
print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso)))
print("%f" % (torch.sum(torch.isnan(out)).detach().cpu().numpy() / (reso * reso)))
# loss = torch.sum(out)
# loss.backward()

View file

@ -1,16 +1,18 @@
'''
Code from the 3D UNet implementation:
https://github.com/wolny/pytorch-3dunet/
'''
"""Code from the 3D UNet implementation:
https://github.com/wolny/pytorch-3dunet/.
"""
import importlib
from functools import partial
import torch
import torch.nn as nn
from torch.nn import functional as F
from functools import partial
from src.network.utils import get_embedder
def number_of_features_per_level(init_channel_number, num_levels):
return [init_channel_number * 2 ** k for k in range(num_levels)]
return [init_channel_number * 2**k for k in range(num_levels)]
def conv3d(in_channels, out_channels, kernel_size, bias, padding=1):
@ -18,8 +20,7 @@ def conv3d(in_channels, out_channels, kernel_size, bias, padding=1):
def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1):
"""
Create a list of modules with together constitute a single conv layer with non-linearity
"""Create a list of modules with together constitute a single conv layer with non-linearity
and optional batchnorm/groupnorm.
Args:
@ -37,23 +38,23 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi
Return:
list of tuple (name, module)
"""
assert 'c' in order, "Conv layer MUST be present"
assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer'
assert "c" in order, "Conv layer MUST be present"
assert order[0] not in "rle", "Non-linearity cannot be the first operation in the layer"
modules = []
for i, char in enumerate(order):
if char == 'r':
modules.append(('ReLU', nn.ReLU(inplace=True)))
elif char == 'l':
modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True)))
elif char == 'e':
modules.append(('ELU', nn.ELU(inplace=True)))
elif char == 'c':
if char == "r":
modules.append(("ReLU", nn.ReLU(inplace=True)))
elif char == "l":
modules.append(("LeakyReLU", nn.LeakyReLU(negative_slope=0.1, inplace=True)))
elif char == "e":
modules.append(("ELU", nn.ELU(inplace=True)))
elif char == "c":
# add learnable bias only in the absence of batchnorm/groupnorm
bias = not ('g' in order or 'b' in order)
modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding)))
elif char == 'g':
is_before_conv = i < order.index('c')
bias = not ("g" in order or "b" in order)
modules.append(("conv", conv3d(in_channels, out_channels, kernel_size, bias, padding=padding)))
elif char == "g":
is_before_conv = i < order.index("c")
if is_before_conv:
num_channels = in_channels
else:
@ -63,14 +64,16 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi
if num_channels < num_groups:
num_groups = 1
assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}'
modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)))
elif char == 'b':
is_before_conv = i < order.index('c')
assert (
num_channels % num_groups == 0
), f"Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}"
modules.append(("groupnorm", nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)))
elif char == "b":
is_before_conv = i < order.index("c")
if is_before_conv:
modules.append(('batchnorm', nn.BatchNorm3d(in_channels)))
modules.append(("batchnorm", nn.BatchNorm3d(in_channels)))
else:
modules.append(('batchnorm', nn.BatchNorm3d(out_channels)))
modules.append(("batchnorm", nn.BatchNorm3d(out_channels)))
else:
raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']")
@ -78,9 +81,8 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi
class SingleConv(nn.Sequential):
"""
Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
of operations can be specified via the `order` parameter
"""Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
of operations can be specified via the `order` parameter.
Args:
in_channels (int): number of input channels
@ -94,16 +96,15 @@ class SingleConv(nn.Sequential):
num_groups (int): number of groups for the GroupNorm
"""
def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8, padding=1):
super(SingleConv, self).__init__()
def __init__(self, in_channels, out_channels, kernel_size=3, order="crg", num_groups=8, padding=1):
super().__init__()
for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding):
self.add_module(name, module)
class DoubleConv(nn.Sequential):
"""
A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
"""A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
We use (Conv3d+ReLU+GroupNorm3d) by default.
This can be changed however by providing the 'order' argument, e.g. in order
to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
@ -123,8 +124,8 @@ class DoubleConv(nn.Sequential):
num_groups (int): number of groups for the GroupNorm
"""
def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='crg', num_groups=8):
super(DoubleConv, self).__init__()
def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order="crg", num_groups=8):
super().__init__()
if encoder:
# we're in the encoder path
conv1_in_channels = in_channels
@ -138,26 +139,27 @@ class DoubleConv(nn.Sequential):
conv2_in_channels, conv2_out_channels = out_channels, out_channels
# conv1
self.add_module('SingleConv1',
SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups))
self.add_module(
"SingleConv1", SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups),
)
# conv2
self.add_module('SingleConv2',
SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups))
self.add_module(
"SingleConv2", SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups),
)
class ExtResNetBlock(nn.Module):
"""
Basic UNet block consisting of a SingleConv followed by the residual block.
"""Basic UNet block consisting of a SingleConv followed by the residual block.
The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number
of output channels is compatible with the residual block that follows.
This block can be used instead of standard DoubleConv in the Encoder module.
Motivated by: https://arxiv.org/pdf/1706.00120.pdf
Motivated by: https://arxiv.org/pdf/1706.00120.pdf.
Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm.
"""
def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs):
super(ExtResNetBlock, self).__init__()
def __init__(self, in_channels, out_channels, kernel_size=3, order="cge", num_groups=8, **kwargs):
super().__init__()
# first convolution
self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
@ -165,15 +167,16 @@ class ExtResNetBlock(nn.Module):
self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
# remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
n_order = order
for c in 'rel':
n_order = n_order.replace(c, '')
self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order,
num_groups=num_groups)
for c in "rel":
n_order = n_order.replace(c, "")
self.conv3 = SingleConv(
out_channels, out_channels, kernel_size=kernel_size, order=n_order, num_groups=num_groups,
)
# create non-linearity separately
if 'l' in order:
if "l" in order:
self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True)
elif 'e' in order:
elif "e" in order:
self.non_linearity = nn.ELU(inplace=True)
else:
self.non_linearity = nn.ReLU(inplace=True)
@ -194,12 +197,12 @@ class ExtResNetBlock(nn.Module):
class Encoder(nn.Module):
"""
A single module from the encoder path consisting of the optional max
"""A single module from the encoder path consisting of the optional max
pooling layer (one may specify the MaxPool kernel_size to be different
than the standard (2,2,2), e.g. if the volumetric data is anisotropic
(make sure to use complementary scale_factor in the decoder path) followed by
a DoubleConv module.
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
@ -210,27 +213,39 @@ class Encoder(nn.Module):
basic_module(nn.Module): either ResNetBlock or DoubleConv
conv_layer_order (string): determines the order of layers
in `DoubleConv` module. See `DoubleConv` for more info.
num_groups (int): number of groups for the GroupNorm
num_groups (int): number of groups for the GroupNorm.
"""
def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True,
pool_kernel_size=(2, 2, 2), pool_type='max', basic_module=DoubleConv, conv_layer_order='crg',
num_groups=8):
super(Encoder, self).__init__()
assert pool_type in ['max', 'avg']
def __init__(
self,
in_channels,
out_channels,
conv_kernel_size=3,
apply_pooling=True,
pool_kernel_size=(2, 2, 2),
pool_type="max",
basic_module=DoubleConv,
conv_layer_order="crg",
num_groups=8,
):
super().__init__()
assert pool_type in ["max", "avg"]
if apply_pooling:
if pool_type == 'max':
if pool_type == "max":
self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
else:
self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size)
else:
self.pooling = None
self.basic_module = basic_module(in_channels, out_channels,
encoder=True,
kernel_size=conv_kernel_size,
order=conv_layer_order,
num_groups=num_groups)
self.basic_module = basic_module(
in_channels,
out_channels,
encoder=True,
kernel_size=conv_kernel_size,
order=conv_layer_order,
num_groups=num_groups,
)
def forward(self, x):
if self.pooling is not None:
@ -240,9 +255,9 @@ class Encoder(nn.Module):
class Decoder(nn.Module):
"""
A single module for decoder path consisting of the upsampling layer
"""A single module for decoder path consisting of the upsampling layer
(either learned ConvTranspose3d or nearest neighbor interpolation) followed by a basic module (DoubleConv or ExtResNetBlock).
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
@ -253,32 +268,56 @@ class Decoder(nn.Module):
basic_module(nn.Module): either ResNetBlock or DoubleConv
conv_layer_order (string): determines the order of layers
in `DoubleConv` module. See `DoubleConv` for more info.
num_groups (int): number of groups for the GroupNorm
num_groups (int): number of groups for the GroupNorm.
"""
def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv,
conv_layer_order='crg', num_groups=8, mode='nearest'):
super(Decoder, self).__init__()
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
scale_factor=(2, 2, 2),
basic_module=DoubleConv,
conv_layer_order="crg",
num_groups=8,
mode="nearest",
):
super().__init__()
if basic_module == DoubleConv:
# if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining
self.upsampling = Upsampling(transposed_conv=False, in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, scale_factor=scale_factor, mode=mode)
self.upsampling = Upsampling(
transposed_conv=False,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
scale_factor=scale_factor,
mode=mode,
)
# concat joining
self.joining = partial(self._joining, concat=True)
else:
# if basic_module=ExtResNetBlock use transposed convolution upsampling and summation joining
self.upsampling = Upsampling(transposed_conv=True, in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, scale_factor=scale_factor, mode=mode)
self.upsampling = Upsampling(
transposed_conv=True,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
scale_factor=scale_factor,
mode=mode,
)
# sum joining
self.joining = partial(self._joining, concat=False)
# adapt the number of in_channels for the ExtResNetBlock
in_channels = out_channels
self.basic_module = basic_module(in_channels, out_channels,
encoder=False,
kernel_size=kernel_size,
order=conv_layer_order,
num_groups=num_groups)
self.basic_module = basic_module(
in_channels,
out_channels,
encoder=False,
kernel_size=kernel_size,
order=conv_layer_order,
num_groups=num_groups,
)
def forward(self, encoder_features, x):
x = self.upsampling(encoder_features=encoder_features, x=x)
@ -295,8 +334,7 @@ class Decoder(nn.Module):
class Upsampling(nn.Module):
"""
Upsamples a given multi-channel 3D data using either interpolation or learned transposed convolution.
"""Upsamples a given multi-channel 3D data using either interpolation or learned transposed convolution.
Args:
transposed_conv (bool): if True uses ConvTranspose3d for upsampling, otherwise uses interpolation
@ -310,15 +348,23 @@ class Upsampling(nn.Module):
'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest'
"""
def __init__(self, transposed_conv, in_channels=None, out_channels=None, kernel_size=3,
scale_factor=(2, 2, 2), mode='nearest'):
super(Upsampling, self).__init__()
def __init__(
self,
transposed_conv,
in_channels=None,
out_channels=None,
kernel_size=3,
scale_factor=(2, 2, 2),
mode="nearest",
):
super().__init__()
if transposed_conv:
# make sure that the output size reverses the MaxPool3d from the corresponding encoder
# (D_out=(D_in1)×stride[0]2×padding[0]+kernel_size[0]+output_padding[0])
self.upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor,
padding=1)
# (D_out = (D_in - 1) x stride[0] - 2 x padding[0] + kernel_size[0] + output_padding[0])
self.upsample = nn.ConvTranspose3d(
in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor, padding=1,
)
else:
self.upsample = partial(self._interpolate, mode=mode)
@ -332,13 +378,13 @@ class Upsampling(nn.Module):
class FinalConv(nn.Sequential):
"""
A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution
"""A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution
which reduces the number of channels to 'out_channels'.
with the number of output channels 'out_channels // 2' and 'out_channels' respectively.
We use (Conv3d+ReLU+GroupNorm3d) by default.
This can be change however by providing the 'order' argument, e.g. in order
to change to Conv3d+BatchNorm3d+ReLU use order='cbr'.
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
@ -346,22 +392,22 @@ class FinalConv(nn.Sequential):
order (string): determines the order of layers, e.g.
'cr' -> conv + ReLU
'crg' -> conv + ReLU + groupnorm
num_groups (int): number of groups for the GroupNorm
num_groups (int): number of groups for the GroupNorm.
"""
def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8):
super(FinalConv, self).__init__()
def __init__(self, in_channels, out_channels, kernel_size=3, order="crg", num_groups=8):
super().__init__()
# conv1
self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups))
self.add_module("SingleConv", SingleConv(in_channels, in_channels, kernel_size, order, num_groups))
# in the last layer a 1×1 convolution reduces the number of output channels to out_channels
# in the last layer a 1x1 convolution reduces the number of output channels to out_channels
final_conv = nn.Conv3d(in_channels, out_channels, 1)
self.add_module('final_conv', final_conv)
self.add_module("final_conv", final_conv)
class Abstract3DUNet(nn.Module):
"""
Base class for standard and residual UNet.
"""Base class for standard and residual UNet.
Args:
in_channels (int): number of input channels
@ -391,9 +437,22 @@ class Abstract3DUNet(nn.Module):
and the `final_activation` (even if present) won't be applied; default: False
"""
def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=4, is_segmentation=False, testing=False, pe_freq=0, **kwargs):
super(Abstract3DUNet, self).__init__()
def __init__(
self,
in_channels,
out_channels,
final_sigmoid,
basic_module,
f_maps=64,
layer_order="gcr",
num_groups=8,
num_levels=4,
is_segmentation=False,
testing=False,
pe_freq=0,
**kwargs,
):
super().__init__()
self.testing = testing
@ -411,13 +470,24 @@ class Abstract3DUNet(nn.Module):
encoders = []
for i, out_feature_num in enumerate(f_maps):
if i == 0:
encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=basic_module,
conv_layer_order=layer_order, num_groups=num_groups)
encoder = Encoder(
in_channels,
out_feature_num,
apply_pooling=False,
basic_module=basic_module,
conv_layer_order=layer_order,
num_groups=num_groups,
)
else:
# TODO: adapt for anisotropy in the data, i.e. use proper pooling kernel to make the data isotropic after 1-2 pooling operations
# currently pools with a constant kernel: (2, 2, 2)
encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=basic_module,
conv_layer_order=layer_order, num_groups=num_groups)
encoder = Encoder(
f_maps[i - 1],
out_feature_num,
basic_module=basic_module,
conv_layer_order=layer_order,
num_groups=num_groups,
)
encoders.append(encoder)
self.encoders = nn.ModuleList(encoders)
@ -434,13 +504,18 @@ class Abstract3DUNet(nn.Module):
out_feature_num = reversed_f_maps[i + 1]
# TODO: if non-standard pooling was used, make sure to use correct striding for transpose conv
# currently strides with a constant stride: (2, 2, 2)
decoder = Decoder(in_feature_num, out_feature_num, basic_module=basic_module,
conv_layer_order=layer_order, num_groups=num_groups)
decoder = Decoder(
in_feature_num,
out_feature_num,
basic_module=basic_module,
conv_layer_order=layer_order,
num_groups=num_groups,
)
decoders.append(decoder)
self.decoders = nn.ModuleList(decoders)
# in the last layer a 1×1 convolution reduces the number of output
# in the last layer a 1x1 convolution reduces the number of output
# channels to the number of labels
self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)
@ -455,7 +530,6 @@ class Abstract3DUNet(nn.Module):
self.final_activation = None
def forward(self, x):
if self.embed_fn is not None:
x = self.embed_fn(x.permute(0, 2, 3, 4, 1))
x = x.permute(0, 4, 1, 2, 3)
@ -488,49 +562,81 @@ class Abstract3DUNet(nn.Module):
class UNet3D(Abstract3DUNet):
"""
3DUnet model from
"""3DUnet model from
`"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation"
<https://arxiv.org/pdf/1606.06650.pdf>`.
Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder
"""
def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=4, is_segmentation=True, **kwargs):
super(UNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid,
basic_module=DoubleConv, f_maps=f_maps, layer_order=layer_order,
num_groups=num_groups, num_levels=num_levels, is_segmentation=is_segmentation,
**kwargs)
def __init__(
self,
in_channels,
out_channels,
final_sigmoid=True,
f_maps=64,
layer_order="gcr",
num_groups=8,
num_levels=4,
is_segmentation=True,
**kwargs,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
final_sigmoid=final_sigmoid,
basic_module=DoubleConv,
f_maps=f_maps,
layer_order=layer_order,
num_groups=num_groups,
num_levels=num_levels,
is_segmentation=is_segmentation,
**kwargs,
)
class ResidualUNet3D(Abstract3DUNet):
"""
Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
"""Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
Uses ExtResNetBlock as a basic building block, summation joining instead
of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts).
Since the model effectively becomes a residual net, in theory it allows for deeper UNet.
"""
def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=5, is_segmentation=True, **kwargs):
super(ResidualUNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels,
final_sigmoid=final_sigmoid,
basic_module=ExtResNetBlock, f_maps=f_maps, layer_order=layer_order,
num_groups=num_groups, num_levels=num_levels,
is_segmentation=is_segmentation,
**kwargs)
def __init__(
self,
in_channels,
out_channels,
final_sigmoid=True,
f_maps=64,
layer_order="gcr",
num_groups=8,
num_levels=5,
is_segmentation=True,
**kwargs,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
final_sigmoid=final_sigmoid,
basic_module=ExtResNetBlock,
f_maps=f_maps,
layer_order=layer_order,
num_groups=num_groups,
num_levels=num_levels,
is_segmentation=is_segmentation,
**kwargs,
)
def get_model(config):
def _model_class(class_name):
m = importlib.import_module('pytorch3dunet.unet3d.model')
m = importlib.import_module("pytorch3dunet.unet3d.model")
clazz = getattr(m, class_name)
return clazz
assert 'model' in config, 'Could not find model configuration'
model_config = config['model']
model_class = _model_class(model_config['name'])
assert "model" in config, "Could not find model configuration"
model_config = config["model"]
model_class = _model_class(model_config["name"])
return model_class(**model_config)
@ -542,18 +648,18 @@ if __name__ == "__main__":
out_channels = 1
f_maps = 32
num_levels = 2
model = UNet3D(in_channels, out_channels, f_maps=f_maps, num_levels=num_levels, layer_order='cr')
model = UNet3D(in_channels, out_channels, f_maps=f_maps, num_levels=num_levels, layer_order="cr")
print(model)
print('number of parameters: ', sum(p.numel() for p in model.parameters()))
print("number of parameters: ", sum(p.numel() for p in model.parameters()))
reso = 18
import numpy as np
import torch
x = np.zeros((1, 1, reso, reso, reso))
x[:,:, int(reso/2-1), int(reso/2-1), int(reso/2-1)] = np.nan
x[:, :, int(reso / 2 - 1), int(reso / 2 - 1), int(reso / 2 - 1)] = np.nan
x = torch.FloatTensor(x)
out = model(x)
print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso*reso)))
print("%f" % (torch.sum(torch.isnan(out)).detach().cpu().numpy() / (reso * reso * reso)))

View file

@ -1,8 +1,9 @@
""" Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """
"""Positional encoding embedding. Code was taken from https://github.com/bmild/nerf."""
import torch
import torch.nn as nn
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
@ -10,24 +11,23 @@ class Embedder:
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
d = self.kwargs["input_dims"]
out_dim = 0
if self.kwargs['include_input']:
if self.kwargs["include_input"]:
embed_fns.append(lambda x: x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
max_freq = self.kwargs["max_freq_log2"]
N_freqs = self.kwargs["num_freqs"]
if self.kwargs['log_sampling']:
freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
if self.kwargs["log_sampling"]:
freq_bands = 2.0 ** torch.linspace(0.0, max_freq, N_freqs)
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn,
freq=freq: p_fn(x * freq))
for p_fn in self.kwargs["periodic_fns"]:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
@ -36,35 +36,40 @@ class Embedder:
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, d_in=3):
embed_kwargs = {
'include_input': True,
'input_dims': d_in,
'max_freq_log2': multires-1,
'num_freqs': multires,
'log_sampling': True,
'periodic_fns': [torch.sin, torch.cos],
"include_input": True,
"input_dims": d_in,
"max_freq_log2": multires - 1,
"num_freqs": multires,
"log_sampling": True,
"periodic_fns": [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
def embed(x, eo=embedder_obj): return eo.embed(x)
def embed(x, eo=embedder_obj):
return eo.embed(x)
return embed, embedder_obj.out_dim
def normalize_coordinate(p, plane='xz'):
''' Normalize coordinate to [0, 1] for unit cube experiments
def normalize_coordinate(p, plane="xz"):
"""Normalize coordinate to [0, 1] for unit cube experiments.
Args:
p (tensor): point
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
plane (str): plane feature type, ['xz', 'xy', 'yz']
'''
if plane == 'xz':
"""
if plane == "xz":
xy = p[:, :, [0, 2]]
elif plane =='xy':
elif plane == "xy":
xy = p[:, :, [0, 1]]
else:
xy = p[:, :, [1, 2]]
xy_new = xy
# f there are outliers out of the range
if xy_new.max() >= 1:
@ -75,40 +80,41 @@ def normalize_coordinate(p, plane='xz'):
def normalize_3d_coordinate(p):
''' Normalize coordinate to [0, 1] for unit cube experiments.
'''
"""Normalize coordinate to [0, 1] for unit cube experiments."""
if p.max() >= 1:
p[p >= 1] = 1 - 10e-6
if p.min() < 0:
p[p < 0] = 0.0
return p
def coordinate2index(x, reso, coord_type='2d'):
''' Normalize coordinate to [0, 1] for unit cube experiments.
Corresponds to our 3D model
def coordinate2index(x, reso, coord_type="2d"):
"""Normalize coordinate to [0, 1] for unit cube experiments.
Corresponds to our 3D model.
Args:
x (tensor): coordinate
reso (int): defined resolution
coord_type (str): coordinate type
'''
"""
x = (x * reso).long()
if coord_type == '2d': # plane
if coord_type == "2d": # plane
index = x[:, :, 0] + reso * x[:, :, 1]
elif coord_type == '3d': # grid
elif coord_type == "3d": # grid
index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2])
index = index[:, None, :]
return index
class map2local(object):
''' Add new keys to the given input
class map2local:
"""Add new keys to the given input.
Args:
s (float): the defined voxel size
pos_encoding (str): method for the positional encoding, linear|sin_cos
'''
def __init__(self, s, pos_encoding='linear'):
"""
def __init__(self, s, pos_encoding="linear"):
super().__init__()
self.s = s
# self.pe = positional_encoding(basis_function=pos_encoding, local=True)
@ -121,15 +127,16 @@ class map2local(object):
# p = self.pe(p)
return p
# Resnet Blocks
class ResnetBlockFC(nn.Module):
''' Fully connected ResNet Block class.
"""Fully connected ResNet Block class.
Args:
size_in (int): input dimension
size_out (int): output dimension
size_h (int): hidden dimension
'''
"""
def __init__(self, size_in, size_out=None, size_h=None, siren=False):
super().__init__()
@ -164,4 +171,4 @@ class ResnetBlockFC(nn.Module):
else:
x_s = x
return x_s + dx
return x_s + dx

View file

@ -1,112 +1,110 @@
import time, os
import os
import numpy as np
import open3d as o3d
import torch
from torch.nn import functional as F
import trimesh
from pytorch3d.loss import chamfer_distance
from torchvision.io import write_video
from torchvision.utils import save_image
from src.dpsr import DPSR
from src.model import PSR2Mesh
from src.utils import grid_interp, verts_on_largest_mesh,\
export_pointcloud, mc_from_psr, GaussianSmoothing
from src.visualize import visualize_points_mesh, visualize_psr_grid, \
visualize_mesh_phong, render_rgb
from torchvision.utils import save_image
from torchvision.io import write_video
from pytorch3d.loss import chamfer_distance
import open3d as o3d
from src.utils import export_pointcloud, mc_from_psr, verts_on_largest_mesh
from src.visualize import (
render_rgb,
visualize_mesh_phong,
visualize_points_mesh,
visualize_psr_grid,
)
class Trainer(object):
'''
Args:
cfg : config file
optimizer : pytorch optimizer object
device : pytorch device
'''
class Trainer:
"""Args:
cfg : config file
optimizer : pytorch optimizer object
device : pytorch device.
"""
def __init__(self, cfg, optimizer, device=None):
self.optimizer = optimizer
self.device = device
self.cfg = cfg
self.psr2mesh = PSR2Mesh.apply
self.data_type = cfg['data']['data_type']
self.data_type = cfg["data"]["data_type"]
# initialize DPSR
self.dpsr = DPSR(res=(cfg['model']['grid_res'],
cfg['model']['grid_res'],
cfg['model']['grid_res']),
sig=cfg['model']['psr_sigma'])
if torch.cuda.device_count() > 1:
self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR
self.dpsr = DPSR(
res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
sig=cfg["model"]["psr_sigma"],
)
if torch.cuda.device_count() > 1:
self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR
self.dpsr = self.dpsr.to(device)
def train_step(self, data, inputs, model, it):
''' Performs a training step.
"""Performs a training step.
Args:
data (dict) : data dictionary
inputs (torch.tensor) : input point clouds
model (nn.Module or None): a neural network or None
it (int) : the number of iterations
'''
"""
self.optimizer.zero_grad()
loss, loss_each = self.compute_loss(inputs, data, model, it)
loss.backward()
self.optimizer.step()
return loss.item(), loss_each
def compute_loss(self, inputs, data, model, it=0):
''' Compute the loss.
"""Compute the loss.
Args:
data (dict) : data dictionary
inputs (torch.tensor) : input point clouds
model (nn.Module or None): a neural network or None
it (int) : the number of iterations
'''
it (int) : the number of iterations.
"""
res = self.cfg["model"]["grid_res"]
device = self.device
res = self.cfg['model']['grid_res']
# source oriented point clouds to PSR grid
psr_grid, points, normals = self.pcl2psr(inputs)
# build mesh
v, f, n = self.psr2mesh(psr_grid)
# the output is in the range of [0, 1), we make it to the real range [0, 1].
# This is a hack for our DPSR solver
v = v * res / (res-1)
points = points * 2. - 1.
v = v * 2. - 1. # within the range of (-1, 1)
# the output is in the range of [0, 1), we make it to the real range [0, 1].
# This is a hack for our DPSR solver
v = v * res / (res - 1)
points = points * 2.0 - 1.0
v = v * 2.0 - 1.0 # within the range of (-1, 1)
loss = 0
loss_each = {}
# compute loss
if self.data_type == 'point':
if self.cfg['train']['w_chamfer'] > 0:
loss_ = self.cfg['train']['w_chamfer'] * \
self.compute_3d_loss(v, data)
loss_each['chamfer'] = loss_
if self.data_type == "point":
if self.cfg["train"]["w_chamfer"] > 0:
loss_ = self.cfg["train"]["w_chamfer"] * self.compute_3d_loss(v, data)
loss_each["chamfer"] = loss_
loss += loss_
elif self.data_type == 'img':
elif self.data_type == "img":
loss, loss_each = self.compute_2d_loss(inputs, data, model)
return loss, loss_each
def pcl2psr(self, inputs):
''' Convert an oriented point cloud to PSR indicator grid
"""Convert an oriented point cloud to PSR indicator grid
Args:
inputs (torch.tensor): input oriented point clouds
'''
points, normals = inputs[...,:3], inputs[...,3:]
if self.cfg['model']['apply_sigmoid']:
inputs (torch.tensor): input oriented point clouds.
"""
points, normals = inputs[..., :3], inputs[..., 3:]
if self.cfg["model"]["apply_sigmoid"]:
points = torch.sigmoid(points)
if self.cfg['model']['normal_normalize']:
if self.cfg["model"]["normal_normalize"]:
normals = normals / normals.norm(dim=-1, keepdim=True)
# DPSR to get grid
@ -116,53 +114,50 @@ class Trainer(object):
return psr_grid, points, normals
def compute_3d_loss(self, v, data):
''' Compute the loss for point clouds.
"""Compute the loss for point clouds.
Args:
v (torch.tensor) : mesh vertices
data (dict) : data dictionary
'''
pts_gt = data.get('target_points')
idx = np.random.randint(pts_gt.shape[1], size=self.cfg['train']['n_sup_point'])
if self.cfg['train']['subsample_vertex']:
#chamfer distance only on random sampled vertices
idx = np.random.randint(v.shape[1], size=self.cfg['train']['n_sup_point'])
data (dict) : data dictionary.
"""
pts_gt = data.get("target_points")
idx = np.random.randint(pts_gt.shape[1], size=self.cfg["train"]["n_sup_point"])
if self.cfg["train"]["subsample_vertex"]:
# chamfer distance only on random sampled vertices
idx = np.random.randint(v.shape[1], size=self.cfg["train"]["n_sup_point"])
loss, _ = chamfer_distance(v[:, idx], pts_gt)
else:
loss, _ = chamfer_distance(v, pts_gt)
return loss
def compute_2d_loss(self, inputs, data, model):
''' Compute the 2D losses.
"""Compute the 2D losses.
Args:
inputs (torch.tensor) : input source point clouds
data (dict) : data dictionary
model (nn.Module or None): neural network or None
'''
losses = {"color":
{"weight": self.cfg['train']['l_weight']['rgb'],
"values": []
},
"silhouette":
{"weight": self.cfg['train']['l_weight']['mask'],
"values": []},
}
model (nn.Module or None): neural network or None.
"""
losses = {
"color": {
"weight": self.cfg["train"]["l_weight"]["rgb"],
"values": [],
},
"silhouette": {"weight": self.cfg["train"]["l_weight"]["mask"], "values": []},
}
loss_all = {k: torch.tensor(0.0, device=self.device) for k in losses}
# forward pass
out = model(inputs, data)
if out['rgb'] is not None:
rgb_gt = out['rgb_gt'].reshape(self.cfg['data']['n_views_per_iter'],
-1, 3)[out['vis_mask']]
loss_all["color"] += torch.nn.L1Loss(reduction='sum')(rgb_gt,
out['rgb']) / out['rgb'].shape[0]
if out["rgb"] is not None:
rgb_gt = out["rgb_gt"].reshape(self.cfg["data"]["n_views_per_iter"], -1, 3)[out["vis_mask"]]
loss_all["color"] += torch.nn.L1Loss(reduction="sum")(rgb_gt, out["rgb"]) / out["rgb"].shape[0]
if out["mask"] is not None:
loss_all["silhouette"] += ((out["mask"] - out["mask_gt"]) ** 2).mean()
if out['mask'] is not None:
loss_all["silhouette"] += ((out['mask'] - out['mask_gt']) ** 2).mean()
# weighted sum of the losses
loss = torch.tensor(0.0, device=self.device)
for k, l in loss_all.items():
@ -172,154 +167,146 @@ class Trainer(object):
return loss, loss_all
def point_resampling(self, inputs):
''' Resample points
"""Resample points
Args:
inputs (torch.tensor): oriented point clouds
'''
inputs (torch.tensor): oriented point clouds.
"""
psr_grid, points, normals = self.pcl2psr(inputs)
# shortcuts
n_grow = self.cfg['train']['n_grow_points']
# [hack] for points resampled from the mesh from marching cubes,
# shortcuts
n_grow = self.cfg["train"]["n_grow_points"]
# [hack] for points resampled from the mesh from marching cubes,
# we need to divide by s instead of (s-1), and the scale is correct.
verts, faces, _ = mc_from_psr(psr_grid, real_scale=False, zero_level=0)
# find the largest component
pts_mesh, faces_mesh = verts_on_largest_mesh(verts, faces)
# sample vertices only from the largest component, not from fragments
mesh = trimesh.Trimesh(vertices=pts_mesh, faces=faces_mesh)
pi, face_idx = mesh.sample(n_grow+points.shape[1], return_index=True)
normals_i = mesh.face_normals[face_idx].astype('float32')
pts_mesh = torch.tensor(pi.astype('float32')).to(self.device)[None]
pi, face_idx = mesh.sample(n_grow + points.shape[1], return_index=True)
normals_i = mesh.face_normals[face_idx].astype("float32")
pts_mesh = torch.tensor(pi.astype("float32")).to(self.device)[None]
n_mesh = torch.tensor(normals_i).to(self.device)[None]
points, normals = pts_mesh, n_mesh
print('{} total points are resampled'.format(points.shape[1]))
print(f"{points.shape[1]} total points are resampled")
# update inputs
points = torch.log(points / (1 - points)) # inverse sigmoid
points = torch.log(points / (1 - points)) # inverse sigmoid
inputs = torch.cat([points, normals], dim=-1)
inputs.requires_grad = True
inputs.requires_grad = True
return inputs
def visualize(self, data, inputs, renderer, epoch, o3d_vis=None):
''' Visualization.
"""Visualization.
Args:
data (dict) : data dictionary
inputs (torch.tensor) : source point clouds
renderer (nn.Module or None): a neural network or None
epoch (int) : the number of iterations
o3d_vis (o3d.Visualizer) : open3d visualizer
'''
data_type = self.cfg['data']['data_type']
it = '{:04d}'.format(int(epoch/self.cfg['train']['visualize_every']))
o3d_vis (o3d.Visualizer) : open3d visualizer.
"""
data_type = self.cfg["data"]["data_type"]
it = "{:04d}".format(int(epoch / self.cfg["train"]["visualize_every"]))
if (self.cfg['train']['exp_mesh']) \
| (self.cfg['train']['exp_pcl']) \
| (self.cfg['train']['o3d_show']):
if (self.cfg["train"]["exp_mesh"]) | (self.cfg["train"]["exp_pcl"]) | (self.cfg["train"]["o3d_show"]):
psr_grid, points, normals = self.pcl2psr(inputs)
with torch.no_grad():
v, f, n = mc_from_psr(psr_grid, pytorchify=True,
zero_level=self.cfg['data']['zero_level'], real_scale=True)
v, f, n = mc_from_psr(
psr_grid, pytorchify=True, zero_level=self.cfg["data"]["zero_level"], real_scale=True,
)
v, f, n = v[None], f[None], n[None]
v = v * 2. - 1. # change to the range of [-1, 1]
v = v * 2.0 - 1.0 # change to the range of [-1, 1]
color_v = None
if data_type == 'img':
if self.cfg['train']['vis_vert_color'] & \
(self.cfg['train']['l_weight']['rgb'] != 0.):
color_v = renderer['color'](v, n).squeeze().detach().cpu().numpy()
color_v[color_v<0], color_v[color_v>1] = 0., 1.
if data_type == "img":
if self.cfg["train"]["vis_vert_color"] & (self.cfg["train"]["l_weight"]["rgb"] != 0.0):
color_v = renderer["color"](v, n).squeeze().detach().cpu().numpy()
color_v[color_v < 0], color_v[color_v > 1] = 0.0, 1.0
vv = v.detach().squeeze().cpu().numpy()
ff = f.detach().squeeze().cpu().numpy()
points = points * 2 - 1
visualize_points_mesh(o3d_vis, points, normals,
vv, ff, self.cfg, it, epoch, color_v=color_v)
visualize_points_mesh(o3d_vis, points, normals, vv, ff, self.cfg, it, epoch, color_v=color_v)
else:
v, f, n = inputs
if (data_type == 'img') & (self.cfg['train']['vis_rendering']):
if (data_type == "img") & (self.cfg["train"]["vis_rendering"]):
pred_imgs = []
pred_masks = []
n_views = len(data['poses'])
len(data["poses"])
# idx_list = trange(n_views)
idx_list = [13, 24, 27, 48]
#!
#!
model = renderer.eval()
for idx in idx_list:
pose = data['poses'][idx]
rgb = data['rgbs'][idx]
mask_gt = data['masks'][idx]
img_size = rgb.shape[0] if rgb.shape[0]== rgb.shape[1] else (rgb.shape[0], rgb.shape[1])
pose = data["poses"][idx]
rgb = data["rgbs"][idx]
data["masks"][idx]
img_size = rgb.shape[0] if rgb.shape[0] == rgb.shape[1] else (rgb.shape[0], rgb.shape[1])
ray = None
if 'rays' in data.keys():
ray = data['rays'][idx]
if self.cfg['train']['l_weight']['rgb'] != 0.:
if "rays" in data.keys():
ray = data["rays"][idx]
if self.cfg["train"]["l_weight"]["rgb"] != 0.0:
fea_grid = None
if model.unet3d is not None:
with torch.no_grad():
fea_grid = model.unet3d(psr_grid).permute(0, 2, 3, 4, 1)
if model.encoder is not None:
pp = torch.cat([(points+1)/2, normals], dim=-1)
fea_grid = model.encoder(pp,
normalize=False).permute(0, 2, 3, 4, 1)
pp = torch.cat([(points + 1) / 2, normals], dim=-1)
fea_grid = model.encoder(pp, normalize=False).permute(0, 2, 3, 4, 1)
pred, visible_mask = render_rgb(v, f, n, pose,
model.rendering_network.eval(),
img_size, ray=ray, fea_grid=fea_grid)
img_pred = torch.zeros([rgb.shape[0]*rgb.shape[1], 3])
pred, visible_mask = render_rgb(
v, f, n, pose, model.rendering_network.eval(), img_size, ray=ray, fea_grid=fea_grid,
)
img_pred = torch.zeros([rgb.shape[0] * rgb.shape[1], 3])
img_pred[visible_mask] = pred.detach().cpu()
img_pred = img_pred.reshape(rgb.shape[0], rgb.shape[1], 3)
img_pred[img_pred<0], img_pred[img_pred>1] = 0., 1.
filename=os.path.join(self.cfg['train']['dir_rendering'],
'rendering_{}_{:d}.png'.format(it, idx))
img_pred[img_pred < 0], img_pred[img_pred > 1] = 0.0, 1.0
filename = os.path.join(self.cfg["train"]["dir_rendering"], f"rendering_{it}_{idx:d}.png")
save_image(img_pred.permute(2, 0, 1), filename)
pred_imgs.append(img_pred)
#! Mesh rendering using Phong shading model
filename=os.path.join(self.cfg['train']['dir_rendering'],
'mesh_{}_{:d}.png'.format(it, idx))
filename = os.path.join(self.cfg["train"]["dir_rendering"], f"mesh_{it}_{idx:d}.png")
visualize_mesh_phong(v, f, n, pose, img_size, name=filename)
if len(pred_imgs) >= 1:
pred_imgs = torch.stack(pred_imgs, dim=0)
save_image(pred_imgs.permute(0, 3, 1, 2),
os.path.join(self.cfg['train']['dir_rendering'],
'{}.png'.format(it)), nrow=4)
if self.cfg['train']['save_video']:
write_video(os.path.join(self.cfg['train']['dir_rendering'],
'{}.mp4'.format(it)),
(pred_imgs*255.).type(torch.uint8), fps=24)
save_image(
pred_imgs.permute(0, 3, 1, 2), os.path.join(self.cfg["train"]["dir_rendering"], f"{it}.png"), nrow=4,
)
if self.cfg["train"]["save_video"]:
write_video(
os.path.join(self.cfg["train"]["dir_rendering"], f"{it}.mp4"),
(pred_imgs * 255.0).type(torch.uint8),
fps=24,
)
def save_mesh_pointclouds(self, inputs, epoch, center=None, scale=None):
''' Save meshes and point clouds.
"""Save meshes and point clouds.
Args:
inputs (torch.tensor) : source point clouds
epoch (int) : the number of iterations
center (numpy.array) : center of the shape
scale (numpy.array) : scale of the shape
'''
scale (numpy.array) : scale of the shape.
"""
exp_pcl = self.cfg["train"]["exp_pcl"]
exp_mesh = self.cfg["train"]["exp_mesh"]
exp_pcl = self.cfg['train']['exp_pcl']
exp_mesh = self.cfg['train']['exp_mesh']
psr_grid, points, normals = self.pcl2psr(inputs)
if exp_pcl:
dir_pcl = self.cfg['train']['dir_pcl']
dir_pcl = self.cfg["train"]["dir_pcl"]
p = points.squeeze(0).detach().cpu().numpy()
p = p * 2 - 1
n = normals.squeeze(0).detach().cpu().numpy()
@ -327,12 +314,11 @@ class Trainer(object):
p *= scale
if center is not None:
p += center
export_pointcloud(os.path.join(dir_pcl, '{:04d}.ply'.format(epoch)), p, n)
export_pointcloud(os.path.join(dir_pcl, f"{epoch:04d}.ply"), p, n)
if exp_mesh:
dir_mesh = self.cfg['train']['dir_mesh']
dir_mesh = self.cfg["train"]["dir_mesh"]
with torch.no_grad():
v, f, _ = mc_from_psr(psr_grid,
zero_level=self.cfg['data']['zero_level'], real_scale=True)
v, f, _ = mc_from_psr(psr_grid, zero_level=self.cfg["data"]["zero_level"], real_scale=True)
v = v * 2 - 1
if scale is not None:
v *= scale
@ -341,9 +327,9 @@ class Trainer(object):
mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(v)
mesh.triangles = o3d.utility.Vector3iVector(f)
outdir_mesh = os.path.join(dir_mesh, '{:04d}.ply'.format(epoch))
outdir_mesh = os.path.join(dir_mesh, f"{epoch:04d}.ply")
o3d.io.write_triangle_mesh(outdir_mesh, mesh)
if self.cfg['train']['vis_psr']:
dir_psr_vis = self.cfg['train']['out_dir']+'/psr_vis_all'
if self.cfg["train"]["vis_psr"]:
dir_psr_vis = self.cfg["train"]["out_dir"] + "/psr_vis_all"
visualize_psr_grid(psr_grid, out_dir=dir_psr_vis)

View file

@ -1,75 +1,81 @@
import os
from collections import defaultdict
import numpy as np
import torch
from pytorch3d.loss import chamfer_distance
from pytorch3d.ops.knn import knn_gather, knn_points
from torch.nn import functional as F
from collections import defaultdict
import trimesh
from tqdm import tqdm
from src.dpsr import DPSR
from src.utils import grid_interp, export_pointcloud, export_mesh, \
mc_from_psr, scale2onet, GaussianSmoothing
from pytorch3d.ops.knn import knn_gather, knn_points
from pytorch3d.loss import chamfer_distance
from pdb import set_trace as st
from src.utils import (
GaussianSmoothing,
export_mesh,
export_pointcloud,
mc_from_psr,
scale2onet,
)
class Trainer:
"""Args:
model (nn.Module): our defined model
optimizer (optimizer): pytorch optimizer object
device (device): pytorch device
input_type (str): input type
vis_dir (str): visualization directory.
"""
class Trainer(object):
'''
Args:
model (nn.Module): our defined model
optimizer (optimizer): pytorch optimizer object
device (device): pytorch device
input_type (str): input type
vis_dir (str): visualization directory
'''
def __init__(self, cfg, optimizer, device=None):
self.optimizer = optimizer
self.device = device
self.cfg = cfg
if self.cfg['train']['w_raw'] != 0:
if self.cfg["train"]["w_raw"] != 0:
from src.model import PSR2Mesh
self.psr2mesh = PSR2Mesh.apply
# initialize DPSR
self.dpsr = DPSR(res=(cfg['model']['grid_res'],
cfg['model']['grid_res'],
cfg['model']['grid_res']),
sig=cfg['model']['psr_sigma'])
if torch.cuda.device_count() > 1:
self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR
self.dpsr = DPSR(
res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
sig=cfg["model"]["psr_sigma"],
)
if torch.cuda.device_count() > 1:
self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR
self.dpsr = self.dpsr.to(device)
if cfg['train']['gauss_weight']>0.:
if cfg["train"]["gauss_weight"] > 0.0:
self.gauss_smooth = GaussianSmoothing(1, 7, 2).to(device)
def train_step(self, inputs, data, model):
''' Performs a training step.
"""Performs a training step.
Args:
data (dict): data dictionary
'''
"""
self.optimizer.zero_grad()
p = data.get('inputs').to(self.device)
p = data.get("inputs").to(self.device)
out = model(p)
points, normals = out
loss = 0
loss_each = {}
if self.cfg['train']['w_psr'] != 0:
psr_gt = data.get('gt_psr').to(self.device)
if self.cfg['model']['psr_tanh']:
if self.cfg["train"]["w_psr"] != 0:
psr_gt = data.get("gt_psr").to(self.device)
if self.cfg["model"]["psr_tanh"]:
psr_gt = torch.tanh(psr_gt)
psr_grid = self.dpsr(points, normals)
if self.cfg['model']['psr_tanh']:
if self.cfg["model"]["psr_tanh"]:
psr_grid = torch.tanh(psr_grid)
# apply a rescaling weight based on GT SDF values
if self.cfg['train']['gauss_weight']>0:
gauss_sigma = self.cfg['train']['gauss_weight']
# set up the weighting for loss, higher weights
if self.cfg["train"]["gauss_weight"] > 0:
self.cfg["train"]["gauss_weight"]
# set up the weighting for loss, higher weights
# for points near to the surface
psr_gt_pad = torch.nn.ReplicationPad3d(1)(psr_gt.unsqueeze(1)).squeeze(1)
delta_x = delta_y = delta_z = 1
@ -82,97 +88,97 @@ class Trainer(object):
psr_grad = torch.stack([grad_x, grad_y, grad_z], dim=-1)
psr_grad_norm = psr_grad.norm(dim=-1)[:, None]
w = torch.nn.ReplicationPad3d(3)(psr_grad_norm)
w = 2*self.gauss_smooth(w).squeeze(1)
loss_each['psr'] = self.cfg['train']['w_psr'] * F.mse_loss(w*psr_grid, w*psr_gt)
w = 2 * self.gauss_smooth(w).squeeze(1)
loss_each["psr"] = self.cfg["train"]["w_psr"] * F.mse_loss(w * psr_grid, w * psr_gt)
else:
loss_each['psr'] = self.cfg['train']['w_psr'] * F.mse_loss(psr_grid, psr_gt)
loss_each["psr"] = self.cfg["train"]["w_psr"] * F.mse_loss(psr_grid, psr_gt)
loss += loss_each['psr']
loss += loss_each["psr"]
# regularization on the input point positions via chamfer distance
if self.cfg['train']['w_reg_point'] != 0.:
points_gt = data.get('gt_points').to(self.device)
if self.cfg["train"]["w_reg_point"] != 0.0:
points_gt = data.get("gt_points").to(self.device)
loss_reg, loss_norm = chamfer_distance(points, points_gt)
loss_each['reg'] = self.cfg['train']['w_reg_point'] * loss_reg
loss += loss_each['reg']
if self.cfg['train']['w_normals'] != 0.:
points_gt = data.get('gt_points').to(self.device)
normals_gt = data.get('gt_points.normals').to(self.device)
loss_each["reg"] = self.cfg["train"]["w_reg_point"] * loss_reg
loss += loss_each["reg"]
if self.cfg["train"]["w_normals"] != 0.0:
points_gt = data.get("gt_points").to(self.device)
normals_gt = data.get("gt_points.normals").to(self.device)
x_nn = knn_points(points, points_gt, K=1)
x_normals_near = knn_gather(normals_gt, x_nn.idx)[..., 0, :]
cham_norm_x = F.l1_loss(normals, x_normals_near)
loss_norm = cham_norm_x
loss_each['normals'] = self.cfg['train']['w_normals'] * loss_norm
loss += loss_each['normals']
if self.cfg['train']['w_raw'] != 0:
res = self.cfg['model']['grid_res']
loss_each["normals"] = self.cfg["train"]["w_normals"] * loss_norm
loss += loss_each["normals"]
if self.cfg["train"]["w_raw"] != 0:
self.cfg["model"]["grid_res"]
# DPSR to get grid
psr_grid = self.dpsr(points, normals)
if self.cfg['model']['psr_tanh']:
if self.cfg["model"]["psr_tanh"]:
psr_grid = torch.tanh(psr_grid)
v, f, n = self.psr2mesh(psr_grid)
pts_gt = data.get('gt_points').to(self.device)
pts_gt = data.get("gt_points").to(self.device)
loss, _ = chamfer_distance(v, pts_gt)
loss.backward()
self.optimizer.step()
return loss.item(), loss_each
def save(self, model, data, epoch, id):
p = data.get('inputs').to(self.device)
exp_pcl = self.cfg['train']['exp_pcl']
exp_mesh = self.cfg['train']['exp_mesh']
exp_gt = self.cfg['generation']['exp_gt']
exp_input = self.cfg['generation']['exp_input']
def save(self, model, data, epoch, id):
p = data.get("inputs").to(self.device)
exp_pcl = self.cfg["train"]["exp_pcl"]
exp_mesh = self.cfg["train"]["exp_mesh"]
exp_gt = self.cfg["generation"]["exp_gt"]
exp_input = self.cfg["generation"]["exp_input"]
model.eval()
with torch.no_grad():
points, normals = model(p)
if exp_gt:
points_gt = data.get('gt_points').to(self.device)
normals_gt = data.get('gt_points.normals').to(self.device)
points_gt = data.get("gt_points").to(self.device)
normals_gt = data.get("gt_points.normals").to(self.device)
if exp_pcl:
dir_pcl = self.cfg['train']['dir_pcl']
export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}.ply'.format(epoch, id)), scale2onet(points), normals)
dir_pcl = self.cfg["train"]["dir_pcl"]
export_pointcloud(os.path.join(dir_pcl, f"{epoch:04d}_{id:01d}.ply"), scale2onet(points), normals)
if exp_gt:
export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}_oracle.ply'.format(epoch, id)), scale2onet(points_gt), normals_gt)
export_pointcloud(
os.path.join(dir_pcl, f"{epoch:04d}_{id:01d}_oracle.ply"), scale2onet(points_gt), normals_gt,
)
if exp_input:
export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}_input.ply'.format(epoch, id)), scale2onet(p))
export_pointcloud(os.path.join(dir_pcl, f"{epoch:04d}_{id:01d}_input.ply"), scale2onet(p))
if exp_mesh:
dir_mesh = self.cfg['train']['dir_mesh']
dir_mesh = self.cfg["train"]["dir_mesh"]
psr_grid = self.dpsr(points, normals)
# psr_grid = torch.tanh(psr_grid)
with torch.no_grad():
v, f, _ = mc_from_psr(psr_grid,
zero_level=self.cfg['data']['zero_level'])
outdir_mesh = os.path.join(dir_mesh, '{:04d}_{:01d}.ply'.format(epoch, id))
v, f, _ = mc_from_psr(psr_grid, zero_level=self.cfg["data"]["zero_level"])
outdir_mesh = os.path.join(dir_mesh, f"{epoch:04d}_{id:01d}.ply")
export_mesh(outdir_mesh, scale2onet(v), f)
if exp_gt:
psr_gt = self.dpsr(points_gt, normals_gt)
with torch.no_grad():
v, f, _ = mc_from_psr(psr_gt,
zero_level=self.cfg['data']['zero_level'])
export_mesh(os.path.join(dir_mesh, '{:04d}_{:01d}_oracle.ply'.format(epoch, id)), scale2onet(v), f)
v, f, _ = mc_from_psr(psr_gt, zero_level=self.cfg["data"]["zero_level"])
export_mesh(os.path.join(dir_mesh, f"{epoch:04d}_{id:01d}_oracle.ply"), scale2onet(v), f)
def evaluate(self, val_loader, model):
''' Performs an evaluation.
"""Performs an evaluation.
Args:
val_loader (dataloader): pytorch dataloader
'''
val_loader (dataloader): pytorch dataloader.
"""
eval_list = defaultdict(list)
for data in tqdm(val_loader):
@ -183,25 +189,26 @@ class Trainer(object):
eval_dict = {k: np.mean(v) for k, v in eval_list.items()}
return eval_dict
def eval_step(self, data, model):
''' Performs an evaluation step.
"""Performs an evaluation step.
Args:
data (dict): data dictionary
'''
data (dict): data dictionary.
"""
model.eval()
eval_dict = {}
p = data.get('inputs').to(self.device)
psr_gt = data.get('gt_psr').to(self.device)
p = data.get("inputs").to(self.device)
psr_gt = data.get("gt_psr").to(self.device)
with torch.no_grad():
# forward pass
points, normals = model(p)
# DPSR to get predicted psr grid
psr_grid = self.dpsr(points, normals)
eval_dict['psr_l1'] = F.l1_loss(psr_grid, psr_gt).item()
eval_dict['psr_l2'] = F.mse_loss(psr_grid, psr_gt).item()
eval_dict["psr_l1"] = F.l1_loss(psr_grid, psr_gt).item()
eval_dict["psr_l2"] = F.mse_loss(psr_grid, psr_gt).item()
return eval_dict
return eval_dict

View file

@ -1,53 +1,53 @@
import torch
import io, os, logging, urllib
import yaml
import trimesh
import imageio
import numbers
import logging
import math
import numpy as np
import numbers
import os
import urllib
from collections import OrderedDict
import numpy as np
import open3d as o3d
import torch
import trimesh
import yaml
from igl import adjacency_matrix, connected_components
from plyfile import PlyData
from pytorch3d.renderer import PerspectiveCameras, rasterize_meshes
from pytorch3d.structures import Meshes
from skimage import measure
from torch import nn
from torch.nn import functional as F
from torch.utils import model_zoo
from skimage import measure, img_as_float32
from pytorch3d.structures import Meshes
from pytorch3d.renderer import PerspectiveCameras, rasterize_meshes
from igl import adjacency_matrix, connected_components
import open3d as o3d
##################################################
# Below are functions for DPSR
def fftfreqs(res, dtype=torch.float32, exact=True):
"""
Helper function to return frequency tensors
"""Helper function to return frequency tensors
:param res: n_dims int tuple of number of frequency modes
:return:
"""
n_dims = len(res)
freqs = []
for dim in range(n_dims - 1):
r_ = res[dim]
freq = np.fft.fftfreq(r_, d=1/r_)
freq = np.fft.fftfreq(r_, d=1 / r_)
freqs.append(torch.tensor(freq, dtype=dtype))
r_ = res[-1]
if exact:
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype))
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1 / r_), dtype=dtype))
else:
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype))
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1 / r_)[:-1], dtype=dtype))
omega = torch.meshgrid(freqs)
omega = list(omega)
omega = torch.stack(omega, dim=-1)
return omega
def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag)
"""
multiply tensor x by i ** deg
"""
def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag)
"""Multiply tensor x by i ** deg."""
deg %= 4
if deg == 0:
res = x
@ -61,17 +61,18 @@ def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag)
res[..., 1] = -res[..., 1]
return res
def spec_gaussian_filter(res, sig):
omega = fftfreqs(res, dtype=torch.float64) # [dim0, dim1, dim2, d]
dis = torch.sqrt(torch.sum(omega ** 2, dim=-1))
filter_ = torch.exp(-0.5*((sig*2*dis/res[0])**2)).unsqueeze(-1).unsqueeze(-1)
omega = fftfreqs(res, dtype=torch.float64) # [dim0, dim1, dim2, d]
dis = torch.sqrt(torch.sum(omega**2, dim=-1))
filter_ = torch.exp(-0.5 * ((sig * 2 * dis / res[0]) ** 2)).unsqueeze(-1).unsqueeze(-1)
filter_.requires_grad = False
return filter_
def grid_interp(grid, pts, batched=True):
"""
:param grid: tensor of shape (batch, *size, in_features)
""":param grid: tensor of shape (batch, *size, in_features)
:param pts: tensor of shape (batch, num_points, dim) within range (0, 1)
:return values at query points
"""
@ -82,69 +83,71 @@ def grid_interp(grid, pts, batched=True):
bs = grid.shape[0]
size = torch.tensor(grid.shape[1:-1]).to(grid.device).type(pts.dtype)
cubesize = 1.0 / size
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
tmp = torch.tensor([0,1],dtype=torch.long)
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
tmp = torch.tensor([0, 1], dtype=torch.long)
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
# latent code on neighbor nodes
if dim == 2:
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1]] # (batch, num_points, 2**dim, in_features)
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1]] # (batch, num_points, 2**dim, in_features)
else:
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2]] # (batch, num_points, 2**dim, in_features)
lat = grid.clone()[
ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2],
] # (batch, num_points, 2**dim, in_features)
# weights of neighboring nodes
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
xyz01[com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
pos_ = xyz01[1 - com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
pos_ = pos_.type(pts.dtype)
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
query_values = torch.sum(lat * weights.unsqueeze(-1), dim=-2) # (batch, num_points, in_features)
if not batched:
query_values = query_values.squeeze(0)
return query_values
def scatter_to_grid(inds, vals, size):
"""
Scatter update values into empty tensor of size size.
"""Scatter update values into empty tensor of size size.
:param inds: (#values, dims)
:param vals: (#values)
:param size: tuple for size. len(size)=dims
:param size: tuple for size. len(size)=dims.
"""
dims = inds.shape[1]
assert(inds.shape[0] == vals.shape[0])
assert(len(size) == dims)
assert inds.shape[0] == vals.shape[0]
assert len(size) == dims
dev = vals.device
# result = torch.zeros(*size).view(-1).to(dev).type(vals.dtype) # flatten
# # flatten inds
result = torch.zeros(*size, device=dev).view(-1).type(vals.dtype) # flatten
# flatten inds
fac = [np.prod(size[i+1:]) for i in range(len(size)-1)] + [1]
fac = [np.prod(size[i + 1 :]) for i in range(len(size) - 1)] + [1]
fac = torch.tensor(fac, device=dev).type(inds.dtype)
inds_fold = torch.sum(inds*fac, dim=-1) # [#values,]
inds_fold = torch.sum(inds * fac, dim=-1) # [#values,]
result.scatter_add_(0, inds_fold, vals)
result = result.view(*size)
return result
def point_rasterize(pts, vals, size):
"""
:param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1)
""":param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1)
:param vals: point values, tensor of shape (batch, num_points, features)
:param size: len(size)=dim tuple for grid size
:return rasterized values (batch, features, res0, res1, res2)
"""
dim = pts.shape[-1]
assert(pts.shape[:2] == vals.shape[:2])
assert(pts.shape[2] == dim)
assert pts.shape[:2] == vals.shape[:2]
assert pts.shape[2] == dim
size_list = list(size)
size = torch.tensor(size).to(pts.device).float()
cubesize = 1.0 / size
@ -152,55 +155,58 @@ def point_rasterize(pts, vals, size):
nf = vals.shape[-1]
npts = pts.shape[1]
dev = pts.device
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
tmp = torch.tensor([0,1],dtype=torch.long)
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
tmp = torch.tensor([0, 1], dtype=torch.long)
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
# ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
ind_b = torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
ind_b = (
torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1)
) # (batch, num_points, 2**dim)
# weights of neighboring nodes
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
xyz01[com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
pos_ = xyz01[1 - com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
pos_ = pos_.type(pts.dtype)
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
ind_b = ind_b.unsqueeze(-1).unsqueeze(-1) # (batch, num_points, 2**dim, 1, 1)
ind_n = ind_n.unsqueeze(-2) # (batch, num_points, 2**dim, 1, dim)
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
ind_b = ind_b.unsqueeze(-1).unsqueeze(-1) # (batch, num_points, 2**dim, 1, 1)
ind_n = ind_n.unsqueeze(-2) # (batch, num_points, 2**dim, 1, dim)
ind_f = torch.arange(nf, device=dev).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1)
# ind_f = torch.arange(nf).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1)
ind_b = ind_b.expand(bs, npts, 2**dim, nf, 1)
ind_n = ind_n.expand(bs, npts, 2**dim, nf, dim).to(dev)
ind_f = ind_f.expand(bs, npts, 2**dim, nf, 1)
inds = torch.cat([ind_b, ind_f, ind_n], dim=-1) # (batch, num_points, 2**dim, nf, 1+1+dim)
# weighted values
vals = weights.unsqueeze(-1) * vals.unsqueeze(-2) # (batch, num_points, 2**dim, nf)
inds = inds.view(-1, dim+2).permute(1, 0).long() # (1+dim+1, bs*npts*2**dim*nf)
vals = vals.reshape(-1) # (bs*npts*2**dim*nf)
tensor_size = [bs, nf] + size_list
raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf] + size_list)
return raster # [batch, nf, res, res, res]
# weighted values
vals = weights.unsqueeze(-1) * vals.unsqueeze(-2) # (batch, num_points, 2**dim, nf)
inds = inds.view(-1, dim + 2).permute(1, 0).long() # (1+dim+1, bs*npts*2**dim*nf)
vals = vals.reshape(-1) # (bs*npts*2**dim*nf)
[bs, nf, *size_list]
raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf, *size_list])
return raster # [batch, nf, res, res, res]
##################################################
# Below are the utilization functions in general
class AverageMeter(object):
"""Computes and stores the average and current value"""
class AverageMeter:
"""Computes and stores the average and current value."""
def __init__(self):
self.reset()
@ -226,49 +232,49 @@ class AverageMeter(object):
def avgcavg(self):
return self.avg.sum().item() / (self.count != 0).sum().item()
def load_model_manual(state_dict, model):
new_state_dict = OrderedDict()
is_model_parallel = isinstance(model, torch.nn.DataParallel)
for k, v in state_dict.items():
if k.startswith('module.') != is_model_parallel:
if k.startswith('module.'):
if k.startswith("module.") != is_model_parallel:
if k.startswith("module."):
# remove module
k = k[7:]
else:
# add module
k = 'module.' + k
k = "module." + k
new_state_dict[k]=v
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0):
'''
Run marching cubes from PSR grid
'''
"""Run marching cubes from PSR grid."""
batch_size = psr_grid.shape[0]
s = psr_grid.shape[-1] # size of psr_grid
s = psr_grid.shape[-1] # size of psr_grid
psr_grid_numpy = psr_grid.squeeze().detach().cpu().numpy()
if batch_size>1:
if batch_size > 1:
verts, faces, normals = [], [], []
for i in range(batch_size):
verts_cur, faces_cur, normals_cur, values = measure.marching_cubes(psr_grid_numpy[i], level=0)
verts.append(verts_cur)
faces.append(faces_cur)
normals.append(normals_cur)
verts = np.stack(verts, axis = 0)
faces = np.stack(faces, axis = 0)
normals = np.stack(normals, axis = 0)
verts = np.stack(verts, axis=0)
faces = np.stack(faces, axis=0)
normals = np.stack(normals, axis=0)
else:
try:
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy, level=zero_level)
except:
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy)
if real_scale:
verts = verts / (s-1) # scale to range [0, 1]
verts = verts / (s - 1) # scale to range [0, 1]
else:
verts = verts / s # scale to range [0, 1)
verts = verts / s # scale to range [0, 1)
if pytorchify:
device = psr_grid.device
@ -277,7 +283,8 @@ def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0):
normals = torch.Tensor(np.ascontiguousarray(-normals)).to(device)
return verts, faces, normals
def calc_inters_points(verts, faces, pose, img_size, mask_gt=None):
verts = verts.squeeze()
faces = faces.squeeze()
@ -288,42 +295,39 @@ def calc_inters_points(verts, faces, pose, img_size, mask_gt=None):
# find 3D points intesected on the mesh
if True:
w_masked = w[mask]
f_p = faces[pix_to_face[mask]].long() # cooresponding faces for each pixel
f_p = faces[pix_to_face[mask]].long() # cooresponding faces for each pixel
# corresponding vertices for p_closest
v_a, v_b, v_c = verts[f_p[..., 0]], verts[f_p[..., 1]], verts[f_p[..., 2]]
# calculate the intersection point of each pixel and the mesh
p_inters = w_masked[..., 0, None] * v_a + \
w_masked[..., 1, None] * v_b + \
w_masked[..., 2, None] * v_c
p_inters = w_masked[..., 0, None] * v_a + w_masked[..., 1, None] * v_b + w_masked[..., 2, None] * v_c
else:
# backproject ndc to world coordinates using z-buffer
W, H = img_size[1], img_size[0]
xy = uv.to(mask.device)[mask]
x_ndc = 1 - (2*xy[:, 0]) / (W - 1)
y_ndc = 1 - (2*xy[:, 1]) / (H - 1)
x_ndc = 1 - (2 * xy[:, 0]) / (W - 1)
y_ndc = 1 - (2 * xy[:, 1]) / (H - 1)
z = zbuf.squeeze().reshape(H * W)[mask]
xy_depth = torch.stack((x_ndc, y_ndc, z), dim=1)
p_inters = pose.unproject_points(xy_depth, world_coordinates=True)
# if there are outlier points, we should remove it
if (p_inters.max()>1) | (p_inters.min()<-1):
mask_bound = (p_inters>=-1) & (p_inters<=1)
mask_bound = (mask_bound.sum(dim=-1)==3)
mask[mask==True] = mask_bound
if (p_inters.max() > 1) | (p_inters.min() < -1):
mask_bound = (p_inters >= -1) & (p_inters <= 1)
mask_bound = mask_bound.sum(dim=-1) == 3
mask[mask is True] = mask_bound
p_inters = p_inters[mask_bound]
print('!!!!!find outlier!')
print("!!!!!find outlier!")
return p_inters, mask, f_p, w_masked
def mesh_rasterization(verts, faces, pose, img_size):
'''
Use PyTorch3D to rasterize the mesh given a camera
'''
transformed_v = pose.transform_points(verts.detach()) # world -> ndc coordinate system
"""Use PyTorch3D to rasterize the mesh given a camera."""
transformed_v = pose.transform_points(verts.detach()) # world -> ndc coordinate system
if isinstance(pose, PerspectiveCameras):
transformed_v[..., 2] = 1/transformed_v[..., 2]
transformed_v[..., 2] = 1 / transformed_v[..., 2]
# find p_closest on mesh of each pixel via rasterization
transformed_mesh = Meshes(verts=[transformed_v], faces=[faces])
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
@ -331,9 +335,9 @@ def mesh_rasterization(verts, faces, pose, img_size):
image_size=img_size,
blur_radius=0,
faces_per_pixel=1,
perspective_correct=False
perspective_correct=False,
)
pix_to_face = pix_to_face.reshape(1, -1) # B x reso x reso -> B x (reso x reso)
pix_to_face = pix_to_face.reshape(1, -1) # B x reso x reso -> B x (reso x reso)
mask = pix_to_face.clone() != -1
mask = mask.squeeze()
pix_to_face = pix_to_face.squeeze()
@ -341,11 +345,11 @@ def mesh_rasterization(verts, faces, pose, img_size):
return pix_to_face, w, mask
def verts_on_largest_mesh(verts, faces):
'''
verts: Numpy array or Torch.Tensor (N, 3)
faces: Numpy array (N, 3)
'''
"""verts: Numpy array or Torch.Tensor (N, 3)
faces: Numpy array (N, 3).
"""
if torch.is_tensor(faces):
verts = verts.squeeze().detach().cpu().numpy()
faces = faces.squeeze().int().detach().cpu().numpy()
@ -355,8 +359,8 @@ def verts_on_largest_mesh(verts, faces):
if num == 0:
v_large, f_large = verts, faces
else:
max_idx = conn_size.argmax() # find the index of the largest component
v_large = verts[conn_idx==max_idx] # keep points on the largest component
max_idx = conn_size.argmax() # find the index of the largest component
v_large = verts[conn_idx == max_idx] # keep points on the largest component
if True:
mesh_largest = trimesh.Trimesh(verts, faces)
@ -366,36 +370,41 @@ def verts_on_largest_mesh(verts, faces):
v_large = v_large.astype(np.float32)
return v_large, f_large
def load_pointcloud(in_file):
plydata = PlyData.read(in_file)
vertices = np.stack([
plydata['vertex']['x'],
plydata['vertex']['y'],
plydata['vertex']['z']
], axis=1)
vertices = np.stack(
[
plydata["vertex"]["x"],
plydata["vertex"]["y"],
plydata["vertex"]["z"],
],
axis=1,
)
return vertices
# General config
def load_config(path, default_path=None):
''' Loads config file.
"""Loads config file.
Args:
Args:
path (str): path to config file
default_path (bool): whether to use default path
'''
"""
# Load configuration from file itself
with open(path, 'r') as f:
with open(path) as f:
cfg_special = yaml.load(f, Loader=yaml.Loader)
# Check if we should inherit from a config
inherit_from = cfg_special.get('inherit_from')
inherit_from = cfg_special.get("inherit_from")
# If yes, load this config first as default
# If no, use the default_path
if inherit_from is not None:
cfg = load_config(inherit_from, default_path)
elif default_path is not None:
with open(default_path, 'r') as f:
with open(default_path) as f:
cfg = yaml.load(f, Loader=yaml.Loader)
else:
cfg = dict()
@ -405,65 +414,67 @@ def load_config(path, default_path=None):
return cfg
def update_config(config, unknown):
# update config given args
for idx,arg in enumerate(unknown):
for idx, arg in enumerate(unknown):
if arg.startswith("--"):
keys = arg.replace("--","").split(':')
assert(len(keys)==2)
keys = arg.replace("--", "").split(":")
assert len(keys) == 2
k1, k2 = keys
argtype = type(config[k1][k2])
if argtype == bool:
v = unknown[idx+1].lower() == 'true'
v = unknown[idx + 1].lower() == "true"
else:
if config[k1][k2] is not None:
v = type(config[k1][k2])(unknown[idx+1])
v = type(config[k1][k2])(unknown[idx + 1])
else:
v = unknown[idx+1]
print(f'Changing {k1}:{k2} ---- {config[k1][k2]} to {v}')
v = unknown[idx + 1]
print(f"Changing {k1}:{k2} ---- {config[k1][k2]} to {v}")
config[k1][k2] = v
return config
def initialize_logger(cfg):
out_dir = cfg['train']['out_dir']
out_dir = cfg["train"]["out_dir"]
if not out_dir:
os.makedirs(out_dir)
cfg['train']['dir_model'] = os.path.join(out_dir, 'model')
os.makedirs(cfg['train']['dir_model'], exist_ok=True)
if cfg['train']['exp_mesh']:
cfg['train']['dir_mesh'] = os.path.join(out_dir, 'vis/mesh')
os.makedirs(cfg['train']['dir_mesh'], exist_ok=True)
if cfg['train']['exp_pcl']:
cfg['train']['dir_pcl'] = os.path.join(out_dir, 'vis/pointcloud')
os.makedirs(cfg['train']['dir_pcl'], exist_ok=True)
if cfg['train']['vis_rendering']:
cfg['train']['dir_rendering'] = os.path.join(out_dir, 'vis/rendering')
os.makedirs(cfg['train']['dir_rendering'], exist_ok=True)
if cfg['train']['o3d_show']:
cfg['train']['dir_o3d'] = os.path.join(out_dir, 'vis/o3d')
os.makedirs(cfg['train']['dir_o3d'], exist_ok=True)
cfg["train"]["dir_model"] = os.path.join(out_dir, "model")
os.makedirs(cfg["train"]["dir_model"], exist_ok=True)
if cfg["train"]["exp_mesh"]:
cfg["train"]["dir_mesh"] = os.path.join(out_dir, "vis/mesh")
os.makedirs(cfg["train"]["dir_mesh"], exist_ok=True)
if cfg["train"]["exp_pcl"]:
cfg["train"]["dir_pcl"] = os.path.join(out_dir, "vis/pointcloud")
os.makedirs(cfg["train"]["dir_pcl"], exist_ok=True)
if cfg["train"]["vis_rendering"]:
cfg["train"]["dir_rendering"] = os.path.join(out_dir, "vis/rendering")
os.makedirs(cfg["train"]["dir_rendering"], exist_ok=True)
if cfg["train"]["o3d_show"]:
cfg["train"]["dir_o3d"] = os.path.join(out_dir, "vis/o3d")
os.makedirs(cfg["train"]["dir_o3d"], exist_ok=True)
logger = logging.getLogger("train")
logger.setLevel(logging.DEBUG)
logger.handlers = []
# ch = logging.StreamHandler()
# logger.addHandler(ch)
fh = logging.FileHandler(os.path.join(cfg['train']['out_dir'], "log.txt"))
fh = logging.FileHandler(os.path.join(cfg["train"]["out_dir"], "log.txt"))
logger.addHandler(fh)
logger.info('Outout dir: %s', out_dir)
logger.info("Outout dir: %s", out_dir)
return logger
def update_recursive(dict1, dict2):
''' Update two config dictionaries recursively.
"""Update two config dictionaries recursively.
Args:
dict1 (dict): first dictionary to be updated
dict2 (dict): second dictionary which entries should be used
'''
"""
for k, v in dict2.items():
if k not in dict1:
dict1[k] = dict()
@ -472,6 +483,7 @@ def update_recursive(dict1, dict2):
else:
dict1[k] = v
def export_pointcloud(name, points, normals=None):
if len(points.shape) > 2:
points = points[0]
@ -487,6 +499,7 @@ def export_pointcloud(name, points, normals=None):
pcd.normals = o3d.utility.Vector3dVector(normals)
o3d.io.write_point_cloud(name, pcd)
def export_mesh(name, v, f):
if len(v.shape) > 2:
v, f = v[0], f[0]
@ -498,59 +511,63 @@ def export_mesh(name, v, f):
mesh.triangles = o3d.utility.Vector3iVector(f)
o3d.io.write_triangle_mesh(name, mesh)
def scale2onet(p, scale=1.2):
'''
Scale the point cloud from SAP to ONet range
'''
"""Scale the point cloud from SAP to ONet range."""
return (p - 0.5) * scale
def update_optimizer(inputs, cfg, epoch, model=None, schedule=None):
if model is not None:
if schedule is not None:
optimizer = torch.optim.Adam([
{"params": model.parameters(),
"lr": schedule[0].get_learning_rate(epoch)},
{"params": inputs,
"lr": schedule[1].get_learning_rate(epoch)}])
elif 'lr' in cfg['train']:
optimizer = torch.optim.Adam([
{"params": model.parameters(),
"lr": float(cfg['train']['lr'])},
{"params": inputs,
"lr": float(cfg['train']['lr_pcl'])}])
optimizer = torch.optim.Adam(
[
{"params": model.parameters(), "lr": schedule[0].get_learning_rate(epoch)},
{"params": inputs, "lr": schedule[1].get_learning_rate(epoch)},
],
)
elif "lr" in cfg["train"]:
optimizer = torch.optim.Adam(
[
{"params": model.parameters(), "lr": float(cfg["train"]["lr"])},
{"params": inputs, "lr": float(cfg["train"]["lr_pcl"])},
],
)
else:
raise Exception('no known learning rate')
msg = "no known learning rate"
raise Exception(msg)
else:
if schedule is not None:
optimizer = torch.optim.Adam([inputs], lr=schedule[0].get_learning_rate(epoch))
else:
optimizer = torch.optim.Adam([inputs], lr=float(cfg['train']['lr_pcl']))
optimizer = torch.optim.Adam([inputs], lr=float(cfg["train"]["lr_pcl"]))
return optimizer
def is_url(url):
scheme = urllib.parse.urlparse(url).scheme
return scheme in ('http', 'https')
return scheme in ("http", "https")
def load_url(url):
'''Load a module dictionary from url.
"""Load a module dictionary from url.
Args:
url (str): url to saved model
'''
"""
print(url)
print('=> Loading checkpoint from url...')
print("=> Loading checkpoint from url...")
state_dict = model_zoo.load_url(url, progress=True)
return state_dict
class GaussianSmoothing(nn.Module):
"""
Apply gaussian smoothing on a
"""Apply gaussian smoothing on a
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
in the input using a depthwise convolution.
Arguments:
channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well.
kernel_size (int, sequence): Size of the gaussian kernel.
@ -558,8 +575,9 @@ class GaussianSmoothing(nn.Module):
dim (int, optional): The number of dimensions of the data.
Default value is 2 (spatial).
"""
def __init__(self, channels, kernel_size, sigma, dim=3):
super(GaussianSmoothing, self).__init__()
super().__init__()
if isinstance(kernel_size, numbers.Number):
kernel_size = [kernel_size] * dim
if isinstance(sigma, numbers.Number):
@ -569,15 +587,11 @@ class GaussianSmoothing(nn.Module):
# gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid(
[
torch.arange(size, dtype=torch.float32)
for size in kernel_size
]
[torch.arange(size, dtype=torch.float32) for size in kernel_size],
)
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
torch.exp(-((mgrid - mean) / std) ** 2 / 2)
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / std) ** 2) / 2)
# Make sure sum of values in gaussian kernel equals 1.
kernel = kernel / torch.sum(kernel)
@ -586,7 +600,7 @@ class GaussianSmoothing(nn.Module):
kernel = kernel.view(1, 1, *kernel.size())
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
self.register_buffer('weight', kernel)
self.register_buffer("weight", kernel)
self.groups = channels
if dim == 1:
@ -596,36 +610,44 @@ class GaussianSmoothing(nn.Module):
elif dim == 3:
self.conv = F.conv3d
else:
msg = f"Only 1, 2 and 3 dimensions are supported. Received {dim}."
raise RuntimeError(
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
msg,
)
def forward(self, input):
"""
Apply gaussian filter to input.
"""Apply gaussian filter to input.
Arguments:
input (torch.Tensor): Input to apply gaussian filter on.
Returns:
filtered (torch.Tensor): Filtered output.
"""
return self.conv(input, weight=self.weight, groups=self.groups)
# Originally from https://github.com/amosgropp/IGR/blob/0db06b1273/code/utils/general.py
def get_learning_rate_schedules(schedule_specs):
schedules = []
for key in schedule_specs.keys():
schedules.append(StepLearningRateSchedule(
schedule_specs[key]['initial'],
schedules.append(
StepLearningRateSchedule(
schedule_specs[key]["initial"],
schedule_specs[key]["interval"],
schedule_specs[key]["factor"],
schedule_specs[key]["final"]))
schedule_specs[key]["factor"],
schedule_specs[key]["final"],
),
)
return schedules
class LearningRateSchedule:
def get_learning_rate(self, epoch):
pass
class StepLearningRateSchedule(LearningRateSchedule):
def __init__(self, initial, interval, factor, final=1e-6):
self.initial = float(initial)
@ -640,6 +662,7 @@ class StepLearningRateSchedule(LearningRateSchedule):
else:
return self.final
def adjust_learning_rate(lr_schedules, optimizer, epoch):
for i, param_group in enumerate(optimizer.param_groups):
param_group["lr"] = lr_schedules[i].get_learning_rate(epoch)

View file

@ -1,95 +1,93 @@
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
import open3d as o3d
import matplotlib.pyplot as plt
from skimage import measure
from src.utils import calc_inters_points, grid_interp
import torch
from scipy import ndimage
from tqdm import trange
from torchvision.utils import save_image
from pdb import set_trace as st
from tqdm import trange
from src.utils import calc_inters_points, grid_interp
def visualize_points_mesh(vis, points, normals, verts, faces, cfg, it, epoch, color_v=None):
''' Visualization.
"""Visualization.
Args:
data (dict): data dictionary
depth (int): PSR depth
out_path (str): output path for the mesh
'''
out_path (str): output path for the mesh
"""
mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(verts)
mesh.triangles = o3d.utility.Vector3iVector(faces)
mesh.paint_uniform_color(np.array([0.7,0.7,0.7]))
mesh.paint_uniform_color(np.array([0.7, 0.7, 0.7]))
if color_v is not None:
mesh.vertex_colors = o3d.utility.Vector3dVector(color_v)
if vis is not None:
dir_o3d = cfg['train']['dir_o3d']
wire = o3d.geometry.LineSet.create_from_triangle_mesh(mesh)
dir_o3d = cfg["train"]["dir_o3d"]
o3d.geometry.LineSet.create_from_triangle_mesh(mesh)
p = points.squeeze(0).detach().cpu().numpy()
n = normals.squeeze(0).detach().cpu().numpy()
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(p)
pcd.normals = o3d.utility.Vector3dVector(n)
pcd.paint_uniform_color(np.array([0.7,0.7,1.0]))
pcd.paint_uniform_color(np.array([0.7, 0.7, 1.0]))
# pcd = pcd.uniform_down_sample(5)
vis.clear_geometries()
vis.add_geometry(mesh)
vis.update_geometry(mesh)
#! Thingi wheel - an example for how to change cameras in Open3D viewers
vis.get_view_control().set_front([ 0.0461, -0.7467, 0.6635 ])
vis.get_view_control().set_lookat([ 0.0092, 0.0078, 0.0638 ])
vis.get_view_control().set_up([ 0.0520, 0.6651, 0.7449 ])
vis.get_view_control().set_front([0.0461, -0.7467, 0.6635])
vis.get_view_control().set_lookat([0.0092, 0.0078, 0.0638])
vis.get_view_control().set_up([0.0520, 0.6651, 0.7449])
vis.get_view_control().set_zoom(0.7)
vis.poll_events()
out_path = os.path.join(dir_o3d, '{}.jpg'.format(it))
out_path = os.path.join(dir_o3d, f"{it}.jpg")
vis.capture_screen_image(out_path)
vis.clear_geometries()
vis.add_geometry(pcd, reset_bounding_box=False)
vis.update_geometry(pcd)
vis.get_render_option().point_show_normal=True # visualize point normals
vis.get_render_option().point_show_normal = True # visualize point normals
vis.get_view_control().set_front([ 0.0461, -0.7467, 0.6635 ])
vis.get_view_control().set_lookat([ 0.0092, 0.0078, 0.0638 ])
vis.get_view_control().set_up([ 0.0520, 0.6651, 0.7449 ])
vis.get_view_control().set_front([0.0461, -0.7467, 0.6635])
vis.get_view_control().set_lookat([0.0092, 0.0078, 0.0638])
vis.get_view_control().set_up([0.0520, 0.6651, 0.7449])
vis.get_view_control().set_zoom(0.7)
vis.poll_events()
out_path = os.path.join(dir_o3d, '{}_pcd.jpg'.format(it))
out_path = os.path.join(dir_o3d, f"{it}_pcd.jpg")
vis.capture_screen_image(out_path)
def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name='video.mp4'):
def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name="video.mp4"):
if pose is not None:
device = psr_grid.device
# get world coordinate of grid points [-1, 1]
res = psr_grid.shape[-1]
x = torch.linspace(-1, 1, steps=res)
co_x, co_y, co_z = torch.meshgrid(x, x, x)
co_grid = torch.stack(
[co_x.reshape(-1), co_y.reshape(-1), co_z.reshape(-1)],
dim=1).to(device).unsqueeze(0)
co_grid = torch.stack([co_x.reshape(-1), co_y.reshape(-1), co_z.reshape(-1)], dim=1).to(device).unsqueeze(0)
# visualize the projected occ_soft value
res = 128
psr_grid = psr_grid.reshape(-1)
out_mask = psr_grid>0
in_mask = psr_grid<0
out_mask = psr_grid > 0
in_mask = psr_grid < 0
pix = pose.transform_points_screen(co_grid, ((res, res),))[..., :2].round().long().squeeze()
vis_mask = (pix[..., 0]>=0) & (pix[..., 0]<=res-1) & \
(pix[..., 1]>=0) & (pix[..., 1]<=res-1)
vis_mask = (pix[..., 0] >= 0) & (pix[..., 0] <= res - 1) & (pix[..., 1] >= 0) & (pix[..., 1] <= res - 1)
pix_out = pix[vis_mask & out_mask]
pix_in = pix[vis_mask & in_mask]
img = torch.ones([res,res]).to(device)
psr_grid = torch.sigmoid(- psr_grid * 5)
img = torch.ones([res, res]).to(device)
psr_grid = torch.sigmoid(-psr_grid * 5)
img[pix_out[:, 1], pix_out[:, 0]] = psr_grid[vis_mask & out_mask]
img[pix_in[:, 1], pix_in[:, 0]] = psr_grid[vis_mask & in_mask]
# save_image(img, 'tmp.png', normalize=True)
@ -98,78 +96,78 @@ def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name='video.
dir_psr_vis = out_dir
os.makedirs(dir_psr_vis, exist_ok=True)
psr_grid = psr_grid.squeeze().detach().cpu().numpy()
axis = ['x', 'y', 'z']
s = psr_grid.shape[0]
for idx in trange(s):
my_dpi = 100
plt.figure(figsize=(1000/my_dpi, 300/my_dpi), dpi=my_dpi)
plt.figure(figsize=(1000 / my_dpi, 300 / my_dpi), dpi=my_dpi)
plt.subplot(1, 3, 1)
plt.imshow(ndimage.rotate(psr_grid[idx], 180, mode='nearest'), cmap='nipy_spectral')
plt.imshow(ndimage.rotate(psr_grid[idx], 180, mode="nearest"), cmap="nipy_spectral")
plt.clim(-1, 1)
plt.colorbar()
plt.title('x')
plt.title("x")
plt.grid("off")
plt.axis("off")
plt.subplot(1, 3, 2)
plt.imshow(ndimage.rotate(psr_grid[:, idx], 180, mode='nearest'), cmap='nipy_spectral')
plt.imshow(ndimage.rotate(psr_grid[:, idx], 180, mode="nearest"), cmap="nipy_spectral")
plt.clim(-1, 1)
plt.colorbar()
plt.title('y')
plt.title("y")
plt.grid("off")
plt.axis("off")
plt.subplot(1, 3, 3)
plt.imshow(ndimage.rotate(psr_grid[:,:,idx], 90, mode='nearest'), cmap='nipy_spectral')
plt.imshow(ndimage.rotate(psr_grid[:, :, idx], 90, mode="nearest"), cmap="nipy_spectral")
plt.clim(-1, 1)
plt.colorbar()
plt.title('z')
plt.title("z")
plt.grid("off")
plt.axis("off")
plt.savefig(os.path.join(dir_psr_vis, '{}'.format(idx)), pad_inches = 0, dpi=100)
plt.savefig(os.path.join(dir_psr_vis, f"{idx}"), pad_inches=0, dpi=100)
plt.close()
os.system("rm {}/{}".format(dir_psr_vis, out_video_name))
os.system("ffmpeg -framerate 25 -start_number 0 -i {}/%d.png -pix_fmt yuv420p -crf 17 {}/{}".format(dir_psr_vis, dir_psr_vis, out_video_name))
os.system(f"rm {dir_psr_vis}/{out_video_name}")
os.system(
f"ffmpeg -framerate 25 -start_number 0 -i {dir_psr_vis}/%d.png -pix_fmt yuv420p -crf 17 {dir_psr_vis}/{out_video_name}",
)
return None
return None
def visualize_mesh_phong(v, f, n, pose, img_size, name, device='cpu'):
def visualize_mesh_phong(v, f, n, pose, img_size, name, device="cpu"):
#! Mesh rendering using Phong shading model
_, mask, f_p, w = calc_inters_points(v, f, pose, img_size)
n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]]
n_inters = w[..., 0, None] * n_a.squeeze() + \
w[..., 1, None] * n_b.squeeze() + \
w[..., 2, None] * n_c.squeeze()
n_inters = w[..., 0, None] * n_a.squeeze() + w[..., 1, None] * n_b.squeeze() + w[..., 2, None] * n_c.squeeze()
n_inters = n_inters.detach().to(device)
light_source = -pose.R@pose.T.squeeze()
light_source = -pose.R @ pose.T.squeeze()
light = (light_source / light_source.norm(2)).permute(1, 0).to(device).float()
diffuse_per = torch.Tensor([0.7,0.7,0.7]).float()
ambiant = torch.Tensor([0.3,0.3,0.3]).float()
diffuse_per = torch.Tensor([0.7, 0.7, 0.7]).float()
ambiant = torch.Tensor([0.3, 0.3, 0.3]).float()
diffuse = torch.mm(n_inters, light).clamp_min(0).repeat(1, 3) * diffuse_per.unsqueeze(0).to(device)
phong = torch.ones([img_size[0]*img_size[1], 3]).to(device)
phong = torch.ones([img_size[0] * img_size[1], 3]).to(device)
phong[mask] = (ambiant.unsqueeze(0).to(device) + diffuse).clamp_max(1.0)
pp = phong.reshape(img_size[0], img_size[1], -1)
save_image(pp.permute(2, 0, 1), name)
def render_rgb(v, f, n, pose, renderer, img_size, mask_gt=None, ray=None, fea_grid=None):
p_inters, mask, f_p, w = calc_inters_points(v.detach(), f, pose, img_size, mask_gt=mask_gt)
# normals for p_inters
n_inters = None
if n is not None:
n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]]
n_inters = w[..., 0, None] * n_a.squeeze() + \
w[..., 1, None] * n_b.squeeze() + \
w[..., 2, None] * n_c.squeeze()
if ray is not None:
ray = ray.squeeze()[mask]
fea = None
if fea_grid is not None:
fea = grid_interp(fea_grid, (p_inters.detach()[None] + 1) / 2).squeeze()
p_inters, mask, f_p, w = calc_inters_points(v.detach(), f, pose, img_size, mask_gt=mask_gt)
# normals for p_inters
n_inters = None
if n is not None:
n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]]
n_inters = w[..., 0, None] * n_a.squeeze() + w[..., 1, None] * n_b.squeeze() + w[..., 2, None] * n_c.squeeze()
if ray is not None:
ray = ray.squeeze()[mask]
# use MLP to regress color
color_pred = renderer(p_inters, normals=n_inters, view_dirs=ray, feature_vectors=fea).squeeze()
fea = None
if fea_grid is not None:
fea = grid_interp(fea_grid, (p_inters.detach()[None] + 1) / 2).squeeze()
return color_pred, mask
# use MLP to regress color
color_pred = renderer(p_inters, normals=n_inters, view_dirs=ray, feature_vectors=fea).squeeze()
return color_pred, mask

189
train.py
View file

@ -4,182 +4,185 @@ abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)
import argparse
import shutil
import time
import numpy as np
import torch
import torch.optim as optim
import open3d as o3d
import numpy as np; np.set_printoptions(precision=4)
import shutil, argparse, time
from torch.utils.tensorboard import SummaryWriter
from src import config
from src.data import collate_remove_none, collate_stack_together, worker_init_fn
from src.training import Trainer
from src.data import collate_remove_none, worker_init_fn
from src.model import Encode2Points
from src.utils import load_config, initialize_logger, \
AverageMeter, load_model_manual
from src.training import Trainer
from src.utils import AverageMeter, initialize_logger, load_config, load_model_manual
np.set_printoptions(precision=4)
def main():
parser = argparse.ArgumentParser(description='MNIST toy experiment')
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--no_cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
parser = argparse.ArgumentParser(description="MNIST toy experiment")
parser.add_argument("config", type=str, help="Path to config file.")
parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
args = parser.parse_args()
cfg = load_config(args.config, 'configs/default.yaml')
cfg = load_config(args.config, "configs/default.yaml")
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
input_type = cfg['data']['input_type']
batch_size = cfg['train']['batch_size']
model_selection_metric = cfg['train']['model_selection_metric']
cfg["data"]["input_type"]
batch_size = cfg["train"]["batch_size"]
model_selection_metric = cfg["train"]["model_selection_metric"]
# PYTORCH VERSION > 1.0.0
assert(float(torch.__version__.split('.')[-3]) > 0)
assert float(torch.__version__.split(".")[-3]) > 0
# boiler-plate
if cfg['train']['timestamp']:
cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S")
if cfg["train"]["timestamp"]:
cfg["train"]["out_dir"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
logger = initialize_logger(cfg)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
shutil.copyfile(args.config, os.path.join(cfg['train']['out_dir'], 'config.yaml'))
shutil.copyfile(args.config, os.path.join(cfg["train"]["out_dir"], "config.yaml"))
logger.info("using GPU: " + torch.cuda.get_device_name(0))
# TensorboardX writer
tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log")
tblogdir = os.path.join(cfg["train"]["out_dir"], "tensorboard_log")
if not os.path.exists(tblogdir):
os.makedirs(tblogdir, exist_ok=True)
writer = SummaryWriter(log_dir=tblogdir)
inputs = None
train_dataset = config.get_dataset('train', cfg)
val_dataset = config.get_dataset('val', cfg)
vis_dataset = config.get_dataset('vis', cfg)
train_dataset = config.get_dataset("train", cfg)
val_dataset = config.get_dataset("val", cfg)
vis_dataset = config.get_dataset("vis", cfg)
collate_fn = collate_remove_none
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, num_workers=cfg['train']['n_workers'], shuffle=True,
collate_fn=collate_fn,
worker_init_fn=worker_init_fn)
train_dataset,
batch_size=batch_size,
num_workers=cfg["train"]["n_workers"],
shuffle=True,
collate_fn=collate_fn,
worker_init_fn=worker_init_fn,
)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False,
collate_fn=collate_remove_none,
worker_init_fn=worker_init_fn)
val_dataset,
batch_size=1,
num_workers=cfg["train"]["n_workers_val"],
shuffle=False,
collate_fn=collate_remove_none,
worker_init_fn=worker_init_fn,
)
vis_loader = torch.utils.data.DataLoader(
vis_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False,
collate_fn=collate_fn,
worker_init_fn=worker_init_fn)
vis_dataset,
batch_size=1,
num_workers=cfg["train"]["n_workers_val"],
shuffle=False,
collate_fn=collate_fn,
worker_init_fn=worker_init_fn,
)
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(Encode2Points(cfg)).to(device)
else:
model = Encode2Points(cfg).to(device)
n_parameter = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info('Number of parameters: %d'% n_parameter)
logger.info("Number of parameters: %d" % n_parameter)
# load model
try:
# load model
state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt'))
load_model_manual(state_dict['state_dict'], model)
out = "Load model from iteration %d" % state_dict.get('it', 0)
state_dict = torch.load(os.path.join(cfg["train"]["out_dir"], "model.pt"))
load_model_manual(state_dict["state_dict"], model)
out = "Load model from iteration %d" % state_dict.get("it", 0)
logger.info(out)
# load point cloud
except:
state_dict = dict()
metric_val_best = state_dict.get(
'loss_val_best', np.inf)
logger.info('Current best validation metric (%s): %.8f'
% (model_selection_metric, metric_val_best))
metric_val_best = state_dict.get("loss_val_best", np.inf)
LR = float(cfg['train']['lr'])
logger.info(f"Current best validation metric ({model_selection_metric}): {metric_val_best:.8f}")
LR = float(cfg["train"]["lr"])
optimizer = optim.Adam(model.parameters(), lr=LR)
start_epoch = state_dict.get('epoch', -1)
it = state_dict.get('it', -1)
start_epoch = state_dict.get("epoch", -1)
it = state_dict.get("it", -1)
trainer = Trainer(cfg, optimizer, device=device)
runtime = {}
runtime['all'] = AverageMeter()
# training loop
for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1):
runtime["all"] = AverageMeter()
# training loop
for epoch in range(start_epoch + 1, cfg["train"]["total_epochs"] + 1):
for batch in train_loader:
it += 1
start = time.time()
loss, loss_each = trainer.train_step(inputs, batch, model)
# measure elapsed time
end = time.time()
runtime['all'].update(end - start)
runtime["all"].update(end - start)
if it % cfg['train']['print_every'] == 0:
log_text = ('[Epoch %02d] it=%d, loss=%.4f') %(epoch, it, loss)
writer.add_scalar('train/loss', loss, it)
if it % cfg["train"]["print_every"] == 0:
log_text = ("[Epoch %02d] it=%d, loss=%.4f") % (epoch, it, loss)
writer.add_scalar("train/loss", loss, it)
if loss_each is not None:
for k, l in loss_each.items():
if l.item() != 0.:
log_text += (' loss_%s=%.4f') % (k, l.item())
writer.add_scalar('train/%s' % k, l, it)
log_text += (' time=%.3f / %.2f') % (runtime['all'].val, runtime['all'].sum)
if l.item() != 0.0:
log_text += f" loss_{k}={l.item():.4f}"
writer.add_scalar("train/%s" % k, l, it)
log_text += (" time={:.3f} / {:.2f}").format(runtime["all"].val, runtime["all"].sum)
logger.info(log_text)
if (it>0)& (it % cfg['train']['visualize_every'] == 0):
if (it > 0) & (it % cfg["train"]["visualize_every"] == 0):
for i, batch_vis in enumerate(vis_loader):
trainer.save(model, batch_vis, it, i)
if i >= 4:
break
logger.info('Saved mesh and pointcloud')
logger.info("Saved mesh and pointcloud")
# run validation
if it > 0 and (it % cfg['train']['validate_every']) == 0:
if it > 0 and (it % cfg["train"]["validate_every"]) == 0:
eval_dict = trainer.evaluate(val_loader, model)
metric_val = eval_dict[model_selection_metric]
logger.info('Validation metric (%s): %.4f'
% (model_selection_metric, metric_val))
for k, v in eval_dict.items():
writer.add_scalar('val/%s' % k, v, it)
logger.info(f"Validation metric ({model_selection_metric}): {metric_val:.4f}")
if -(metric_val - metric_val_best) >= 0:
for k, v in eval_dict.items():
writer.add_scalar("val/%s" % k, v, it)
if -(metric_val - metric_val_best) >= 0:
metric_val_best = metric_val
logger.info('New best model (loss %.4f)' % metric_val_best)
state = {'epoch': epoch,
'it': it,
'loss_val_best': metric_val_best}
state['state_dict'] = model.state_dict()
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model_best.pt'))
logger.info("New best model (loss %.4f)" % metric_val_best)
state = {"epoch": epoch, "it": it, "loss_val_best": metric_val_best}
state["state_dict"] = model.state_dict()
torch.save(state, os.path.join(cfg["train"]["out_dir"], "model_best.pt"))
# save checkpoint
if (epoch > 0) & (it % cfg['train']['checkpoint_every'] == 0):
state = {'epoch': epoch,
'it': it,
'loss_val_best': metric_val_best}
pcl = None
state['state_dict'] = model.state_dict()
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt'))
if (epoch > 0) & (it % cfg["train"]["checkpoint_every"] == 0):
state = {"epoch": epoch, "it": it, "loss_val_best": metric_val_best}
state["state_dict"] = model.state_dict()
if (it % cfg['train']['backup_every'] == 0):
torch.save(state, os.path.join(cfg['train']['dir_model'], '%04d' % it + '.pt'))
torch.save(state, os.path.join(cfg["train"]["out_dir"], "model.pt"))
if it % cfg["train"]["backup_every"] == 0:
torch.save(state, os.path.join(cfg["train"]["dir_model"], "%04d" % it + ".pt"))
logger.info("Backup model at iteration %d" % it)
logger.info("Save new model at iteration %d" % it)
done=time.time()
time.time()
if __name__ == '__main__':
main()
if __name__ == "__main__":
main()