diff --git a/eval_meshes.py b/eval_meshes.py index 96a82e5..fb5a05c 100644 --- a/eval_meshes.py +++ b/eval_meshes.py @@ -1,155 +1,145 @@ +import argparse +import os + +import numpy as np +import pandas as pd import torch import trimesh -from torch.utils.data import Dataset, DataLoader -import numpy as np; np.set_printoptions(precision=4) -import shutil, argparse, time, os -import pandas as pd -from src.data import collate_remove_none, collate_stack_together, worker_init_fn -from src.training import Trainer -from src.model import Encode2Points -from src.data import PointCloudField, IndexField, Shapes3dDataset -from src.utils import load_config, load_pointcloud -from src.eval import MeshEvaluator from tqdm import tqdm -from pdb import set_trace as st + +from src.data import IndexField, PointCloudField, Shapes3dDataset +from src.eval import MeshEvaluator +from src.utils import load_config, load_pointcloud + +np.set_printoptions(precision=4) def main(): - parser = argparse.ArgumentParser(description='MNIST toy experiment') - parser.add_argument('config', type=str, help='Path to config file.') - parser.add_argument('--no_cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') - parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.') - + parser = argparse.ArgumentParser(description="MNIST toy experiment") + parser.add_argument("config", type=str, help="Path to config file.") + parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument("--iter", type=int, metavar="S", help="the training iteration to be evaluated.") + args = parser.parse_args() - cfg = load_config(args.config, 'configs/default.yaml') + cfg = load_config(args.config, "configs/default.yaml") use_cuda = not args.no_cuda and torch.cuda.is_available() - device = torch.device("cuda" if use_cuda else "cpu") - data_type = cfg['data']['data_type'] + torch.device("cuda" if use_cuda else "cpu") + cfg["data"]["data_type"] # Shorthands - out_dir = cfg['train']['out_dir'] - generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir']) + out_dir = cfg["train"]["out_dir"] + generation_dir = os.path.join(out_dir, cfg["generation"]["generation_dir"]) - if cfg['generation'].get('iter', 0)!=0: - generation_dir += '_%04d'%cfg['generation']['iter'] + if cfg["generation"].get("iter", 0) != 0: + generation_dir += "_%04d" % cfg["generation"]["iter"] elif args.iter is not None: - generation_dir += '_%04d'%args.iter - - print('Evaluate meshes under %s'%generation_dir) - - out_file = os.path.join(generation_dir, 'eval_meshes_full.pkl') - out_file_class = os.path.join(generation_dir, 'eval_meshes.csv') - - # PYTORCH VERSION > 1.0.0 - assert(float(torch.__version__.split('.')[-3]) > 0) + generation_dir += "_%04d" % args.iter - pointcloud_field = PointCloudField(cfg['data']['pointcloud_file']) + print("Evaluate meshes under %s" % generation_dir) + + out_file = os.path.join(generation_dir, "eval_meshes_full.pkl") + out_file_class = os.path.join(generation_dir, "eval_meshes.csv") + + # PYTORCH VERSION > 1.0.0 + assert float(torch.__version__.split(".")[-3]) > 0 + + pointcloud_field = PointCloudField(cfg["data"]["pointcloud_file"]) fields = { - 'pointcloud': pointcloud_field, - 'idx': IndexField(), + "pointcloud": pointcloud_field, + "idx": IndexField(), } - print('Test split: ', cfg['data']['test_split']) + print("Test split: ", cfg["data"]["test_split"]) - dataset_folder = cfg['data']['path'] + dataset_folder = cfg["data"]["path"] dataset = Shapes3dDataset( - dataset_folder, fields, - cfg['data']['test_split'], - categories=cfg['data']['class'], cfg=cfg) - + dataset_folder, fields, cfg["data"]["test_split"], categories=cfg["data"]["class"], cfg=cfg, + ) + # Loader - test_loader = torch.utils.data.DataLoader( - dataset, batch_size=1, num_workers=0, shuffle=False) - + test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False) + # Evaluator evaluator = MeshEvaluator(n_points=100000) - eval_dicts = [] - print('Evaluating meshes...') - for it, data in enumerate(tqdm(test_loader)): - + eval_dicts = [] + print("Evaluating meshes...") + for _it, data in enumerate(tqdm(test_loader)): if data is None: - print('Invalid data.') + print("Invalid data.") continue - mesh_dir = os.path.join(generation_dir, 'meshes') - pointcloud_dir = os.path.join(generation_dir, 'pointcloud') + mesh_dir = os.path.join(generation_dir, "meshes") + pointcloud_dir = os.path.join(generation_dir, "pointcloud") - # Get index etc. - idx = data['idx'].item() + idx = data["idx"].item() try: model_dict = dataset.get_model_dict(idx) except AttributeError: - model_dict = {'model': str(idx), 'category': 'n/a'} + model_dict = {"model": str(idx), "category": "n/a"} - modelname = model_dict['model'] - category_id = model_dict['category'] + modelname = model_dict["model"] + category_id = model_dict["category"] try: - category_name = dataset.metadata[category_id].get('name', 'n/a') + category_name = dataset.metadata[category_id].get("name", "n/a") except AttributeError: - category_name = 'n/a' + category_name = "n/a" - if category_id != 'n/a': + if category_id != "n/a": mesh_dir = os.path.join(mesh_dir, category_id) pointcloud_dir = os.path.join(pointcloud_dir, category_id) # Evaluate - pointcloud_tgt = data['pointcloud'].squeeze(0).numpy() - normals_tgt = data['pointcloud.normals'].squeeze(0).numpy() + pointcloud_tgt = data["pointcloud"].squeeze(0).numpy() + normals_tgt = data["pointcloud.normals"].squeeze(0).numpy() - eval_dict = { - 'idx': idx, - 'class id': category_id, - 'class name': category_name, - 'modelname':modelname, + "idx": idx, + "class id": category_id, + "class name": category_name, + "modelname": modelname, } eval_dicts.append(eval_dict) - + # Evaluate mesh - if cfg['test']['eval_mesh']: - mesh_file = os.path.join(mesh_dir, '%s.off' % modelname) + if cfg["test"]["eval_mesh"]: + mesh_file = os.path.join(mesh_dir, "%s.off" % modelname) if os.path.exists(mesh_file): mesh = trimesh.load(mesh_file, process=False) - eval_dict_mesh = evaluator.eval_mesh( - mesh, pointcloud_tgt, normals_tgt) + eval_dict_mesh = evaluator.eval_mesh(mesh, pointcloud_tgt, normals_tgt) for k, v in eval_dict_mesh.items(): - eval_dict[k + ' (mesh)'] = v + eval_dict[k + " (mesh)"] = v else: - print('Warning: mesh does not exist: %s' % mesh_file) + print("Warning: mesh does not exist: %s" % mesh_file) # Evaluate point cloud - if cfg['test']['eval_pointcloud']: - pointcloud_file = os.path.join( - pointcloud_dir, '%s.ply' % modelname) + if cfg["test"]["eval_pointcloud"]: + pointcloud_file = os.path.join(pointcloud_dir, "%s.ply" % modelname) if os.path.exists(pointcloud_file): pointcloud = load_pointcloud(pointcloud_file).astype(np.float32) - eval_dict_pcl = evaluator.eval_pointcloud( - pointcloud, pointcloud_tgt) + eval_dict_pcl = evaluator.eval_pointcloud(pointcloud, pointcloud_tgt) for k, v in eval_dict_pcl.items(): - eval_dict[k + ' (pcl)'] = v + eval_dict[k + " (pcl)"] = v else: - print('Warning: pointcloud does not exist: %s' - % pointcloud_file) - - + print("Warning: pointcloud does not exist: %s" % pointcloud_file) + # Create pandas dataframe and save eval_df = pd.DataFrame(eval_dicts) - eval_df.set_index(['idx'], inplace=True) + eval_df.set_index(["idx"], inplace=True) eval_df.to_pickle(out_file) # Create CSV file with main statistics - eval_df_class = eval_df.groupby(by=['class name']).mean() - eval_df_class.loc['mean'] = eval_df_class.mean() + eval_df_class = eval_df.groupby(by=["class name"]).mean() + eval_df_class.loc["mean"] = eval_df_class.mean() eval_df_class.to_csv(out_file_class) # Print results print(eval_df_class) -if __name__ == '__main__': - main() \ No newline at end of file + +if __name__ == "__main__": + main() diff --git a/generate.py b/generate.py index 08dd28b..1b6ec1e 100644 --- a/generate.py +++ b/generate.py @@ -1,155 +1,164 @@ -import torch -from torch.utils.data import Dataset, DataLoader -import numpy as np; np.set_printoptions(precision=4) -import shutil, argparse, time, os -import pandas as pd +import argparse +import os +import shutil from collections import defaultdict -from src import config -from src.utils import mc_from_psr, export_mesh, export_pointcloud -from src.dpsr import DPSR -from src.training import Trainer -from src.model import Encode2Points -from src.utils import load_config, load_model_manual, scale2onet, is_url, load_url + +import numpy as np +import pandas as pd +import torch from tqdm import tqdm -from pdb import set_trace as st + +from src import config +from src.dpsr import DPSR +from src.model import Encode2Points +from src.utils import ( + export_mesh, + export_pointcloud, + is_url, + load_config, + load_model_manual, + load_url, + mc_from_psr, + scale2onet, +) + +np.set_printoptions(precision=4) def main(): - parser = argparse.ArgumentParser(description='MNIST toy experiment') - parser.add_argument('config', type=str, help='Path to config file.') - parser.add_argument('--no_cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') - parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.') - + parser = argparse.ArgumentParser(description="MNIST toy experiment") + parser.add_argument("config", type=str, help="Path to config file.") + parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument("--iter", type=int, metavar="S", help="the training iteration to be evaluated.") + args = parser.parse_args() - cfg = load_config(args.config, 'configs/default.yaml') + cfg = load_config(args.config, "configs/default.yaml") use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") - data_type = cfg['data']['data_type'] - input_type = cfg['data']['input_type'] - vis_n_outputs = cfg['generation']['vis_n_outputs'] + cfg["data"]["data_type"] + cfg["data"]["input_type"] + vis_n_outputs = cfg["generation"]["vis_n_outputs"] if vis_n_outputs is None: vis_n_outputs = -1 # Shorthands - out_dir = cfg['train']['out_dir'] + out_dir = cfg["train"]["out_dir"] if not out_dir: os.makedirs(out_dir) - generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir']) - out_time_file = os.path.join(generation_dir, 'time_generation_full.pkl') - out_time_file_class = os.path.join(generation_dir, 'time_generation.pkl') + generation_dir = os.path.join(out_dir, cfg["generation"]["generation_dir"]) + out_time_file = os.path.join(generation_dir, "time_generation_full.pkl") + out_time_file_class = os.path.join(generation_dir, "time_generation.pkl") # PYTORCH VERSION > 1.0.0 - assert(float(torch.__version__.split('.')[-3]) > 0) + assert float(torch.__version__.split(".")[-3]) > 0 - dataset = config.get_dataset('test', cfg, return_idx=True) - test_loader = torch.utils.data.DataLoader( - dataset, batch_size=1, num_workers=0, shuffle=False) + dataset = config.get_dataset("test", cfg, return_idx=True) + test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False) model = Encode2Points(cfg).to(device) - + # load model try: - if is_url(cfg['test']['model_file']): - state_dict = load_url(cfg['test']['model_file']) - elif cfg['generation'].get('iter', 0)!=0: - state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% cfg['generation']['iter'])) - generation_dir += '_%04d'%cfg['generation']['iter'] + if is_url(cfg["test"]["model_file"]): + state_dict = load_url(cfg["test"]["model_file"]) + elif cfg["generation"].get("iter", 0) != 0: + state_dict = torch.load(os.path.join(out_dir, "model", "%04d.pt" % cfg["generation"]["iter"])) + generation_dir += "_%04d" % cfg["generation"]["iter"] elif args.iter is not None: - state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% args.iter)) + state_dict = torch.load(os.path.join(out_dir, "model", "%04d.pt" % args.iter)) else: - state_dict = torch.load(os.path.join(out_dir, 'model_best.pt')) + state_dict = torch.load(os.path.join(out_dir, "model_best.pt")) - load_model_manual(state_dict['state_dict'], model) + load_model_manual(state_dict["state_dict"], model) except: - print('Model loading error. Exiting.') + print("Model loading error. Exiting.") exit() - - + # Generator generator = config.get_generator(model, cfg, device=device) - + # Determine what to generate - generate_mesh = cfg['generation']['generate_mesh'] - generate_pointcloud = cfg['generation']['generate_pointcloud'] - + generate_mesh = cfg["generation"]["generate_mesh"] + generate_pointcloud = cfg["generation"]["generate_pointcloud"] + # Statistics time_dicts = [] # Generate model.eval() - dpsr = DPSR(res=(cfg['generation']['psr_resolution'], - cfg['generation']['psr_resolution'], - cfg['generation']['psr_resolution']), - sig= cfg['generation']['psr_sigma']).to(device) - - + dpsr = DPSR( + res=( + cfg["generation"]["psr_resolution"], + cfg["generation"]["psr_resolution"], + cfg["generation"]["psr_resolution"], + ), + sig=cfg["generation"]["psr_sigma"], + ).to(device) # Count how many models already created model_counter = defaultdict(int) - print('Generating...') - for it, data in enumerate(tqdm(test_loader)): - + print("Generating...") + for _it, data in enumerate(tqdm(test_loader)): # Output folders - mesh_dir = os.path.join(generation_dir, 'meshes') - in_dir = os.path.join(generation_dir, 'input') - pointcloud_dir = os.path.join(generation_dir, 'pointcloud') - generation_vis_dir = os.path.join(generation_dir, 'vis', ) + mesh_dir = os.path.join(generation_dir, "meshes") + in_dir = os.path.join(generation_dir, "input") + pointcloud_dir = os.path.join(generation_dir, "pointcloud") + generation_vis_dir = os.path.join(generation_dir, "vis") # Get index etc. - idx = data['idx'].item() - + idx = data["idx"].item() + try: model_dict = dataset.get_model_dict(idx) except AttributeError: - model_dict = {'model': str(idx), 'category': 'n/a'} + model_dict = {"model": str(idx), "category": "n/a"} - modelname = model_dict['model'] - category_id = model_dict['category'] + modelname = model_dict["model"] + category_id = model_dict["category"] try: - category_name = dataset.metadata[category_id].get('name', 'n/a') + category_name = dataset.metadata[category_id].get("name", "n/a") except AttributeError: - category_name = 'n/a' - - if category_id != 'n/a': + category_name = "n/a" + + if category_id != "n/a": mesh_dir = os.path.join(mesh_dir, str(category_id)) pointcloud_dir = os.path.join(pointcloud_dir, str(category_id)) in_dir = os.path.join(in_dir, str(category_id)) folder_name = str(category_id) - if category_name != 'n/a': - folder_name = str(folder_name) + '_' + category_name.split(',')[0] + if category_name != "n/a": + folder_name = str(folder_name) + "_" + category_name.split(",")[0] generation_vis_dir = os.path.join(generation_vis_dir, folder_name) # Create directories if necessary if vis_n_outputs >= 0 and not os.path.exists(generation_vis_dir): os.makedirs(generation_vis_dir) - + if generate_mesh and not os.path.exists(mesh_dir): os.makedirs(mesh_dir) - + if generate_pointcloud and not os.path.exists(pointcloud_dir): os.makedirs(pointcloud_dir) - + if not os.path.exists(in_dir): os.makedirs(in_dir) # Timing dict time_dict = { - 'idx': idx, - 'class id': category_id, - 'class name': category_name, - 'modelname':modelname, + "idx": idx, + "class id": category_id, + "class name": category_name, + "modelname": modelname, } time_dicts.append(time_dict) # Generate outputs out_file_dict = {} - + if generate_mesh: #! deploy the generator to a separate class out = generator.generate_mesh(data) @@ -158,60 +167,56 @@ def main(): time_dict.update(stats_dict) # Write output - mesh_out_file = os.path.join(mesh_dir, '%s.off' % modelname) + mesh_out_file = os.path.join(mesh_dir, "%s.off" % modelname) export_mesh(mesh_out_file, scale2onet(v), f) - out_file_dict['mesh'] = mesh_out_file - + out_file_dict["mesh"] = mesh_out_file + if generate_pointcloud: - pointcloud_out_file = os.path.join( - pointcloud_dir, '%s.ply' % modelname) + pointcloud_out_file = os.path.join(pointcloud_dir, "%s.ply" % modelname) export_pointcloud(pointcloud_out_file, scale2onet(points), normals) - out_file_dict['pointcloud'] = pointcloud_out_file - - if cfg['generation']['copy_input']: - inputs_path = os.path.join(in_dir, '%s.ply' % modelname) - p = data.get('inputs').to(device) + out_file_dict["pointcloud"] = pointcloud_out_file + + if cfg["generation"]["copy_input"]: + inputs_path = os.path.join(in_dir, "%s.ply" % modelname) + p = data.get("inputs").to(device) export_pointcloud(inputs_path, scale2onet(p)) - out_file_dict['in'] = inputs_path - + out_file_dict["in"] = inputs_path + # Copy to visualization directory for first vis_n_output samples c_it = model_counter[category_id] if c_it < vis_n_outputs: # Save output files - img_name = '%02d.off' % c_it + "%02d.off" % c_it for k, filepath in out_file_dict.items(): ext = os.path.splitext(filepath)[1] - out_file = os.path.join(generation_vis_dir, '%02d_%s%s' - % (c_it, k, ext)) + out_file = os.path.join(generation_vis_dir, "%02d_%s%s" % (c_it, k, ext)) shutil.copyfile(filepath, out_file) - - # Also generate oracle meshes - if cfg['generation']['exp_oracle']: - points_gt = data.get('gt_points').to(device) - normals_gt = data.get('gt_points.normals').to(device) - psr_gt = dpsr(points_gt, normals_gt) - v, f, _ = mc_from_psr(psr_gt, - zero_level=cfg['data']['zero_level']) - out_file = os.path.join(generation_vis_dir, '%02d_%s%s' - % (c_it, 'mesh_oracle', '.off')) - export_mesh(out_file, scale2onet(v), f) - - model_counter[category_id] += 1 + # Also generate oracle meshes + if cfg["generation"]["exp_oracle"]: + points_gt = data.get("gt_points").to(device) + normals_gt = data.get("gt_points.normals").to(device) + psr_gt = dpsr(points_gt, normals_gt) + v, f, _ = mc_from_psr(psr_gt, zero_level=cfg["data"]["zero_level"]) + out_file = os.path.join(generation_vis_dir, "%02d_%s%s" % (c_it, "mesh_oracle", ".off")) + export_mesh(out_file, scale2onet(v), f) + + model_counter[category_id] += 1 # Create pandas dataframe and save time_df = pd.DataFrame(time_dicts) - time_df.set_index(['idx'], inplace=True) + time_df.set_index(["idx"], inplace=True) time_df.to_pickle(out_time_file) # Create pickle files with main statistics - time_df_class = time_df.groupby(by=['class name']).mean() - time_df_class.loc['mean'] = time_df_class.mean() + time_df_class = time_df.groupby(by=["class name"]).mean() + time_df_class.loc["mean"] = time_df_class.mean() time_df_class.to_pickle(out_time_file_class) # Print results - print('Timings [s]:') + print("Timings [s]:") print(time_df_class) -if __name__ == '__main__': - main() \ No newline at end of file + +if __name__ == "__main__": + main() diff --git a/optim.py b/optim.py index fa552b0..d7ae4e4 100644 --- a/optim.py +++ b/optim.py @@ -1,79 +1,86 @@ +import argparse +import glob +import os +import shutil +import time + +import numpy as np +import open3d as o3d import torch import trimesh -import shutil, argparse, time, os, glob - -import numpy as np; np.set_printoptions(precision=4) -import open3d as o3d +from plyfile import PlyData +from pytorch3d.io import load_objs_as_meshes +from pytorch3d.ops import sample_points_from_meshes +from pytorch3d.structures import Meshes +from skimage import measure from torch.utils.tensorboard import SummaryWriter -from torchvision.utils import save_image -from torchvision.io import write_video from src.optimization import Trainer -from src.utils import load_config, update_config, initialize_logger, \ - get_learning_rate_schedules, adjust_learning_rate, AverageMeter,\ - update_optimizer, export_pointcloud -from skimage import measure -from plyfile import PlyData -from pytorch3d.ops import sample_points_from_meshes -from pytorch3d.io import load_objs_as_meshes -from pytorch3d.structures import Meshes +from src.utils import ( + AverageMeter, + adjust_learning_rate, + export_pointcloud, + get_learning_rate_schedules, + initialize_logger, + load_config, + update_config, + update_optimizer, +) + +np.set_printoptions(precision=4) def main(): - parser = argparse.ArgumentParser(description='MNIST toy experiment') - parser.add_argument('config', type=str, help='Path to config file.') - parser.add_argument('--no_cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=1457, metavar='S', - help='random seed') - - args, unknown = parser.parse_known_args() - cfg = load_config(args.config, 'configs/default.yaml') + parser = argparse.ArgumentParser(description="MNIST toy experiment") + parser.add_argument("config", type=str, help="Path to config file.") + parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--seed", type=int, default=1457, metavar="S", help="random seed") + + args, unknown = parser.parse_known_args() + cfg = load_config(args.config, "configs/default.yaml") cfg = update_config(cfg, unknown) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") - data_type = cfg['data']['data_type'] - data_class = cfg['data']['class'] + data_type = cfg["data"]["data_type"] + cfg["data"]["class"] - print(cfg['train']['out_dir']) + print(cfg["train"]["out_dir"]) # PYTORCH VERSION > 1.0.0 - assert(float(torch.__version__.split('.')[-3]) > 0) + assert float(torch.__version__.split(".")[-3]) > 0 # boiler-plate - if cfg['train']['timestamp']: - cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S") + if cfg["train"]["timestamp"]: + cfg["train"]["out_dir"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S") logger = initialize_logger(cfg) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) np.random.seed(args.seed) - shutil.copyfile(args.config, - os.path.join(cfg['train']['out_dir'], 'config.yaml')) + shutil.copyfile(args.config, os.path.join(cfg["train"]["out_dir"], "config.yaml")) # tensorboardX writer - tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log") + tblogdir = os.path.join(cfg["train"]["out_dir"], "tensorboard_log") if not os.path.exists(tblogdir): os.makedirs(tblogdir) - writer = SummaryWriter(log_dir=tblogdir) + SummaryWriter(log_dir=tblogdir) # initialize o3d visualizer vis = None - if cfg['train']['o3d_show']: + if cfg["train"]["o3d_show"]: vis = o3d.visualization.Visualizer() - vis.create_window(width=cfg['train']['o3d_window_size'], - height=cfg['train']['o3d_window_size']) + vis.create_window(width=cfg["train"]["o3d_window_size"], height=cfg["train"]["o3d_window_size"]) # initialize dataset - if data_type == 'point': - if cfg['data']['object_id'] != -1: - data_paths = sorted(glob.glob(cfg['data']['data_path'])) - data_path = data_paths[cfg['data']['object_id']] - print('Loaded %d/%d object' % (cfg['data']['object_id']+1, len(data_paths))) + if data_type == "point": + if cfg["data"]["object_id"] != -1: + data_paths = sorted(glob.glob(cfg["data"]["data_path"])) + data_path = data_paths[cfg["data"]["object_id"]] + print("Loaded %d/%d object" % (cfg["data"]["object_id"] + 1, len(data_paths))) else: - data_path = cfg['data']['data_path'] - print('Data loaded') - ext = data_path.split('.')[-1] - if ext == 'obj': # have GT mesh + data_path = cfg["data"]["data_path"] + print("Data loaded") + ext = data_path.split(".")[-1] + if ext == "obj": # have GT mesh mesh = load_objs_as_meshes([data_path], device=device) # scale the mesh into unit cube verts = mesh.verts_packed() @@ -81,20 +88,15 @@ def main(): center = verts.mean(0) mesh.offset_verts_(-center.expand(N, 3)) scale = max((verts - center).abs().max(0)[0]) - mesh.scale_verts_((1.0 / float(scale))) + mesh.scale_verts_(1.0 / float(scale)) # important for our DPSR to have the range in [0, 1), not reaching 1 mesh.scale_verts_(0.9) - target_pts, target_normals = sample_points_from_meshes(mesh, - num_samples=200000, return_normals=True) - elif ext == 'ply': # only have the point cloud + target_pts, target_normals = sample_points_from_meshes(mesh, num_samples=200000, return_normals=True) + elif ext == "ply": # only have the point cloud plydata = PlyData.read(data_path) - vertices = np.stack([plydata['vertex']['x'], - plydata['vertex']['y'], - plydata['vertex']['z']], axis=1) - normals = np.stack([plydata['vertex']['nx'], - plydata['vertex']['ny'], - plydata['vertex']['nz']], axis=1) + vertices = np.stack([plydata["vertex"]["x"], plydata["vertex"]["y"], plydata["vertex"]["z"]], axis=1) + normals = np.stack([plydata["vertex"]["nx"], plydata["vertex"]["ny"], plydata["vertex"]["nz"]], axis=1) N = vertices.shape[0] center = vertices.mean(0) scale = np.max(np.max(np.abs(vertices - center), axis=0)) @@ -104,212 +106,212 @@ def main(): target_pts = torch.tensor(vertices, device=device)[None].float() target_normals = torch.tensor(normals, device=device)[None].float() - mesh = None # no GT mesh + mesh = None # no GT mesh if not torch.is_tensor(center): center = torch.from_numpy(center) if not torch.is_tensor(scale): scale = torch.from_numpy(np.array([scale])) - data = {'target_points': target_pts, - 'target_normals': target_normals, # normals are never used - 'gt_mesh': mesh} + data = { + "target_points": target_pts, + "target_normals": target_normals, # normals are never used + "gt_mesh": mesh, + } else: raise NotImplementedError # save the input point cloud - if 'target_points' in data.keys(): - outdir_pcl = os.path.join(cfg['train']['out_dir'], 'target_pcl.ply') - if 'target_normals' in data.keys(): - export_pointcloud(outdir_pcl, data['target_points'], data['target_normals']) + if "target_points" in data.keys(): + outdir_pcl = os.path.join(cfg["train"]["out_dir"], "target_pcl.ply") + if "target_normals" in data.keys(): + export_pointcloud(outdir_pcl, data["target_points"], data["target_normals"]) else: - export_pointcloud(outdir_pcl, data['target_points']) + export_pointcloud(outdir_pcl, data["target_points"]) # save oracle PSR mesh (mesh from our PSR using GT point+normals) - if data.get('gt_mesh') is not None: - gt_verts, gt_faces = data['gt_mesh'].get_mesh_verts_faces(0) - pts_gt, norms_gt = sample_points_from_meshes(data['gt_mesh'], - num_samples=500000, return_normals=True) + if data.get("gt_mesh") is not None: + gt_verts, gt_faces = data["gt_mesh"].get_mesh_verts_faces(0) + pts_gt, norms_gt = sample_points_from_meshes(data["gt_mesh"], num_samples=500000, return_normals=True) pts_gt = (pts_gt + 1) / 2 from src.dpsr import DPSR - dpsr_tmp = DPSR(res=(cfg['model']['grid_res'], - cfg['model']['grid_res'], - cfg['model']['grid_res']), - sig=cfg['model']['psr_sigma']).to(device) + + dpsr_tmp = DPSR( + res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]), + sig=cfg["model"]["psr_sigma"], + ).to(device) target = dpsr_tmp(pts_gt, norms_gt).unsqueeze(1).to(device) target = torch.tanh(target) - s = target.shape[-1] # size of psr_grid + s = target.shape[-1] # size of psr_grid psr_grid_numpy = target.squeeze().detach().cpu().numpy() verts, faces, _, _ = measure.marching_cubes(psr_grid_numpy) - verts = verts / s * 2. - 1 # [-1, 1] + verts = verts / s * 2.0 - 1 # [-1, 1] mesh = o3d.geometry.TriangleMesh() mesh.vertices = o3d.utility.Vector3dVector(verts) mesh.triangles = o3d.utility.Vector3iVector(faces) - outdir_mesh = os.path.join(cfg['train']['out_dir'], 'oracle_mesh.ply') + outdir_mesh = os.path.join(cfg["train"]["out_dir"], "oracle_mesh.ply") o3d.io.write_triangle_mesh(outdir_mesh, mesh) # initialize the source point cloud given an input mesh - if 'input_mesh' in cfg['train'].keys() and \ - os.path.isfile(cfg['train']['input_mesh']): - if cfg['train']['input_mesh'].split('/')[-2] == 'mesh': - mesh_tmp = trimesh.load_mesh(cfg['train']['input_mesh']) + if "input_mesh" in cfg["train"].keys() and os.path.isfile(cfg["train"]["input_mesh"]): + if cfg["train"]["input_mesh"].split("/")[-2] == "mesh": + mesh_tmp = trimesh.load_mesh(cfg["train"]["input_mesh"]) verts = torch.from_numpy(mesh_tmp.vertices[None]).float().to(device) faces = torch.from_numpy(mesh_tmp.faces[None]).to(device) mesh = Meshes(verts=verts, faces=faces) - points, normals = sample_points_from_meshes(mesh, - num_samples=cfg['data']['num_points'], return_normals=True) + points, normals = sample_points_from_meshes( + mesh, num_samples=cfg["data"]["num_points"], return_normals=True, + ) # mesh is saved in the original scale of the gt points -= center.float().to(device) points /= scale.float().to(device) points *= 0.9 # make sure the points are within the range of [0, 1) - points = points / 2. + 0.5 + points = points / 2.0 + 0.5 else: # directly initialize from a point cloud - pcd = o3d.io.read_point_cloud(cfg['train']['input_mesh']) + pcd = o3d.io.read_point_cloud(cfg["train"]["input_mesh"]) points = torch.from_numpy(np.array(pcd.points)[None]).float().to(device) normals = torch.from_numpy(np.array(pcd.normals)[None]).float().to(device) points -= center.float().to(device) points /= scale.float().to(device) points *= 0.9 - points = points / 2. + 0.5 - else: #! initialize our source point cloud from a sphere - sphere_radius = cfg['model']['sphere_radius'] - sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius, - count=[256,256]) - points, idx = sphere_mesh.sample(cfg['data']['num_points'], - return_index=True) - points += 0.5 # make sure the points are within the range of [0, 1) + points = points / 2.0 + 0.5 + else: #! initialize our source point cloud from a sphere + sphere_radius = cfg["model"]["sphere_radius"] + sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius, count=[256, 256]) + points, idx = sphere_mesh.sample(cfg["data"]["num_points"], return_index=True) + points += 0.5 # make sure the points are within the range of [0, 1) normals = sphere_mesh.face_normals[idx] points = torch.from_numpy(points).unsqueeze(0).to(device) normals = torch.from_numpy(normals).unsqueeze(0).to(device) - - points = torch.log(points/(1-points)) # inverse sigmoid + points = torch.log(points / (1 - points)) # inverse sigmoid inputs = torch.cat([points, normals], axis=-1).float() inputs.requires_grad = True - model = None # no network - + model = None # no network + # initialize optimizer - cfg['train']['schedule']['pcl']['initial'] = cfg['train']['lr_pcl'] - print('Initial learning rate:', cfg['train']['schedule']['pcl']['initial']) - if 'schedule' in cfg['train']: - lr_schedules = get_learning_rate_schedules(cfg['train']['schedule']) + cfg["train"]["schedule"]["pcl"]["initial"] = cfg["train"]["lr_pcl"] + print("Initial learning rate:", cfg["train"]["schedule"]["pcl"]["initial"]) + if "schedule" in cfg["train"]: + lr_schedules = get_learning_rate_schedules(cfg["train"]["schedule"]) else: lr_schedules = None - optimizer = update_optimizer(inputs, cfg, - epoch=0, model=model, schedule=lr_schedules) + optimizer = update_optimizer(inputs, cfg, epoch=0, model=model, schedule=lr_schedules) try: # load model - state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt')) - if ('pcl' in state_dict.keys()) & (state_dict['pcl'] is not None): - inputs = state_dict['pcl'].to(device) + state_dict = torch.load(os.path.join(cfg["train"]["out_dir"], "model.pt")) + if ("pcl" in state_dict.keys()) & (state_dict["pcl"] is not None): + inputs = state_dict["pcl"].to(device) inputs.requires_grad = True - - optimizer = update_optimizer(inputs, cfg, - epoch=state_dict.get('epoch'), schedule=lr_schedules) - - out = "Load model from epoch %d" % state_dict.get('epoch', 0) + + optimizer = update_optimizer(inputs, cfg, epoch=state_dict.get("epoch"), schedule=lr_schedules) + + out = "Load model from epoch %d" % state_dict.get("epoch", 0) print(out) logger.info(out) except: state_dict = dict() - start_epoch = state_dict.get('epoch', -1) + start_epoch = state_dict.get("epoch", -1) trainer = Trainer(cfg, optimizer, device=device) runtime = {} - runtime['all'] = AverageMeter() - + runtime["all"] = AverageMeter() + # training loop - for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1): - + for epoch in range(start_epoch + 1, cfg["train"]["total_epochs"] + 1): # schedule the learning rate - if (epoch>0) & (lr_schedules is not None): - if (epoch % lr_schedules[0].interval == 0): + if (epoch > 0) & (lr_schedules is not None): + if epoch % lr_schedules[0].interval == 0: adjust_learning_rate(lr_schedules, optimizer, epoch) - if len(lr_schedules) >1: - print('[epoch {}] net_lr: {}, pcl_lr: {}'.format(epoch, - lr_schedules[0].get_learning_rate(epoch), - lr_schedules[1].get_learning_rate(epoch))) + if len(lr_schedules) > 1: + print( + "[epoch {}] net_lr: {}, pcl_lr: {}".format( + epoch, lr_schedules[0].get_learning_rate(epoch), lr_schedules[1].get_learning_rate(epoch), + ), + ) else: - print('[epoch {}] adjust pcl_lr to: {}'.format(epoch, - lr_schedules[0].get_learning_rate(epoch))) + print(f"[epoch {epoch}] adjust pcl_lr to: {lr_schedules[0].get_learning_rate(epoch)}") start = time.time() loss, loss_each = trainer.train_step(data, inputs, model, epoch) - runtime['all'].update(time.time() - start) + runtime["all"].update(time.time() - start) - if epoch % cfg['train']['print_every'] == 0: - log_text = ('[Epoch %02d] loss=%.5f') %(epoch, loss) + if epoch % cfg["train"]["print_every"] == 0: + log_text = ("[Epoch %02d] loss=%.5f") % (epoch, loss) if loss_each is not None: for k, l in loss_each.items(): - if l.item() != 0.: - log_text += (' loss_%s=%.5f') % (k, l.item()) - - log_text += (' time=%.3f / %.3f') % (runtime['all'].val, - runtime['all'].sum) + if l.item() != 0.0: + log_text += f" loss_{k}={l.item():.5f}" + + log_text += (" time={:.3f} / {:.3f}").format(runtime["all"].val, runtime["all"].sum) logger.info(log_text) print(log_text) # visualize point clouds and meshes - if (epoch % cfg['train']['visualize_every'] == 0) & (vis is not None): + if (epoch % cfg["train"]["visualize_every"] == 0) & (vis is not None): trainer.visualize(data, inputs, model, epoch, o3d_vis=vis) - + # save outputs - if epoch % cfg['train']['save_every'] == 0: - trainer.save_mesh_pointclouds(inputs, epoch, - center.cpu().numpy(), - scale.cpu().numpy()*(1/0.9)) + if epoch % cfg["train"]["save_every"] == 0: + trainer.save_mesh_pointclouds(inputs, epoch, center.cpu().numpy(), scale.cpu().numpy() * (1 / 0.9)) # save checkpoints - if (epoch > 0) & (epoch % cfg['train']['checkpoint_every'] == 0): - state = {'epoch': epoch} - pcl = None + if (epoch > 0) & (epoch % cfg["train"]["checkpoint_every"] == 0): + state = {"epoch": epoch} if isinstance(inputs, torch.Tensor): - state['pcl'] = inputs.detach().cpu() - - torch.save(state, os.path.join(cfg['train']['dir_model'], - '%04d' % epoch + '.pt')) + state["pcl"] = inputs.detach().cpu() + + torch.save(state, os.path.join(cfg["train"]["dir_model"], "%04d" % epoch + ".pt")) print("Save new model at epoch %d" % epoch) logger.info("Save new model at epoch %d" % epoch) - torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt')) - + torch.save(state, os.path.join(cfg["train"]["out_dir"], "model.pt")) + # resample and gradually add new points to the source pcl - if (epoch > 0) & \ - (cfg['train']['resample_every']!=0) & \ - (epoch % cfg['train']['resample_every'] == 0) & \ - (epoch < cfg['train']['total_epochs']): - inputs = trainer.point_resampling(inputs) - optimizer = update_optimizer(inputs, cfg, - epoch=epoch, model=model, schedule=lr_schedules) - trainer = Trainer(cfg, optimizer, device=device) - + if ( + (epoch > 0) + & (cfg["train"]["resample_every"] != 0) + & (epoch % cfg["train"]["resample_every"] == 0) + & (epoch < cfg["train"]["total_epochs"]) + ): + inputs = trainer.point_resampling(inputs) + optimizer = update_optimizer(inputs, cfg, epoch=epoch, model=model, schedule=lr_schedules) + trainer = Trainer(cfg, optimizer, device=device) + # visualize the Open3D outputs - if cfg['train']['o3d_show']: - out_video_dir = os.path.join(cfg['train']['out_dir'], - 'vis/o3d/video.mp4') + if cfg["train"]["o3d_show"]: + out_video_dir = os.path.join(cfg["train"]["out_dir"], "vis/o3d/video.mp4") if os.path.isfile(out_video_dir): - os.system('rm {}'.format(out_video_dir)) - os.system('ffmpeg -framerate 30 \ + os.system(f"rm {out_video_dir}") + os.system( + "ffmpeg -framerate 30 \ -start_number 0 \ -i {}/vis/o3d/%04d.jpg \ -pix_fmt yuv420p \ - -crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir)) - out_video_dir = os.path.join(cfg['train']['out_dir'], - 'vis/o3d/video_pcd.mp4') + -crf 17 {}".format( + cfg["train"]["out_dir"], out_video_dir, + ), + ) + out_video_dir = os.path.join(cfg["train"]["out_dir"], "vis/o3d/video_pcd.mp4") if os.path.isfile(out_video_dir): - os.system('rm {}'.format(out_video_dir)) - os.system('ffmpeg -framerate 30 \ + os.system(f"rm {out_video_dir}") + os.system( + "ffmpeg -framerate 30 \ -start_number 0 \ -i {}/vis/o3d/%04d_pcd.jpg \ -pix_fmt yuv420p \ - -crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir)) - print('Video saved.') + -crf 17 {}".format( + cfg["train"]["out_dir"], out_video_dir, + ), + ) + print("Video saved.") -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/optim_hierarchy.py b/optim_hierarchy.py index 23347b9..a8b3a38 100644 --- a/optim_hierarchy.py +++ b/optim_hierarchy.py @@ -1,69 +1,80 @@ -import sys, os import argparse +import os + from src.utils import load_config -import subprocess -os.environ['MKL_THREADING_LAYER'] = 'GNU' + +os.environ["MKL_THREADING_LAYER"] = "GNU" + def main(): + parser = argparse.ArgumentParser(description="MNIST toy experiment") + parser.add_argument("config", type=str, help="Path to config file.") + parser.add_argument("--start_res", type=int, default=-1, help="Resolution to start with.") + parser.add_argument("--object_id", type=int, default=-1, help="Object index.") - parser = argparse.ArgumentParser(description='MNIST toy experiment') - parser.add_argument('config', type=str, help='Path to config file.') - parser.add_argument('--start_res', type=int, default=-1, help='Resolution to start with.') - parser.add_argument('--object_id', type=int, default=-1, help='Object index.') + args, unknown = parser.parse_known_args() + cfg = load_config(args.config, "configs/default.yaml") - args, unknown = parser.parse_known_args() - cfg = load_config(args.config, 'configs/default.yaml') - - resolutions=[32, 64, 128, 256] - iterations=[1000, 1000, 1000, 200] - lrs=[2e-3, 2e-3*0.7, 2e-3*(0.7**2), 2e-3*(0.7**3)] # reduce lr - for idx,(res, iteration, lr) in enumerate(zip(resolutions, iterations, lrs)): - - if rescfg['model']['grid_res']: + if res > cfg["model"]["grid_res"]: continue - psr_sigma= 2 if res<=128 else 3 - + psr_sigma = 2 if res <= 128 else 3 + if res > 128: - psr_sigma = 5 if 'thingi_noisy' in args.config else 3 + psr_sigma = 5 if "thingi_noisy" in args.config else 3 if args.object_id != -1: - out_dir = os.path.join(cfg['train']['out_dir'], 'object_%02d'%args.object_id, 'res_%d'%res) + out_dir = os.path.join(cfg["train"]["out_dir"], "object_%02d" % args.object_id, "res_%d" % res) else: - out_dir = os.path.join(cfg['train']['out_dir'], 'res_%d'%res) - + out_dir = os.path.join(cfg["train"]["out_dir"], "res_%d" % res) + # sample from mesh when resampling is enabled, otherwise reuse the pointcloud - init_shape='mesh' if cfg['train']['resample_every']>0 else 'pointcloud' - - + init_shape = "mesh" if cfg["train"]["resample_every"] > 0 else "pointcloud" + if args.object_id != -1: - input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'], - 'object_%02d'%args.object_id, 'res_%d' % (resolutions[idx-1]), - 'vis', init_shape, '%04d.ply' % (iterations[idx-1])) + input_mesh = ( + "None" + if idx == 0 + else os.path.join( + cfg["train"]["out_dir"], + "object_%02d" % args.object_id, + "res_%d" % (resolutions[idx - 1]), + "vis", + init_shape, + "%04d.ply" % (iterations[idx - 1]), + ) + ) else: - input_mesh='None' if idx==0 else os.path.join(cfg['train']['out_dir'], - 'res_%d' % (resolutions[idx-1]), - 'vis', init_shape, '%04d.ply' % (iterations[idx-1])) - - - cmd = 'export MKL_SERVICE_FORCE_INTEL=1 && ' - cmd += "python optim.py %s --model:grid_res %d --model:psr_sigma %d \ + input_mesh = ( + "None" + if idx == 0 + else os.path.join( + cfg["train"]["out_dir"], + "res_%d" % (resolutions[idx - 1]), + "vis", + init_shape, + "%04d.ply" % (iterations[idx - 1]), + ) + ) + + cmd = "export MKL_SERVICE_FORCE_INTEL=1 && " + cmd += ( + "python optim.py %s --model:grid_res %d --model:psr_sigma %d \ --train:input_mesh %s --train:total_epochs %d \ --train:out_dir %s --train:lr_pcl %f \ - --data:object_id %d" % ( - args.config, - res, - psr_sigma, - input_mesh, - iteration, - out_dir, - lr, - args.object_id) + --data:object_id %d" + % (args.config, res, psr_sigma, input_mesh, iteration, out_dir, lr, args.object_id) + ) print(cmd) os.system(cmd) -if __name__=="__main__": + +if __name__ == "__main__": main() diff --git a/scripts/process_shapenet.py b/scripts/process_shapenet.py index cde672b..98cb491 100644 --- a/scripts/process_shapenet.py +++ b/scripts/process_shapenet.py @@ -1,14 +1,16 @@ -import os -import torch -import time import multiprocessing +import os +import time + import numpy as np +import torch from tqdm import tqdm + from src.dpsr import DPSR -data_path = 'data/ShapeNet' # path for ShapeNet from ONet -base = 'data' # output base directory -dataset_name = 'shapenet_psr' +data_path = "data/ShapeNet" # path for ShapeNet from ONet +base = "data" # output base directory +dataset_name = "shapenet_psr" multiprocess = True njobs = 8 save_pointcloud = True @@ -20,52 +22,56 @@ padding = 1.2 dpsr = DPSR(res=(resolution, resolution, resolution), sig=0) -def process_one(obj): - obj_name = obj.split('/')[-1] - c = obj.split('/')[-2] +def process_one(obj): + obj_name = obj.split("/")[-1] + c = obj.split("/")[-2] # create new for the current object out_path_cur = os.path.join(base, dataset_name, c) out_path_cur_obj = os.path.join(out_path_cur, obj_name) os.makedirs(out_path_cur_obj, exist_ok=True) - gt_path = os.path.join(data_path, c, obj_name, 'pointcloud.npz') + gt_path = os.path.join(data_path, c, obj_name, "pointcloud.npz") data = np.load(gt_path) - points = data['points'] - normals = data['normals'] + points = data["points"] + normals = data["normals"] # normalize the point to [0, 1) points = points / padding + 0.5 # to scale back during inference, we should: #! p = (p - 0.5) * padding - - if save_pointcloud: - outdir = os.path.join(out_path_cur_obj, 'pointcloud.npz') - # np.savez(outdir, points=points, normals=normals) - np.savez(outdir, points=data['points'], normals=data['normals']) - # return - - if save_psr_field: - psr_gt = dpsr(torch.from_numpy(points.astype(np.float32))[None], - torch.from_numpy(normals.astype(np.float32))[None]).squeeze().cpu().numpy().astype(np.float16) - outdir = os.path.join(out_path_cur_obj, 'psr.npz') + if save_pointcloud: + outdir = os.path.join(out_path_cur_obj, "pointcloud.npz") + # np.savez(outdir, points=points, normals=normals) + np.savez(outdir, points=data["points"], normals=data["normals"]) + # return + + if save_psr_field: + psr_gt = ( + dpsr(torch.from_numpy(points.astype(np.float32))[None], torch.from_numpy(normals.astype(np.float32))[None]) + .squeeze() + .cpu() + .numpy() + .astype(np.float16) + ) + + outdir = os.path.join(out_path_cur_obj, "psr.npz") np.savez(outdir, psr=psr_gt) - + def main(c): + print("---------------------------------------") + print(f"Processing {c} {split}") + print("---------------------------------------") - print('---------------------------------------') - print('Processing {} {}'.format(c, split)) - print('---------------------------------------') + for split in ["train", "val", "test"]: + fname = os.path.join(data_path, c, split + ".lst") + with open(fname) as f: + obj_list = f.read().splitlines() - for split in ['train', 'val', 'test']: - fname = os.path.join(data_path, c, split+'.lst') - with open(fname, 'r') as f: - obj_list = f.read().splitlines() - - obj_list = [c+'/'+s for s in obj_list] + obj_list = [c + "/" + s for s in obj_list] if multiprocess: # multiprocessing.set_start_method('spawn', force=True) @@ -81,21 +87,30 @@ def main(c): else: for obj in tqdm(obj_list): process_one(obj) - - print('Done Processing {} {}!'.format(c, split)) - - + + print(f"Done Processing {c} {split}!") + + if __name__ == "__main__": + classes = [ + "02691156", + "02828884", + "02933112", + "02958343", + "03211117", + "03001627", + "03636649", + "03691459", + "04090263", + "04256520", + "04379243", + "04401088", + "04530566", + ] - classes = ['02691156', '02828884', '02933112', - '02958343', '03211117', '03001627', - '03636649', '03691459', '04090263', - '04256520', '04379243', '04401088', '04530566'] - - t_start = time.time() for c in classes: main(c) - + t_end = time.time() - print('Total processing time: ', t_end - t_start) + print("Total processing time: ", t_end - t_start) diff --git a/src/config.py b/src/config.py index 2c318e3..a6999bd 100644 --- a/src/config.py +++ b/src/config.py @@ -1,146 +1,151 @@ -import yaml from torchvision import transforms + from src import data, generation from src.dpsr import DPSR -from ipdb import set_trace as st # Generator for final mesh extraction def get_generator(model, cfg, device, **kwargs): - ''' Returns the generator object. + """Returns the generator object. Args: model (nn.Module): Occupancy Network model cfg (dict): imported yaml config device (device): pytorch device - ''' - - if cfg['generation']['psr_resolution'] == 0: - psr_res = cfg['model']['grid_res'] - psr_sigma = cfg['model']['psr_sigma'] + """ + if cfg["generation"]["psr_resolution"] == 0: + psr_res = cfg["model"]["grid_res"] + psr_sigma = cfg["model"]["psr_sigma"] else: - psr_res = cfg['generation']['psr_resolution'] - psr_sigma = cfg['generation']['psr_sigma'] - - dpsr = DPSR(res=(psr_res, psr_res, psr_res), - sig= psr_sigma).to(device) + psr_res = cfg["generation"]["psr_resolution"] + psr_sigma = cfg["generation"]["psr_sigma"] + + dpsr = DPSR(res=(psr_res, psr_res, psr_res), sig=psr_sigma).to(device) - generator = generation.Generator3D( model, device=device, - threshold=cfg['data']['zero_level'], - sample=cfg['generation']['use_sampling'], - input_type = cfg['data']['input_type'], - padding=cfg['data']['padding'], + threshold=cfg["data"]["zero_level"], + sample=cfg["generation"]["use_sampling"], + input_type=cfg["data"]["input_type"], + padding=cfg["data"]["padding"], dpsr=dpsr, - psr_tanh=cfg['model']['psr_tanh'] + psr_tanh=cfg["model"]["psr_tanh"], ) return generator + # Datasets def get_dataset(mode, cfg, return_idx=False): - ''' Returns the dataset. + """Returns the dataset. Args: model (nn.Module): the model which is used cfg (dict): config dictionary return_idx (bool): whether to include an ID field - ''' - dataset_type = cfg['data']['dataset'] - dataset_folder = cfg['data']['path'] - categories = cfg['data']['class'] + """ + dataset_type = cfg["data"]["dataset"] + dataset_folder = cfg["data"]["path"] + categories = cfg["data"]["class"] # Get split splits = { - 'train': cfg['data']['train_split'], - 'val': cfg['data']['val_split'], - 'test': cfg['data']['test_split'], - 'vis': cfg['data']['val_split'], + "train": cfg["data"]["train_split"], + "val": cfg["data"]["val_split"], + "test": cfg["data"]["test_split"], + "vis": cfg["data"]["val_split"], } split = splits[mode] # Create dataset - if dataset_type == 'Shapes3D': + if dataset_type == "Shapes3D": fields = get_data_fields(mode, cfg) # Input fields inputs_field = get_inputs_field(mode, cfg) if inputs_field is not None: - fields['inputs'] = inputs_field + fields["inputs"] = inputs_field if return_idx: - fields['idx'] = data.IndexField() + fields["idx"] = data.IndexField() dataset = data.Shapes3dDataset( - dataset_folder, fields, + dataset_folder, + fields, split=split, categories=categories, - cfg = cfg + cfg=cfg, ) else: - raise ValueError('Invalid dataset "%s"' % cfg['data']['dataset']) - + raise ValueError('Invalid dataset "%s"' % cfg["data"]["dataset"]) + return dataset def get_inputs_field(mode, cfg): - ''' Returns the inputs fields. + """Returns the inputs fields. Args: mode (str): the mode which is used cfg (dict): config dictionary - ''' - input_type = cfg['data']['input_type'] + """ + input_type = cfg["data"]["input_type"] if input_type is None: inputs_field = None - elif input_type == 'pointcloud': - noise_level = cfg['data']['pointcloud_noise'] - if cfg['data']['pointcloud_outlier_ratio']>0: - transform = transforms.Compose([ - data.SubsamplePointcloud(cfg['data']['pointcloud_n']), - data.PointcloudNoise(noise_level), - data.PointcloudOutliers(cfg['data']['pointcloud_outlier_ratio']) - ]) + elif input_type == "pointcloud": + noise_level = cfg["data"]["pointcloud_noise"] + if cfg["data"]["pointcloud_outlier_ratio"] > 0: + transform = transforms.Compose( + [ + data.SubsamplePointcloud(cfg["data"]["pointcloud_n"]), + data.PointcloudNoise(noise_level), + data.PointcloudOutliers(cfg["data"]["pointcloud_outlier_ratio"]), + ], + ) else: - transform = transforms.Compose([ - data.SubsamplePointcloud(cfg['data']['pointcloud_n']), - data.PointcloudNoise(noise_level) - ]) + transform = transforms.Compose( + [ + data.SubsamplePointcloud(cfg["data"]["pointcloud_n"]), + data.PointcloudNoise(noise_level), + ], + ) - data_type = cfg['data']['data_type'] + data_type = cfg["data"]["data_type"] inputs_field = data.PointCloudField( - cfg['data']['pointcloud_file'], data_type, transform, - multi_files= cfg['data']['multi_files'] - ) + cfg["data"]["pointcloud_file"], + data_type, + transform, + multi_files=cfg["data"]["multi_files"], + ) else: - raise ValueError( - 'Invalid input type (%s)' % input_type) + raise ValueError("Invalid input type (%s)" % input_type) return inputs_field + def get_data_fields(mode, cfg): - ''' Returns the data fields. + """Returns the data fields. Args: mode (str): the mode which is used cfg (dict): imported yaml config - ''' - data_type = cfg['data']['data_type'] + """ + data_type = cfg["data"]["data_type"] fields = {} - - if (mode in ('val', 'test')): + + if mode in ("val", "test"): transform = data.SubsamplePointcloud(100000) else: - transform = data.SubsamplePointcloud(cfg['data']['num_gt_points']) - - data_name = cfg['data']['pointcloud_file'] - fields['gt_points'] = data.PointCloudField(data_name, - transform=transform, data_type=data_type, multi_files=cfg['data']['multi_files']) - if data_type == 'psr_full': - if mode != 'test': - fields['gt_psr'] = data.FullPSRField(multi_files=cfg['data']['multi_files']) - else: - raise ValueError('Invalid data type (%s)' % data_type) + transform = data.SubsamplePointcloud(cfg["data"]["num_gt_points"]) - return fields \ No newline at end of file + data_name = cfg["data"]["pointcloud_file"] + fields["gt_points"] = data.PointCloudField( + data_name, transform=transform, data_type=data_type, multi_files=cfg["data"]["multi_files"], + ) + if data_type == "psr_full": + if mode != "test": + fields["gt_psr"] = data.FullPSRField(multi_files=cfg["data"]["multi_files"]) + else: + raise ValueError("Invalid data type (%s)" % data_type) + + return fields diff --git a/src/data/__init__.py b/src/data/__init__.py index a7f46c1..3ee8142 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -1,14 +1,12 @@ - from src.data.core import ( - Shapes3dDataset, collate_remove_none, worker_init_fn, collate_stack_together -) -from src.data.fields import ( - IndexField, PointCloudField, FullPSRField -) -from src.data.transforms import ( - PointcloudNoise, SubsamplePointcloud, - PointcloudOutliers, + Shapes3dDataset, + collate_remove_none, + collate_stack_together, + worker_init_fn, ) +from src.data.fields import FullPSRField, IndexField, PointCloudField +from src.data.transforms import PointcloudNoise, PointcloudOutliers, SubsamplePointcloud + __all__ = [ # Core Shapes3dDataset, diff --git a/src/data/core.py b/src/data/core.py index 142747c..3541b16 100644 --- a/src/data/core.py +++ b/src/data/core.py @@ -1,45 +1,41 @@ -import os import logging -from torch.utils import data -from pdb import set_trace as st +import os + import numpy as np import yaml - +from torch.utils import data logger = logging.getLogger(__name__) # Fields -class Field(object): - ''' Data fields class. - ''' +class Field: + """Data fields class.""" def load(self, data_path, idx, category): - ''' Loads a data point. + """Loads a data point. Args: data_path (str): path to data file idx (int): index of data point category (int): index of category - ''' + """ raise NotImplementedError def check_complete(self, files): - ''' Checks if set is complete. + """Checks if set is complete. Args: files: files - ''' + """ raise NotImplementedError class Shapes3dDataset(data.Dataset): - ''' 3D Shapes dataset class. - ''' + """3D Shapes dataset class.""" - def __init__(self, dataset_folder, fields, split=None, - categories=None, no_except=True, transform=None, cfg=None): - ''' Initialization of the the 3D shape dataset. + def __init__(self, dataset_folder, fields, split=None, categories=None, no_except=True, transform=None, cfg=None): + """Initialization of the the 3D shape dataset. Args: dataset_folder (str): dataset folder @@ -49,7 +45,7 @@ class Shapes3dDataset(data.Dataset): no_except (bool): no exception transform (callable): transformation applied to data points cfg (yaml): config file - ''' + """ # Attributes self.dataset_folder = dataset_folder self.fields = fields @@ -60,76 +56,69 @@ class Shapes3dDataset(data.Dataset): # If categories is None, use all subfolders if categories is None: categories = os.listdir(dataset_folder) - categories = [c for c in categories - if os.path.isdir(os.path.join(dataset_folder, c))] + categories = [c for c in categories if os.path.isdir(os.path.join(dataset_folder, c))] # Read metadata file - metadata_file = os.path.join(dataset_folder, 'metadata.yaml') + metadata_file = os.path.join(dataset_folder, "metadata.yaml") if os.path.exists(metadata_file): - with open(metadata_file, 'r') as f: + with open(metadata_file) as f: self.metadata = yaml.load(f, Loader=yaml.Loader) else: - self.metadata = { - c: {'id': c, 'name': 'n/a'} for c in categories - } - + self.metadata = {c: {"id": c, "name": "n/a"} for c in categories} + # Set index for c_idx, c in enumerate(categories): - self.metadata[c]['idx'] = c_idx + self.metadata[c]["idx"] = c_idx # Get all models self.models = [] for c_idx, c in enumerate(categories): subpath = os.path.join(dataset_folder, c) if not os.path.isdir(subpath): - logger.warning('Category %s does not exist in dataset.' % c) + logger.warning("Category %s does not exist in dataset." % c) if split is None: self.models += [ - {'category': c, 'model': m} for m in [d for d in os.listdir(subpath) if (os.path.isdir(os.path.join(subpath, d)) and d != '') ] + {"category": c, "model": m} + for m in [d for d in os.listdir(subpath) if (os.path.isdir(os.path.join(subpath, d)) and d != "")] ] else: - split_file = os.path.join(subpath, split + '.lst') - with open(split_file, 'r') as f: - models_c = f.read().split('\n') - - if '' in models_c: - models_c.remove('') + split_file = os.path.join(subpath, split + ".lst") + with open(split_file) as f: + models_c = f.read().split("\n") + + if "" in models_c: + models_c.remove("") + + self.models += [{"category": c, "model": m} for m in models_c] - self.models += [ - {'category': c, 'model': m} - for m in models_c - ] - # precompute self.split = split - + def __len__(self): - ''' Returns the length of the dataset. - ''' + """Returns the length of the dataset.""" return len(self.models) def __getitem__(self, idx): - ''' Returns an item of the dataset. + """Returns an item of the dataset. Args: idx (int): ID of data point - ''' - - category = self.models[idx]['category'] - model = self.models[idx]['model'] - c_idx = self.metadata[category]['idx'] + """ + category = self.models[idx]["category"] + model = self.models[idx]["model"] + c_idx = self.metadata[category]["idx"] model_path = os.path.join(self.dataset_folder, category, model) data = {} info = c_idx - - if self.cfg['data']['multi_files'] is not None: - idx = np.random.randint(self.cfg['data']['multi_files']) - if self.split != 'train': + + if self.cfg["data"]["multi_files"] is not None: + idx = np.random.randint(self.cfg["data"]["multi_files"]) + if self.split != "train": idx = 0 for field_name, field in self.fields.items(): @@ -137,9 +126,8 @@ class Shapes3dDataset(data.Dataset): field_data = field.load(model_path, idx, info) except Exception: if self.no_except: - logger.warn( - 'Error occured when loading field %s of model %s' - % (field_name, model) + logger.warning( + f"Error occured when loading field {field_name} of model {model}", ) return None else: @@ -150,7 +138,7 @@ class Shapes3dDataset(data.Dataset): if k is None: data[field_name] = v else: - data['%s.%s' % (field_name, k)] = v + data[f"{field_name}.{k}"] = v else: data[field_name] = field_data @@ -158,78 +146,76 @@ class Shapes3dDataset(data.Dataset): data = self.transform(data) return data - - + def get_model_dict(self, idx): return self.models[idx] def test_model_complete(self, category, model): - ''' Tests if model is complete. + """Tests if model is complete. Args: model (str): modelname - ''' + """ model_path = os.path.join(self.dataset_folder, category, model) files = os.listdir(model_path) for field_name, field in self.fields.items(): if not field.check_complete(files): - logger.warn('Field "%s" is incomplete: %s' - % (field_name, model_path)) + logger.warning(f'Field "{field_name}" is incomplete: {model_path}') return False return True def collate_remove_none(batch): - ''' Collater that puts each data field into a tensor with outer dimension + """Collater that puts each data field into a tensor with outer dimension batch size. Args: batch: batch - ''' - + """ batch = list(filter(lambda x: x is not None, batch)) return data.dataloader.default_collate(batch) + def collate_stack_together(batch): - ''' Collater that puts each data field into a tensor with outer dimension + """Collater that puts each data field into a tensor with outer dimension batch size. Args: batch: batch - ''' - + """ batch = list(filter(lambda x: x is not None, batch)) keys = batch[0].keys() concat = {} - if len(batch)>1: + if len(batch) > 1: for key in keys: key_val = [item[key] for item in batch] concat[key] = np.concatenate(key_val, axis=0) - if key == 'inputs': + if key == "inputs": n_pts = [item[key].shape[0] for item in batch] - concat['batch_ind'] = np.concatenate( - [i * np.ones(n, dtype=int) for i, n in enumerate(n_pts)], axis=0) + concat["batch_ind"] = np.concatenate([i * np.ones(n, dtype=int) for i, n in enumerate(n_pts)], axis=0) return data.dataloader.default_collate([concat]) else: - n_pts = batch[0]['inputs'].shape[0] - batch[0]['batch_ind'] = np.zeros(n_pts, dtype=int) + n_pts = batch[0]["inputs"].shape[0] + batch[0]["batch_ind"] = np.zeros(n_pts, dtype=int) return data.dataloader.default_collate(batch) def worker_init_fn(worker_id): - ''' Worker init function to ensure true randomness. - ''' + """Worker init function to ensure true randomness.""" + def set_num_threads(nt): - try: - import mkl; mkl.set_num_threads(nt) - except: + try: + import mkl + + mkl.set_num_threads(nt) + except: pass torch.set_num_threads(1) - os.environ['IPC_ENABLE']='1' - for o in ['OPENBLAS_NUM_THREADS','NUMEXPR_NUM_THREADS','OMP_NUM_THREADS','MKL_NUM_THREADS']: + os.environ["IPC_ENABLE"] = "1" + for o in ["OPENBLAS_NUM_THREADS", "NUMEXPR_NUM_THREADS", "OMP_NUM_THREADS", "MKL_NUM_THREADS"]: os.environ[o] = str(nt) random_data = os.urandom(4) diff --git a/src/data/fields.py b/src/data/fields.py index 7f78435..9716dd1 100644 --- a/src/data/fields.py +++ b/src/data/fields.py @@ -1,63 +1,61 @@ import os -import glob -import time -import random -from PIL import Image + import numpy as np -import trimesh + from src.data.core import Field -from pdb import set_trace as st class IndexField(Field): - ''' Basic index field.''' + """Basic index field.""" + def load(self, model_path, idx, category): - ''' Loads the index field. + """Loads the index field. Args: model_path (str): path to model idx (int): ID of data point category (int): index of category - ''' + """ return idx def check_complete(self, files): - ''' Check if field is complete. - + """Check if field is complete. + Args: files: files - ''' + """ return True + class FullPSRField(Field): def __init__(self, transform=None, multi_files=None): self.transform = transform # self.unpackbits = unpackbits self.multi_files = multi_files - - def load(self, model_path, idx, category): + def load(self, model_path, idx, category): # try: # t0 = time.time() if self.multi_files is not None: - psr_path = os.path.join(model_path, 'psr', 'psr_{:02d}.npz'.format(idx)) + psr_path = os.path.join(model_path, "psr", f"psr_{idx:02d}.npz") else: - psr_path = os.path.join(model_path, 'psr.npz') + psr_path = os.path.join(model_path, "psr.npz") psr_dict = np.load(psr_path) # t1 = time.time() - psr = psr_dict['psr'] + psr = psr_dict["psr"] psr = psr.astype(np.float32) # t2 = time.time() # print('load PSR: {:.4f}, change type: {:.4f}, total: {:.4f}'.format(t1 - t0, t2 - t1, t2-t0)) data = {None: psr} - + if self.transform is not None: data = self.transform(data) return data + class PointCloudField(Field): - ''' Point cloud field. + """Point cloud field. It provides the field used for point cloud data. These are the points randomly sampled on the mesh. @@ -66,53 +64,54 @@ class PointCloudField(Field): file_name (str): file name transform (list): list of transformations applied to data points multi_files (callable): number of files - ''' + """ + def __init__(self, file_name, data_type=None, transform=None, multi_files=None, padding=0.1, scale=1.2): self.file_name = file_name - self.data_type = data_type # to make sure the range of input is correct + self.data_type = data_type # to make sure the range of input is correct self.transform = transform self.multi_files = multi_files self.padding = padding self.scale = scale def load(self, model_path, idx, category): - ''' Loads the data point. + """Loads the data point. Args: model_path (str): path to model idx (int): ID of data point category (int): index of category - ''' + """ if self.multi_files is None: file_path = os.path.join(model_path, self.file_name) else: # num = np.random.randint(self.multi_files) # file_path = os.path.join(model_path, self.file_name, '%s_%02d.npz' % (self.file_name, num)) - file_path = os.path.join(model_path, self.file_name, 'pointcloud_%02d.npz' % (idx)) + file_path = os.path.join(model_path, self.file_name, "pointcloud_%02d.npz" % (idx)) pointcloud_dict = np.load(file_path) - points = pointcloud_dict['points'].astype(np.float32) - normals = pointcloud_dict['normals'].astype(np.float32) - + points = pointcloud_dict["points"].astype(np.float32) + normals = pointcloud_dict["normals"].astype(np.float32) + data = { None: points, - 'normals': normals, + "normals": normals, } if self.transform is not None: data = self.transform(data) - - if self.data_type == 'psr_full': + + if self.data_type == "psr_full": # scale the point cloud to the range of (0, 1) data[None] = data[None] / self.scale + 0.5 return data def check_complete(self, files): - ''' Check if field is complete. - + """Check if field is complete. + Args: files: files - ''' - complete = (self.file_name in files) + """ + complete = self.file_name in files return complete diff --git a/src/data/transforms.py b/src/data/transforms.py index 8909594..2b874ab 100644 --- a/src/data/transforms.py +++ b/src/data/transforms.py @@ -2,24 +2,24 @@ import numpy as np # Transforms -class PointcloudNoise(object): - ''' Point cloud noise transformation class. +class PointcloudNoise: + """Point cloud noise transformation class. It adds noise to point cloud data. Args: stddev (int): standard deviation - ''' + """ def __init__(self, stddev): self.stddev = stddev def __call__(self, data): - ''' Calls the transformation. + """Calls the transformation. Args: data (dictionary): data dictionary - ''' + """ data_out = data.copy() points = data[None] noise = self.stddev * np.random.randn(*points.shape) @@ -27,60 +27,63 @@ class PointcloudNoise(object): data_out[None] = points + noise return data_out -class PointcloudOutliers(object): - ''' Point cloud outlier transformation class. + +class PointcloudOutliers: + """Point cloud outlier transformation class. It adds outliers to point cloud data. Args: ratio (int): outlier percentage to the entire point cloud - ''' + """ def __init__(self, ratio): self.ratio = ratio def __call__(self, data): - ''' Calls the transformation. + """Calls the transformation. Args: data (dictionary): data dictionary - ''' + """ data_out = data.copy() points = data[None] n_points = points.shape[0] - n_outlier_points = int(n_points*self.ratio) + n_outlier_points = int(n_points * self.ratio) ind = np.random.randint(0, n_points, n_outlier_points) - + outliers = np.random.uniform(-0.55, 0.55, (n_outlier_points, 3)) outliers = outliers.astype(np.float32) points[ind] = outliers data_out[None] = points return data_out -class SubsamplePointcloud(object): - ''' Point cloud subsampling transformation class. + +class SubsamplePointcloud: + """Point cloud subsampling transformation class. It subsamples the point cloud data. Args: N (int): number of points to be subsampled - ''' + """ + def __init__(self, N): self.N = N def __call__(self, data): - ''' Calls the transformation. + """Calls the transformation. Args: data (dict): data dictionary - ''' + """ data_out = data.copy() points = data[None] indices = np.random.randint(points.shape[0], size=self.N) data_out[None] = points[indices, :] - if 'normals' in data.keys(): - normals = data['normals'] - data_out['normals'] = normals[indices, :] + if "normals" in data.keys(): + normals = data["normals"] + data_out["normals"] = normals[indices, :] - return data_out \ No newline at end of file + return data_out diff --git a/src/data_loader.py b/src/data_loader.py index 75ea59c..d6822a0 100644 --- a/src/data_loader.py +++ b/src/data_loader.py @@ -1,17 +1,20 @@ import os -import cv2 -import torch -import numpy as np from glob import glob -from torch.utils import data -from src.utils import load_rgb, load_mask, get_camera_params + +import cv2 +import numpy as np +import torch from pytorch3d.renderer import PerspectiveCameras from skimage import img_as_float32 +from torch.utils import data + +from src.utils import get_camera_params, load_mask, load_rgb ################################################## # Below are for the differentiable renderer # Taken from https://github.com/lioryariv/idr/blob/main/code/utils/rend_util.py + def load_rgb(path): img = imageio.imread(path) img = img_as_float32(img) @@ -23,6 +26,7 @@ def load_rgb(path): # img = img.transpose(2, 0, 1) return img + def load_mask(path): alpha = imageio.imread(path, as_gray=True) alpha = img_as_float32(alpha) @@ -32,13 +36,13 @@ def load_mask(path): def get_camera_params(uv, pose, intrinsics): - if pose.shape[1] == 7: #In case of quaternion vector representation + if pose.shape[1] == 7: # In case of quaternion vector representation cam_loc = pose[:, 4:] - R = quat_to_rot(pose[:,:4]) - p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float() + R = quat_to_rot(pose[:, :4]) + p = torch.eye(4).repeat(pose.shape[0], 1, 1).cuda().float() p[:, :3, :3] = R p[:, :3, 3] = cam_loc - else: # In case of pose matrix representation + else: # In case of pose matrix representation cam_loc = pose[:, :3, 3] p = pose @@ -60,25 +64,27 @@ def get_camera_params(uv, pose, intrinsics): return ray_dirs, cam_loc + def quat_to_rot(q): batch_size, _ = q.shape q = F.normalize(q, dim=1) - R = torch.ones((batch_size, 3,3)).cuda() - qr=q[:,0] + R = torch.ones((batch_size, 3, 3)).cuda() + qr = q[:, 0] qi = q[:, 1] qj = q[:, 2] qk = q[:, 3] - R[:, 0, 0]=1-2 * (qj**2 + qk**2) - R[:, 0, 1] = 2 * (qj *qi -qk*qr) + R[:, 0, 0] = 1 - 2 * (qj**2 + qk**2) + R[:, 0, 1] = 2 * (qj * qi - qk * qr) R[:, 0, 2] = 2 * (qi * qk + qr * qj) R[:, 1, 0] = 2 * (qj * qi + qk * qr) - R[:, 1, 1] = 1-2 * (qi**2 + qk**2) - R[:, 1, 2] = 2*(qj*qk - qi*qr) - R[:, 2, 0] = 2 * (qk * qi-qj * qr) - R[:, 2, 1] = 2 * (qj*qk + qi*qr) - R[:, 2, 2] = 1-2 * (qi**2 + qj**2) + R[:, 1, 1] = 1 - 2 * (qi**2 + qk**2) + R[:, 1, 2] = 2 * (qj * qk - qi * qr) + R[:, 2, 0] = 2 * (qk * qi - qj * qr) + R[:, 2, 1] = 2 * (qj * qk + qi * qr) + R[:, 2, 2] = 1 - 2 * (qi**2 + qj**2) return R + def lift(x, y, z, intrinsics): # parse intrinsics # intrinsics = intrinsics.cuda() @@ -88,7 +94,16 @@ def lift(x, y, z, intrinsics): cy = intrinsics[:, 1, 2] sk = intrinsics[:, 0, 1] - x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z + x_lift = ( + ( + x + - cx.unsqueeze(-1) + + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1) + - sk.unsqueeze(-1) * y / fy.unsqueeze(-1) + ) + / fx.unsqueeze(-1) + * z + ) y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z # homogeneous @@ -96,21 +111,18 @@ def lift(x, y, z, intrinsics): class PixelNeRFDTUDataset(data.Dataset): - """ - Processed DTU from pixelNeRF - """ - def __init__(self, - data_dir='data/DTU', - scan_id=65, - img_size=None, - device=None, - fixed_scale=0, - ): - data_dir = os.path.join(data_dir, "scan{}".format(scan_id)) - rgb_paths = [ - x for x in glob(os.path.join(data_dir, "image", "*")) - if (x.endswith(".jpg") or x.endswith(".png")) - ] + """Processed DTU from pixelNeRF.""" + + def __init__( + self, + data_dir="data/DTU", + scan_id=65, + img_size=None, + device=None, + fixed_scale=0, + ): + data_dir = os.path.join(data_dir, f"scan{scan_id}") + rgb_paths = [x for x in glob(os.path.join(data_dir, "image", "*")) if (x.endswith((".jpg", ".png")))] rgb_paths = sorted(rgb_paths) mask_paths = sorted(glob(os.path.join(data_dir, "mask", "*.png"))) if len(mask_paths) == 0: @@ -129,27 +141,24 @@ class PixelNeRFDTUDataset(data.Dataset): all_T = [] for idx, (rgb_path, mask_path) in enumerate(zip(rgb_paths, mask_paths)): - i = sel_indices[idx] rgb = load_rgb(rgb_path) mask = load_mask(mask_path) - rgb[~mask] = 0. + rgb[~mask] = 0.0 rgb = torch.from_numpy(rgb).float().to(device) mask = torch.from_numpy(mask).float().to(device) - x_scale = y_scale = 1.0 - xy_delta = 0.0 P = all_cam["world_mat_" + str(i)] P = P[:3] # scale the original shape to really [-0.9, 0.9] - if fixed_scale!=0.: + if fixed_scale != 0.0: scale_mat_new = np.eye(4, 4) - scale_mat_new[:3, :3] *= fixed_scale # scale to [-0.9, 0.9] + scale_mat_new[:3, :3] *= fixed_scale # scale to [-0.9, 0.9] P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)] @ scale_mat_new else: P = all_cam["world_mat_" + str(i)] @ all_cam["scale_mat_" + str(i)] - + P = P[:3, :4] K, R, t = cv2.decomposeProjectionMatrix(P)[:3] K = K / K[2, 2] @@ -158,38 +167,34 @@ class PixelNeRFDTUDataset(data.Dataset): ########!!!!! RR = torch.from_numpy(R).permute(1, 0).unsqueeze(0) - tt = torch.from_numpy(-R@(t[:3] / t[3])).permute(1, 0) + tt = torch.from_numpy(-R @ (t[:3] / t[3])).permute(1, 0) focal = torch.tensor((fx, fy), dtype=torch.float32).unsqueeze(0) pc = torch.tensor((cx, cy), dtype=torch.float32).unsqueeze(0) im_size = (rgb.shape[1], rgb.shape[0]) - + # check https://pytorch3d.org/docs/cameras for how to transform from screen to NDC s = min(im_size) - focal[:, 0] = focal[:, 0] * 2 / (s-1) - focal[:, 1] = focal[:, 1] * 2 /(s-1) - pc[:, 0] = -(pc[:, 0] - (im_size[0]-1)/2) * 2 / (s-1) - pc[:, 1] = -(pc[:, 1] - (im_size[1]-1)/2) * 2 / (s-1) + focal[:, 0] = focal[:, 0] * 2 / (s - 1) + focal[:, 1] = focal[:, 1] * 2 / (s - 1) + pc[:, 0] = -(pc[:, 0] - (im_size[0] - 1) / 2) * 2 / (s - 1) + pc[:, 1] = -(pc[:, 1] - (im_size[1] - 1) / 2) * 2 / (s - 1) - camera = PerspectiveCameras(focal_length=-focal, principal_point=pc, - device=device, R=RR, T=tt) + camera = PerspectiveCameras(focal_length=-focal, principal_point=pc, device=device, R=RR, T=tt) # calculate camera rays uv = uv_creation(im_size)[None].float() pose = np.eye(4, dtype=np.float32) pose[:3, :3] = R.transpose() - pose[:3,3] = (t[:3] / t[3])[:,0] + pose[:3, 3] = (t[:3] / t[3])[:, 0] pose = torch.from_numpy(pose)[None].float() intrinsics = np.eye(4) intrinsics[:3, :3] = K - intrinsics[0, 1] = 0. #! remove skew for now + intrinsics[0, 1] = 0.0 #! remove skew for now intrinsics = torch.from_numpy(intrinsics)[None].float() - rays, _ = get_camera_params(uv, pose, intrinsics) rays = -rays.to(device) - - all_poses.append(camera) all_imgs.append(rgb) all_masks.append(mask) @@ -198,7 +203,7 @@ class PixelNeRFDTUDataset(data.Dataset): # only for neural renderer all_K.append(torch.tensor(K).to(device)) all_R.append(torch.tensor(R).to(device)) - all_T.append(torch.tensor(t[:3]/t[3]).to(device)) + all_T.append(torch.tensor(t[:3] / t[3]).to(device)) all_imgs = torch.stack(all_imgs) all_masks = torch.stack(all_masks) @@ -210,19 +215,20 @@ class PixelNeRFDTUDataset(data.Dataset): all_T = torch.stack(all_T).permute(0, 2, 1).float() uv = uv_creation((all_imgs.size(2), all_imgs.size(1))) - self.data = {'rgbs': all_imgs, - 'masks': all_masks, - 'poses': all_poses, - 'rays': all_rays, - 'uv': uv, - 'light_pose': all_light_pose, # for rendering lights - 'K': all_K, - 'R': all_R, - 'T': all_T, - } - + self.data = { + "rgbs": all_imgs, + "masks": all_masks, + "poses": all_poses, + "rays": all_rays, + "uv": uv, + "light_pose": all_light_pose, # for rendering lights + "K": all_K, + "R": all_R, + "T": all_T, + } + def __len__(self): return 1 def __getitem__(self, idx): - return self.data \ No newline at end of file + return self.data diff --git a/src/dpsr.py b/src/dpsr.py index ff65287..b24892a 100644 --- a/src/dpsr.py +++ b/src/dpsr.py @@ -1,17 +1,17 @@ - -import torch -import torch.nn as nn -from src.utils import spec_gaussian_filter, fftfreqs, img, grid_interp, point_rasterize import numpy as np +import torch import torch.fft +import torch.nn as nn + +from src.utils import fftfreqs, grid_interp, img, point_rasterize, spec_gaussian_filter + class DPSR(nn.Module): def __init__(self, res, sig=10, scale=True, shift=True): - """ - :param res: tuple of output field resolution. eg., (128,128) + """:param res: tuple of output field resolution. eg., (128,128) :param sig: degree of gaussian smoothing """ - super(DPSR, self).__init__() + super().__init__() self.res = res self.sig = sig self.dim = len(res) @@ -22,45 +22,44 @@ class DPSR(nn.Module): self.scale = scale self.shift = shift self.register_buffer("G", G) - + def forward(self, V, N): - """ - :param V: (batch, nv, 2 or 3) tensor for point cloud coordinates + """:param V: (batch, nv, 2 or 3) tensor for point cloud coordinates :param N: (batch, nv, 2 or 3) tensor for point normals :return phi: (batch, res, res, ...) tensor of output indicator function field """ - assert(V.shape == N.shape) # [b, nv, ndims] + assert V.shape == N.shape # [b, nv, ndims] ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2] - - ras_s = torch.fft.rfftn(ras_p, dim=(2,3,4)) - ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+1))+[self.dim+1, 1])) - N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1] - omega = fftfreqs(self.res, dtype=torch.float32).unsqueeze(-1) # [dim0, dim1, dim2/2+1, n_dim, 1] + ras_s = torch.fft.rfftn(ras_p, dim=(2, 3, 4)) + ras_s = ras_s.permute(*tuple([0, *list(range(2, self.dim + 1)), self.dim + 1, 1])) + N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1] + + omega = fftfreqs(self.res, dtype=torch.float32).unsqueeze(-1) # [dim0, dim1, dim2/2+1, n_dim, 1] omega *= 2 * np.pi # normalize frequencies omega = omega.to(V.device) - + DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2) - - Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1] - Phi = DivN / (Lap+1e-6) # [b, dim0, dim1, dim2/2+1, 2] - Phi = Phi.permute(*tuple([list(range(1,self.dim+2)) + [0]])) # [dim0, dim1, dim2/2+1, 2, b] + + Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1] + Phi = DivN / (Lap + 1e-6) # [b, dim0, dim1, dim2/2+1, 2] + Phi = Phi.permute(*tuple([[*list(range(1, self.dim + 2)), 0]])) # [dim0, dim1, dim2/2+1, 2, b] Phi[tuple([0] * self.dim)] = 0 - Phi = Phi.permute(*tuple([[self.dim+1] + list(range(self.dim+1))])) # [b, dim0, dim1, dim2/2+1, 2] - - phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1,2,3)) - + Phi = Phi.permute(*tuple([[self.dim + 1, *list(range(self.dim + 1))]])) # [b, dim0, dim1, dim2/2+1, 2] + + phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1, 2, 3)) + if self.shift or self.scale: # ensure values at points are zero - fv = grid_interp(phi.unsqueeze(-1), V, batched=True).squeeze(-1) # [b, nv] - if self.shift: # offset points to have mean of 0 - offset = torch.mean(fv, dim=-1) # [b,] + fv = grid_interp(phi.unsqueeze(-1), V, batched=True).squeeze(-1) # [b, nv] + if self.shift: # offset points to have mean of 0 + offset = torch.mean(fv, dim=-1) # [b,] phi -= offset.view(*tuple([-1] + [1] * self.dim)) - - phi = phi.permute(*tuple([list(range(1,self.dim+1)) + [0]])) + + phi = phi.permute(*tuple([[*list(range(1, self.dim + 1)), 0]])) fv0 = phi[tuple([0] * self.dim)] # [b,] - phi = phi.permute(*tuple([[self.dim] + list(range(self.dim))])) - + phi = phi.permute(*tuple([[self.dim, *list(range(self.dim))]])) + if self.scale: - phi = -phi / torch.abs(fv0.view(*tuple([-1]+[1] * self.dim))) *0.5 - return phi \ No newline at end of file + phi = -phi / torch.abs(fv0.view(*tuple([-1] + [1] * self.dim))) * 0.5 + return phi diff --git a/src/eval.py b/src/eval.py index 5ce65fe..336a09d 100644 --- a/src/eval.py +++ b/src/eval.py @@ -1,43 +1,45 @@ import logging + import numpy as np -import trimesh from pykdtree.kdtree import KDTree EMPTY_PCL_DICT = { - 'completeness': np.sqrt(3), - 'accuracy': np.sqrt(3), - 'completeness2': 3, - 'accuracy2': 3, - 'chamfer': 6, + "completeness": np.sqrt(3), + "accuracy": np.sqrt(3), + "completeness2": 3, + "accuracy2": 3, + "chamfer": 6, } EMPTY_PCL_DICT_NORMALS = { - 'normals completeness': -1., - 'normals accuracy': -1., - 'normals': -1., + "normals completeness": -1.0, + "normals accuracy": -1.0, + "normals": -1.0, } logger = logging.getLogger(__name__) -class MeshEvaluator(object): - ''' Mesh evaluation class. +class MeshEvaluator: + """Mesh evaluation class. It handles the mesh evaluation process. + Args: - n_points (int): number of points to be used for evaluation - ''' + n_points (int): number of points to be used for evaluation. + """ def __init__(self, n_points=100000): self.n_points = n_points - def eval_mesh(self, mesh, pointcloud_tgt, normals_tgt, thresholds=np.linspace(1./1000, 1, 1000)): - ''' Evaluates a mesh. + def eval_mesh(self, mesh, pointcloud_tgt, normals_tgt, thresholds=np.linspace(1.0 / 1000, 1, 1000)): + """Evaluates a mesh. + Args: mesh (trimesh): mesh which should be evaluated pointcloud_tgt (numpy array): target point cloud normals_tgt (numpy array): target normals - thresholds (numpy arry): for F-Score - ''' + thresholds (numpy arry): for F-Score. + """ if len(mesh.vertices) != 0 and len(mesh.faces) != 0: pointcloud, idx = mesh.sample(self.n_points, return_index=True) @@ -47,25 +49,25 @@ class MeshEvaluator(object): pointcloud = np.empty((0, 3)) normals = np.empty((0, 3)) - out_dict = self.eval_pointcloud( - pointcloud, pointcloud_tgt, normals, normals_tgt, thresholds=thresholds) + out_dict = self.eval_pointcloud(pointcloud, pointcloud_tgt, normals, normals_tgt, thresholds=thresholds) return out_dict - def eval_pointcloud(self, pointcloud, pointcloud_tgt, - normals=None, normals_tgt=None, - thresholds=np.linspace(1./1000, 1, 1000)): - ''' Evaluates a point cloud. + def eval_pointcloud( + self, pointcloud, pointcloud_tgt, normals=None, normals_tgt=None, thresholds=np.linspace(1.0 / 1000, 1, 1000), + ): + """Evaluates a point cloud. + Args: pointcloud (numpy array): predicted point cloud pointcloud_tgt (numpy array): target point cloud normals (numpy array): predicted normals normals_tgt (numpy array): target normals - thresholds (numpy array): threshold values for the F-score calculation - ''' + thresholds (numpy array): threshold values for the F-score calculation. + """ # Return maximum losses if pointcloud is empty if pointcloud.shape[0] == 0: - logger.warn('Empty pointcloud / mesh detected!') + logger.warning("Empty pointcloud / mesh detected!") out_dict = EMPTY_PCL_DICT.copy() if normals is not None and normals_tgt is not None: out_dict.update(EMPTY_PCL_DICT_NORMALS) @@ -74,11 +76,13 @@ class MeshEvaluator(object): pointcloud = np.asarray(pointcloud) pointcloud_tgt = np.asarray(pointcloud_tgt) - # Completeness: how far are the points of the target point cloud # from thre predicted point cloud completeness, completeness_normals = distance_p2p( - pointcloud_tgt, normals_tgt, pointcloud, normals + pointcloud_tgt, + normals_tgt, + pointcloud, + normals, ) recall = get_threshold_percentage(completeness, thresholds) completeness2 = completeness**2 @@ -90,7 +94,10 @@ class MeshEvaluator(object): # Accuracy: how far are th points of the predicted pointcloud # from the target pointcloud accuracy, accuracy_normals = distance_p2p( - pointcloud, normals, pointcloud_tgt, normals_tgt + pointcloud, + normals, + pointcloud_tgt, + normals_tgt, ) precision = get_threshold_percentage(accuracy, thresholds) accuracy2 = accuracy**2 @@ -101,68 +108,61 @@ class MeshEvaluator(object): # Chamfer distance chamferL2 = 0.5 * (completeness2 + accuracy2) - normals_correctness = ( - 0.5 * completeness_normals + 0.5 * accuracy_normals - ) + normals_correctness = 0.5 * completeness_normals + 0.5 * accuracy_normals chamferL1 = 0.5 * (completeness + accuracy) # F-Score - F = [ - 2 * precision[i] * recall[i] / (precision[i] + recall[i]) - for i in range(len(precision)) - ] + F = [2 * precision[i] * recall[i] / (precision[i] + recall[i]) for i in range(len(precision))] out_dict = { - 'completeness': completeness, - 'accuracy': accuracy, - 'normals completeness': completeness_normals, - 'normals accuracy': accuracy_normals, - 'normals': normals_correctness, - 'completeness2': completeness2, - 'accuracy2': accuracy2, - 'chamfer-L2': chamferL2, - 'chamfer-L1': chamferL1, - 'f-score': F[9], # threshold = 1.0% - 'f-score-15': F[14], # threshold = 1.5% - 'f-score-20': F[19], # threshold = 2.0% + "completeness": completeness, + "accuracy": accuracy, + "normals completeness": completeness_normals, + "normals accuracy": accuracy_normals, + "normals": normals_correctness, + "completeness2": completeness2, + "accuracy2": accuracy2, + "chamfer-L2": chamferL2, + "chamfer-L1": chamferL1, + "f-score": F[9], # threshold = 1.0% + "f-score-15": F[14], # threshold = 1.5% + "f-score-20": F[19], # threshold = 2.0% } return out_dict def distance_p2p(points_src, normals_src, points_tgt, normals_tgt): - ''' Computes minimal distances of each point in points_src to points_tgt. + """Computes minimal distances of each point in points_src to points_tgt. + Args: points_src (numpy array): source points normals_src (numpy array): source normals points_tgt (numpy array): target points - normals_tgt (numpy array): target normals - ''' + normals_tgt (numpy array): target normals. + """ kdtree = KDTree(points_tgt) dist, idx = kdtree.query(points_src) if normals_src is not None and normals_tgt is not None: - normals_src = \ - normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True) - normals_tgt = \ - normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True) + normals_src = normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True) + normals_tgt = normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True) normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1) # Handle normals that point into wrong direction gracefully # (mostly due to mehtod not caring about this in generation) normals_dot_product = np.abs(normals_dot_product) else: - normals_dot_product = np.array( - [np.nan] * points_src.shape[0], dtype=np.float32) + normals_dot_product = np.array([np.nan] * points_src.shape[0], dtype=np.float32) return dist, normals_dot_product + def get_threshold_percentage(dist, thresholds): - ''' Evaluates a point cloud. + """Evaluates a point cloud. + Args: dist (numpy array): calculated distance - thresholds (numpy array): threshold values for the F-score calculation - ''' - in_threshold = [ - (dist <= t).mean() for t in thresholds - ] - return in_threshold \ No newline at end of file + thresholds (numpy array): threshold values for the F-score calculation. + """ + in_threshold = [(dist <= t).mean() for t in thresholds] + return in_threshold diff --git a/src/generation.py b/src/generation.py index 9abbe6d..33e0e3e 100644 --- a/src/generation.py +++ b/src/generation.py @@ -1,11 +1,12 @@ -import torch import time -import trimesh -import numpy as np + +import torch + from src.utils import mc_from_psr -class Generator3D(object): - ''' Generator class for Occupancy Networks. + +class Generator3D: + """Generator class for Occupancy Networks. It provides functions to generate the final mesh as well refining options. @@ -17,11 +18,20 @@ class Generator3D(object): padding (float): how much padding should be used for MISE sample (bool): whether z should be sampled input_type (str): type of input - ''' + """ - def __init__(self, model, points_batch_size=100000, - threshold=0.5, device=None, padding=0.1, - sample=False, input_type = None, dpsr=None, psr_tanh=True): + def __init__( + self, + model, + points_batch_size=100000, + threshold=0.5, + device=None, + padding=0.1, + sample=False, + input_type=None, + dpsr=None, + psr_tanh=True, + ): self.model = model.to(device) self.points_batch_size = points_batch_size self.threshold = threshold @@ -31,33 +41,32 @@ class Generator3D(object): self.sample = sample self.dpsr = dpsr self.psr_tanh = psr_tanh - + def generate_mesh(self, data, return_stats=True): - ''' Generates the output mesh. + """Generates the output mesh. Args: data (tensor): data tensor return_stats (bool): whether stats should be returned - ''' + """ self.model.eval() device = self.device stats_dict = {} - p = data.get('inputs', torch.empty(1, 0)).to(device) + p = data.get("inputs", torch.empty(1, 0)).to(device) t0 = time.time() points, normals = self.model(p) t1 = time.time() psr_grid = self.dpsr(points, normals) t2 = time.time() - v, f, _ = mc_from_psr(psr_grid, - zero_level=self.threshold) - stats_dict['pcl'] = t1 - t0 - stats_dict['dpsr'] = t2 - t1 - stats_dict['mc'] = time.time() - t2 - stats_dict['total'] = time.time() - t0 + v, f, _ = mc_from_psr(psr_grid, zero_level=self.threshold) + stats_dict["pcl"] = t1 - t0 + stats_dict["dpsr"] = t2 - t1 + stats_dict["mc"] = time.time() - t2 + stats_dict["total"] = time.time() - t0 if return_stats: return v, f, points, normals, stats_dict else: - return v, f, points, normals \ No newline at end of file + return v, f, points, normals diff --git a/src/model.py b/src/model.py index a2e978f..10c209d 100644 --- a/src/model.py +++ b/src/model.py @@ -1,18 +1,17 @@ -import torch -import numpy as np import time -from src.utils import point_rasterize, grid_interp, mc_from_psr, \ -calc_inters_points -from src.dpsr import DPSR + +import torch import torch.nn as nn -from src.network import encoder_dict, decoder_dict + +from src.network import decoder_dict, encoder_dict from src.network.utils import map2local +from src.utils import calc_inters_points, grid_interp, mc_from_psr, point_rasterize + class PSR2Mesh(torch.autograd.Function): @staticmethod def forward(ctx, psr_grid): - """ - In the forward pass we receive a Tensor containing the input and return + """In the forward pass we receive a Tensor containing the input and return a Tensor containing the output. ctx is a context object that can be used to stash information for backward computation. You can cache arbitrary objects for use in the backward pass using the ctx.save_for_backward method. @@ -29,8 +28,7 @@ class PSR2Mesh(torch.autograd.Function): @staticmethod def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals): - """ - In the backward pass we receive a Tensor containing the gradient of the loss + """In the backward pass we receive a Tensor containing the gradient of the loss with respect to the output, and we need to compute the gradient of the loss with respect to the input. """ @@ -39,17 +37,17 @@ class PSR2Mesh(torch.autograd.Function): # matrix multiplication between dL/dV and dV/dPSR # dV/dPSR = - normals grad_vert = torch.matmul(dL_dVertex.permute(1, 0, 2), -normals.permute(1, 2, 0)) - grad_grid = point_rasterize(vert_pts, grad_vert.permute(1, 0, 2), res) # b x 1 x res x res x res - + grad_grid = point_rasterize(vert_pts, grad_vert.permute(1, 0, 2), res) # b x 1 x res x res x res + return grad_grid + class PSR2SurfacePoints(torch.autograd.Function): @staticmethod def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample): verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True) - verts = verts * 2. - 1. # within the range of [-1, 1] + verts = verts * 2.0 - 1.0 # within the range of [-1, 1] - p_all, n_all, mask_all = [], [], [] for i in range(len(poses)): @@ -67,7 +65,6 @@ class PSR2SurfacePoints(torch.autograd.Function): n_inters_all = torch.cat(n_all, dim=0) mask_visible = torch.stack(mask_all, dim=0) - res = torch.tensor(psr_grid.detach().shape[2]) ctx.save_for_backward(p_inters_all, n_inters_all, res) @@ -80,30 +77,31 @@ class PSR2SurfacePoints(torch.autograd.Function): # grad from the p_inters via MLP renderer grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None]) - grad_grid_pts = point_rasterize((pts[None]+1)/2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res - + grad_grid_pts = point_rasterize((pts[None] + 1) / 2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res + return grad_grid_pts, None, None, None, None, None + class Encode2Points(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg - - encoder = cfg['model']['encoder'] - decoder = cfg['model']['decoder'] - dim = cfg['data']['dim'] # input dim - c_dim = cfg['model']['c_dim'] - encoder_kwargs = cfg['model']['encoder_kwargs'] - if encoder_kwargs == None: + + encoder = cfg["model"]["encoder"] + decoder = cfg["model"]["decoder"] + dim = cfg["data"]["dim"] # input dim + c_dim = cfg["model"]["c_dim"] + encoder_kwargs = cfg["model"]["encoder_kwargs"] + if encoder_kwargs is None: encoder_kwargs = {} - decoder_kwargs = cfg['model']['decoder_kwargs'] - padding = cfg['data']['padding'] - self.predict_normal = cfg['model']['predict_normal'] - self.predict_offset = cfg['model']['predict_offset'] + decoder_kwargs = cfg["model"]["decoder_kwargs"] + cfg["data"]["padding"] + self.predict_normal = cfg["model"]["predict_normal"] + self.predict_offset = cfg["model"]["predict_offset"] out_dim = 3 out_dim_offset = 3 - num_offset = cfg['data']['num_offset'] + num_offset = cfg["data"]["num_offset"] # each point predict more than one offset to add output points if num_offset > 1: out_dim_offset = out_dim * num_offset @@ -111,44 +109,40 @@ class Encode2Points(nn.Module): # local mapping self.map2local = None - if cfg['model']['local_coord']: - if 'unet' in encoder_kwargs.keys(): - unit_size = 1 / encoder_kwargs['plane_resolution'] + if cfg["model"]["local_coord"]: + if "unet" in encoder_kwargs.keys(): + unit_size = 1 / encoder_kwargs["plane_resolution"] else: - unit_size = 1 / encoder_kwargs['grid_resolution'] - + unit_size = 1 / encoder_kwargs["grid_resolution"] + local_mapping = map2local(unit_size) self.encoder = encoder_dict[encoder]( - dim=dim, c_dim=c_dim, map2local=local_mapping, - **encoder_kwargs + dim=dim, + c_dim=c_dim, + map2local=local_mapping, + **encoder_kwargs, ) if self.predict_normal: # decoder for normal prediction - self.decoder_normal = decoder_dict[decoder]( - dim=dim, c_dim=c_dim, out_dim=out_dim, - **decoder_kwargs) + self.decoder_normal = decoder_dict[decoder](dim=dim, c_dim=c_dim, out_dim=out_dim, **decoder_kwargs) if self.predict_offset: # decoder for offset prediction self.decoder_offset = decoder_dict[decoder]( - dim=dim, c_dim=c_dim, out_dim=out_dim_offset, - map2local=local_mapping, - **decoder_kwargs) + dim=dim, c_dim=c_dim, out_dim=out_dim_offset, map2local=local_mapping, **decoder_kwargs, + ) + + self.s_off = cfg["model"]["s_offset"] - self.s_off = cfg['model']['s_offset'] - - def forward(self, p): - ''' Performs a forward pass through the network. + """Performs a forward pass through the network. Args: p (tensor): input unoriented points - ''' - + """ time_dict = {} - mask = None - + batch_size = p.size(0) points = p.clone() @@ -168,14 +162,12 @@ class Encode2Points(nn.Module): if self.predict_normal: normals = self.decoder_normal(points, c) t2 = time.perf_counter() - - time_dict['encode'] = t1 - t0 - time_dict['predict'] = t2 - t1 - - points = torch.clamp(points, 0.0, 0.99) - if self.cfg['model']['normal_normalize']: - normals = normals / (normals.norm(dim=-1, keepdim=True)+1e-8) - + time_dict["encode"] = t1 - t0 + time_dict["predict"] = t2 - t1 + + points = torch.clamp(points, 0.0, 0.99) + if self.cfg["model"]["normal_normalize"]: + normals = normals / (normals.norm(dim=-1, keepdim=True) + 1e-8) + return points, normals - \ No newline at end of file diff --git a/src/model_rgb.py b/src/model_rgb.py index 79f20a1..80ac7b2 100644 --- a/src/model_rgb.py +++ b/src/model_rgb.py @@ -1,17 +1,19 @@ import torch +from pytorch3d.renderer import ( + MeshRasterizer, + MeshRenderer, + PerspectiveCameras, + RasterizationSettings, + SoftSilhouetteShader, +) +from pytorch3d.structures import Meshes + from src.network.net_rgb import RenderingNetwork from src.utils import approx_psr_grad -from pytorch3d.renderer import ( - RasterizationSettings, - PerspectiveCameras, - MeshRenderer, - MeshRasterizer, - SoftSilhouetteShader) -from pytorch3d.structures import Meshes def approx_psr_grad(psr_grid, res, normalize=True): - delta_x = delta_y = delta_z = 1/res + delta_x = delta_y = delta_z = 1 / res psr_pad = torch.nn.ReplicationPad3d(1)(psr_grid).squeeze() grad_x = (psr_pad[2:, :, :] - psr_pad[:-2, :, :]) / 2 / delta_x @@ -35,76 +37,77 @@ class SAP2Image(nn.Module): self.psr2sur = PSR2SurfacePoints.apply self.psr2mesh = PSR2Mesh.apply # initialize DPSR - self.dpsr = DPSR(res=(cfg['model']['grid_res'], - cfg['model']['grid_res'], - cfg['model']['grid_res']), - sig=cfg['model']['psr_sigma']) + self.dpsr = DPSR( + res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]), + sig=cfg["model"]["psr_sigma"], + ) self.cfg = cfg - if cfg['train']['l_weight']['rgb'] != 0.: - self.rendering_network = RenderingNetwork(**cfg['model']['renderer']) + if cfg["train"]["l_weight"]["rgb"] != 0.0: + self.rendering_network = RenderingNetwork(**cfg["model"]["renderer"]) - if cfg['train']['l_weight']['mask'] != 0.: + if cfg["train"]["l_weight"]["mask"] != 0.0: # initialize rasterizer sigma = 1e-4 raster_settings_soft = RasterizationSettings( - image_size=img_size, - blur_radius=np.log(1. / 1e-4 - 1.)*sigma, + image_size=img_size, + blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma, faces_per_pixel=150, - perspective_correct=False + perspective_correct=False, ) - # initialize silhouette renderer + # initialize silhouette renderer self.mesh_rasterizer = MeshRenderer( rasterizer=MeshRasterizer( - raster_settings=raster_settings_soft + raster_settings=raster_settings_soft, ), - shader=SoftSilhouetteShader() + shader=SoftSilhouetteShader(), ) self.cfg = cfg self.img_size = img_size - + def forward(self, inputs, data): - points, normals = inputs[...,:3], inputs[...,3:] + points, normals = inputs[..., :3], inputs[..., 3:] points = torch.sigmoid(points) normals = normals / normals.norm(dim=-1, keepdim=True) # DPSR to get grid psr_grid = self.dpsr(points, normals).unsqueeze(1) psr_grid = torch.tanh(psr_grid) - + return self.render_img(psr_grid, data) def render_img(self, psr_grid, data): - - n_views = len(data['masks']) - n_views_per_iter = self.cfg['data']['n_views_per_iter'] - - rgb_render_mode = self.cfg['model']['renderer']['mode'] - uv = data['uv'] + n_views = len(data["masks"]) + n_views_per_iter = self.cfg["data"]["n_views_per_iter"] + + self.cfg["model"]["renderer"]["mode"] + uv = data["uv"] idx = np.random.randint(0, n_views, n_views_per_iter) - pose = [data['poses'][i] for i in idx] - rgb = data['rgbs'][idx] - mask_gt = data['masks'][idx] + pose = [data["poses"][i] for i in idx] + rgb = data["rgbs"][idx] + mask_gt = data["masks"][idx] ray = None pred_rgb = None pred_mask = None - - if self.cfg['train']['l_weight']['rgb'] != 0.: - psr_grad = approx_psr_grad(psr_grid, self.cfg['model']['grid_res']) + + if self.cfg["train"]["l_weight"]["rgb"] != 0.0: + psr_grad = approx_psr_grad(psr_grid, self.cfg["model"]["grid_res"]) p_inters, visible_mask = self.psr2sur(psr_grid, pose, self.img_size, uv, psr_grad, None) n_inters = grid_interp(psr_grad[None], (p_inters.detach()[None] + 1) / 2) fea_interp = None - if 'rays' in data.keys(): - ray = data['rays'].squeeze()[idx][visible_mask] - pred_rgb = self.rendering_network(p_inters, normals=n_inters.squeeze(), view_dirs=ray, feature_vectors=fea_interp) - + if "rays" in data.keys(): + ray = data["rays"].squeeze()[idx][visible_mask] + pred_rgb = self.rendering_network( + p_inters, normals=n_inters.squeeze(), view_dirs=ray, feature_vectors=fea_interp, + ) + # silhouette loss - if self.cfg['train']['l_weight']['mask'] != 0.: + if self.cfg["train"]["l_weight"]["mask"] != 0.0: # build mesh v, f, _ = self.psr2mesh(psr_grid) - v = v * 2. - 1 # within the range of [-1, 1] + v = v * 2.0 - 1 # within the range of [-1, 1] # ! Fast but more GPU usage mesh = Meshes(verts=[v.squeeze()], faces=[f.squeeze()]) if True: @@ -114,11 +117,7 @@ class SAP2Image(nn.Module): T = torch.cat([p.T for p in pose], dim=0) focal = torch.cat([p.focal_length for p in pose], dim=0) pp = torch.cat([p.principal_point for p in pose], dim=0) - pose_cur = PerspectiveCameras( - focal_length=focal, - principal_point=pp, - R=R, T=T, - device=R.device) + pose_cur = PerspectiveCameras(focal_length=focal, principal_point=pp, R=R, T=T, device=R.device) pred_mask = self.mesh_rasterizer(mesh.extend(n_views_per_iter), cameras=pose_cur)[..., 3] else: pred_mask = [] @@ -129,11 +128,11 @@ class SAP2Image(nn.Module): pred_mask = torch.cat(pred_mask, dim=0) output = { - 'rgb': pred_rgb, - 'rgb_gt': rgb, - 'mask': pred_mask, - 'mask_gt': mask_gt, - 'vis_mask': visible_mask, - } - - return output \ No newline at end of file + "rgb": pred_rgb, + "rgb_gt": rgb, + "mask": pred_mask, + "mask_gt": mask_gt, + "vis_mask": visible_mask, + } + + return output diff --git a/src/network/__init__.py b/src/network/__init__.py index f6a0b61..ecc0b2e 100644 --- a/src/network/__init__.py +++ b/src/network/__init__.py @@ -1,8 +1,8 @@ -from src.network import encoder, decoder +from src.network import decoder, encoder encoder_dict = { - 'local_pool_pointnet': encoder.LocalPoolPointnet, + "local_pool_pointnet": encoder.LocalPoolPointnet, } decoder_dict = { - 'simple_local': decoder.LocalDecoder, -} \ No newline at end of file + "simple_local": decoder.LocalDecoder, +} diff --git a/src/network/decoder.py b/src/network/decoder.py index 5fc7bd1..9c40b8f 100644 --- a/src/network/decoder.py +++ b/src/network/decoder.py @@ -1,15 +1,17 @@ -import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np -from ipdb import set_trace as st -from src.network.utils import normalize_3d_coordinate, ResnetBlockFC, \ - normalize_coordinate + +from src.network.utils import ( + ResnetBlockFC, + normalize_3d_coordinate, + normalize_coordinate, +) class LocalDecoder(nn.Module): - ''' Decoder. + """Decoder. Instead of conditioning on global features, on plane/volume local features. + Args: dim (int): input dimension c_dim (int): dimension of latent conditioned code c @@ -17,26 +19,31 @@ class LocalDecoder(nn.Module): n_blocks (int): number of blocks ResNetBlockFC layers leaky (bool): whether to use leaky ReLUs sample_mode (str): sampling feature strategy, bilinear|nearest - padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] - ''' + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]. + """ - def __init__(self, dim=3, c_dim=128, out_dim=3, - hidden_size=256, n_blocks=5, leaky=False, sample_mode='bilinear', padding=0.1, map2local=None): + def __init__( + self, + dim=3, + c_dim=128, + out_dim=3, + hidden_size=256, + n_blocks=5, + leaky=False, + sample_mode="bilinear", + padding=0.1, + map2local=None, + ): super().__init__() self.c_dim = c_dim self.n_blocks = n_blocks if c_dim != 0: - self.fc_c = nn.ModuleList([ - nn.Linear(c_dim, hidden_size) for i in range(n_blocks) - ]) - + self.fc_c = nn.ModuleList([nn.Linear(c_dim, hidden_size) for i in range(n_blocks)]) self.fc_p = nn.Linear(dim, hidden_size) - self.blocks = nn.ModuleList([ - ResnetBlockFC(hidden_size) for i in range(n_blocks) - ]) + self.blocks = nn.ModuleList([ResnetBlockFC(hidden_size) for i in range(n_blocks)]) self.fc_out = nn.Linear(hidden_size, out_dim) @@ -49,46 +56,45 @@ class LocalDecoder(nn.Module): self.padding = padding self.map2local = map2local self.out_dim = out_dim - - def sample_plane_feature(self, p, c, plane='xz'): - xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1) + def sample_plane_feature(self, p, c, plane="xz"): + xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1) xy = xy[:, :, None].float() - vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1) - c = F.grid_sample(c, vgrid, padding_mode='border', - align_corners=True, - mode=self.sample_mode).squeeze(-1) + vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1) + c = F.grid_sample(c, vgrid, padding_mode="border", align_corners=True, mode=self.sample_mode).squeeze(-1) return c def sample_grid_feature(self, p, c): p_nor = normalize_3d_coordinate(p.clone()) p_nor = p_nor[:, :, None, None].float() - vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1) + vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1) # acutally trilinear interpolation if mode = 'bilinear' - c = F.grid_sample(c, vgrid, padding_mode='border', - align_corners=True, - mode=self.sample_mode).squeeze(-1).squeeze(-1) + c = ( + F.grid_sample(c, vgrid, padding_mode="border", align_corners=True, mode=self.sample_mode) + .squeeze(-1) + .squeeze(-1) + ) return c def forward(self, p, c_plane, **kwargs): batch_size = p.shape[0] plane_type = list(c_plane.keys()) c = 0 - if 'grid' in plane_type: - c += self.sample_grid_feature(p, c_plane['grid']) - if 'xz' in plane_type: - c += self.sample_plane_feature(p, c_plane['xz'], plane='xz') - if 'xy' in plane_type: - c += self.sample_plane_feature(p, c_plane['xy'], plane='xy') - if 'yz' in plane_type: - c += self.sample_plane_feature(p, c_plane['yz'], plane='yz') + if "grid" in plane_type: + c += self.sample_grid_feature(p, c_plane["grid"]) + if "xz" in plane_type: + c += self.sample_plane_feature(p, c_plane["xz"], plane="xz") + if "xy" in plane_type: + c += self.sample_plane_feature(p, c_plane["xy"], plane="xy") + if "yz" in plane_type: + c += self.sample_plane_feature(p, c_plane["yz"], plane="yz") c = c.transpose(1, 2) p = p.float() if self.map2local: p = self.map2local(p) - + net = self.fc_p(p) for i in range(self.n_blocks): @@ -98,9 +104,8 @@ class LocalDecoder(nn.Module): net = self.blocks[i](net) out = self.fc_out(self.actvn(net)) - if self.out_dim > 3: out = out.reshape(batch_size, -1, 3) - - return out \ No newline at end of file + + return out diff --git a/src/network/encoder.py b/src/network/encoder.py index 4385a9f..feac77f 100644 --- a/src/network/encoder.py +++ b/src/network/encoder.py @@ -1,17 +1,23 @@ import torch import torch.nn as nn -import numpy as np -from src.network.unet3d import UNet3D +from torch_scatter import scatter_max, scatter_mean + from src.network.unet import UNet -from ipdb import set_trace as st -from torch_scatter import scatter_mean, scatter_max -from src.network.utils import get_embedder, normalize_3d_coordinate,\ -coordinate2index, ResnetBlockFC, normalize_coordinate +from src.network.unet3d import UNet3D +from src.network.utils import ( + ResnetBlockFC, + coordinate2index, + get_embedder, + normalize_3d_coordinate, + normalize_coordinate, +) + class LocalPoolPointnet(nn.Module): - ''' PointNet-based encoder network with ResNet blocks for each point. + """PointNet-based encoder network with ResNet blocks for each point. Number of input points are fixed. - + + Args: c_dim (int): dimension of latent code c dim (int): input points dimension @@ -22,26 +28,38 @@ class LocalPoolPointnet(nn.Module): unet3d (bool): weather to use 3D U-Net unet3d_kwargs (str): 3D U-Net parameters plane_resolution (int): defined resolution for plane feature - grid_resolution (int): defined resolution for grid feature + grid_resolution (int): defined resolution for grid feature plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] n_blocks (int): number of blocks ResNetBlockFC layers map2local (function): map global coordintes to local ones pos_encoding (int): frequency for the positional encoding - ''' + """ - def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max', - unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None, - plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5, - map2local=None, pos_encoding=0): + def __init__( + self, + c_dim=128, + dim=3, + hidden_dim=128, + scatter_type="max", + unet=False, + unet_kwargs=None, + unet3d=False, + unet3d_kwargs=None, + plane_resolution=None, + grid_resolution=None, + plane_type="xz", + padding=0.1, + n_blocks=5, + map2local=None, + pos_encoding=0, + ): super().__init__() self.c_dim = c_dim - self.fc_pos = nn.Linear(dim, 2*hidden_dim) - self.blocks = nn.ModuleList([ - ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks) - ]) + self.fc_pos = nn.Linear(dim, 2 * hidden_dim) + self.blocks = nn.ModuleList([ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks)]) self.fc_c = nn.Linear(hidden_dim, c_dim) self.actvn = nn.ReLU() @@ -59,34 +77,36 @@ class LocalPoolPointnet(nn.Module): self.reso_grid = grid_resolution self.plane_type = plane_type self.padding = padding - - self.fc_pos = nn.Linear(dim, 2*hidden_dim) + + self.fc_pos = nn.Linear(dim, 2 * hidden_dim) self.pe = None if pos_encoding > 0: embed_fn, input_ch = get_embedder(pos_encoding, d_in=dim) self.pe = embed_fn - self.fc_pos = nn.Linear(input_ch, 2*hidden_dim) + self.fc_pos = nn.Linear(input_ch, 2 * hidden_dim) self.map2local = map2local - - if scatter_type == 'max': + if scatter_type == "max": self.scatter = scatter_max - elif scatter_type == 'mean': + elif scatter_type == "mean": self.scatter = scatter_mean else: - raise ValueError('incorrect scatter type') + msg = "incorrect scatter type" + raise ValueError(msg) - def generate_plane_features(self, p, c, plane='xz'): + def generate_plane_features(self, p, c, plane="xz"): # acquire indices of features in plane - xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1) + xy = normalize_coordinate(p.clone(), plane=plane) # normalize to the range of (0, 1) index = coordinate2index(xy, self.reso_plane) # scatter plane features from points fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2) - c = c.permute(0, 2, 1) # B x 512 x T - fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2 - fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso) + c = c.permute(0, 2, 1) # B x 512 x T + fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2 + fea_plane = fea_plane.reshape( + p.size(0), self.c_dim, self.reso_plane, self.reso_plane, + ) # sparce matrix (B x 512 x reso x reso) # process the plane features with UNet if self.unet is not None: @@ -96,12 +116,14 @@ class LocalPoolPointnet(nn.Module): def generate_grid_features(self, p, c): p_nor = normalize_3d_coordinate(p.clone()) - index = coordinate2index(p_nor, self.reso_grid, coord_type='3d') + index = coordinate2index(p_nor, self.reso_grid, coord_type="3d") # scatter grid features from points fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3) c = c.permute(0, 2, 1) - fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3 - fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) # sparce matrix (B x 512 x reso x reso) + fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3 + fea_grid = fea_grid.reshape( + p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid, + ) # sparce matrix (B x 512 x reso x reso) if self.unet3d is not None: fea_grid = self.unet3d(fea_grid) @@ -115,7 +137,7 @@ class LocalPoolPointnet(nn.Module): c_out = 0 for key in keys: # scatter plane features from points - if key == 'grid': + if key == "grid": fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid**3) else: fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2) @@ -126,33 +148,31 @@ class LocalPoolPointnet(nn.Module): c_out += fea return c_out.permute(0, 2, 1) - def forward(self, p, normalize=True): batch_size, T, D = p.size() # acquire the index for each point coord = {} index = {} - if 'xz' in self.plane_type: - coord['xz'] = normalize_coordinate(p.clone(), plane='xz') - index['xz'] = coordinate2index(coord['xz'], self.reso_plane) - if 'xy' in self.plane_type: - coord['xy'] = normalize_coordinate(p.clone(), plane='xy') - index['xy'] = coordinate2index(coord['xy'], self.reso_plane) - if 'yz' in self.plane_type: - coord['yz'] = normalize_coordinate(p.clone(), plane='yz') - index['yz'] = coordinate2index(coord['yz'], self.reso_plane) - if 'grid' in self.plane_type: + if "xz" in self.plane_type: + coord["xz"] = normalize_coordinate(p.clone(), plane="xz") + index["xz"] = coordinate2index(coord["xz"], self.reso_plane) + if "xy" in self.plane_type: + coord["xy"] = normalize_coordinate(p.clone(), plane="xy") + index["xy"] = coordinate2index(coord["xy"], self.reso_plane) + if "yz" in self.plane_type: + coord["yz"] = normalize_coordinate(p.clone(), plane="yz") + index["yz"] = coordinate2index(coord["yz"], self.reso_plane) + if "grid" in self.plane_type: if normalize: - coord['grid'] = normalize_3d_coordinate(p.clone()) + coord["grid"] = normalize_3d_coordinate(p.clone()) else: - coord['grid'] = p.clone()[...,:3] - index['grid'] = coordinate2index(coord['grid'], self.reso_grid, coord_type='3d') + coord["grid"] = p.clone()[..., :3] + index["grid"] = coordinate2index(coord["grid"], self.reso_grid, coord_type="3d") - if self.pe: p = self.pe(p) - + if self.map2local: pp = self.map2local(p) net = self.fc_pos(pp) @@ -169,13 +189,13 @@ class LocalPoolPointnet(nn.Module): c = self.fc_c(net) fea = {} - if 'grid' in self.plane_type: - fea['grid'] = self.generate_grid_features(p, c) - if 'xz' in self.plane_type: - fea['xz'] = self.generate_plane_features(p, c, plane='xz') - if 'xy' in self.plane_type: - fea['xy'] = self.generate_plane_features(p, c, plane='xy') - if 'yz' in self.plane_type: - fea['yz'] = self.generate_plane_features(p, c, plane='yz') + if "grid" in self.plane_type: + fea["grid"] = self.generate_grid_features(p, c) + if "xz" in self.plane_type: + fea["xz"] = self.generate_plane_features(p, c, plane="xz") + if "xy" in self.plane_type: + fea["xy"] = self.generate_plane_features(p, c, plane="xy") + if "yz" in self.plane_type: + fea["yz"] = self.generate_plane_features(p, c, plane="yz") - return fea \ No newline at end of file + return fea diff --git a/src/network/net_rgb.py b/src/network/net_rgb.py index acdde01..3f5a444 100644 --- a/src/network/net_rgb.py +++ b/src/network/net_rgb.py @@ -1,39 +1,41 @@ # code from IDR (https://github.com/lioryariv/idr/blob/main/code/model/implicit_differentiable_renderer.py) + +import numpy as np import torch import torch.nn as nn -import numpy as np + from src.network.utils import get_embedder -from pdb import set_trace as st + class RenderingNetwork(nn.Module): def __init__( - self, - fea_size=0, - mode='naive', - d_out=3, - dims=[512, 512, 512, 512], - weight_norm=True, - pe_freq_view=0 # for positional encoding + self, + fea_size=0, + mode="naive", + d_out=3, + dims=[512, 512, 512, 512], + weight_norm=True, + pe_freq_view=0, # for positional encoding ): super().__init__() - + self.mode = mode - if mode == 'naive': + if mode == "naive": d_in = 3 - elif mode == 'no_feature': + elif mode == "no_feature": d_in = 3 + 3 + 3 fea_size = 0 - elif mode == 'full': + elif mode == "full": d_in = 3 + 3 + 3 else: d_in = 3 + 3 - dims = [d_in + fea_size] + dims + [d_out] + dims = [d_in + fea_size, *dims, d_out] self.embedview_fn = None if pe_freq_view > 0: embedview_fn, input_ch = get_embedder(pe_freq_view, d_in=3) self.embedview_fn = embedview_fn - dims[0] += (input_ch - 3) + dims[0] += input_ch - 3 self.num_layers = len(dims) @@ -54,13 +56,13 @@ class RenderingNetwork(nn.Module): view_dirs = self.embedview_fn(view_dirs) # points = self.embedview_fn(points) - if (self.mode == 'full') & (feature_vectors is not None): + if (self.mode == "full") & (feature_vectors is not None): rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1) - elif (self.mode == 'no_feature') | ((self.mode == 'full') & (feature_vectors is None)): + elif (self.mode == "no_feature") | ((self.mode == "full") & (feature_vectors is None)): rendering_input = torch.cat([points, view_dirs, normals], dim=-1) - elif self.mode == 'no_view_dir': + elif self.mode == "no_view_dir": rendering_input = torch.cat([points, normals], dim=-1) - elif self.mode == 'no_normal': + elif self.mode == "no_normal": rendering_input = torch.cat([points, view_dirs], dim=-1) else: rendering_input = points @@ -81,28 +83,27 @@ class RenderingNetwork(nn.Module): class NeRFRenderingNetwork(nn.Module): def __init__( - self, - feature_vector_size=0, - mode='naive', - d_in=3, - d_out=3, - dims=[512, 512, 512, 256], - weight_norm=True, - multires=0, # positional encoding of points - multires_view=0 # positional encoding of view + self, + feature_vector_size=0, + mode="naive", + d_in=3, + d_out=3, + dims=[512, 512, 512, 256], + weight_norm=True, + multires=0, # positional encoding of points + multires_view=0, # positional encoding of view ): super().__init__() self.mode = mode - dims = [d_in + feature_vector_size] + dims - + dims = [d_in + feature_vector_size, *dims] self.embed_fn = None if multires > 0: embed_fn, input_ch = get_embedder(multires, d_in=d_in) self.embed_fn = embed_fn - dims[0] += (input_ch - 3) - + dims[0] += input_ch - 3 + self.num_layers = len(dims) self.pts_net = nn.ModuleList([nn.Linear(dims[i], dims[i + 1]) for i in range(self.num_layers - 1)]) @@ -113,13 +114,12 @@ class NeRFRenderingNetwork(nn.Module): self.embedview_fn = embedview_fn # dims[0] += (input_ch - 3) - if mode == 'full': - self.view_net = nn.ModuleList([nn.Linear(dims[-1]+view_ch, 128)]) + if mode == "full": + self.view_net = nn.ModuleList([nn.Linear(dims[-1] + view_ch, 128)]) self.rgb_net = nn.Linear(128, 3) - else: + else: self.rgb_net = nn.Linear(dims[-1], 3) - self.relu = nn.ReLU() self.tanh = nn.Tanh() @@ -134,7 +134,7 @@ class NeRFRenderingNetwork(nn.Module): x = net(x) x = self.relu(x) - if self.mode=='full': + if self.mode == "full": x = torch.cat([x, view_dirs], -1) for net in self.view_net: x = net(x) @@ -144,22 +144,23 @@ class NeRFRenderingNetwork(nn.Module): x = self.tanh(x) return x + class ImplicitNetwork(nn.Module): def __init__( - self, - d_in, - d_out, - dims, - geometric_init=True, - feature_vector_size=0, - bias=1.0, - skip_in=(), - weight_norm=True, - multires=0 + self, + d_in, + d_out, + dims, + geometric_init=True, + feature_vector_size=0, + bias=1.0, + skip_in=(), + weight_norm=True, + multires=0, ): super().__init__() - dims = [d_in] + dims + [d_out + feature_vector_size] + dims = [d_in, *dims, d_out + feature_vector_size] self.embed_fn = None if multires > 0: @@ -189,7 +190,7 @@ class ImplicitNetwork(nn.Module): elif multires > 0 and l in self.skip_in: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) - torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) + torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0) else: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) @@ -222,13 +223,9 @@ class ImplicitNetwork(nn.Module): def gradient(self, x): x.requires_grad_(True) - y = self.forward(x)[:,:1] + y = self.forward(x)[:, :1] d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( - outputs=y, - inputs=x, - grad_outputs=d_output, - create_graph=True, - retain_graph=True, - only_inputs=True)[0] - return gradients.unsqueeze(1) \ No newline at end of file + outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True, + )[0] + return gradients.unsqueeze(1) diff --git a/src/network/unet.py b/src/network/unet.py index a58cc31..57a1cc7 100644 --- a/src/network/unet.py +++ b/src/network/unet.py @@ -1,55 +1,38 @@ -''' -Codes are from: -https://github.com/jaxony/unet-pytorch/blob/master/model.py -''' +"""Codes are from: +https://github.com/jaxony/unet-pytorch/blob/master/model.py. +""" +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import init -import numpy as np -def conv3x3(in_channels, out_channels, stride=1, - padding=1, bias=True, groups=1): - return nn.Conv2d( - in_channels, - out_channels, - kernel_size=3, - stride=stride, - padding=padding, - bias=bias, - groups=groups) -def upconv2x2(in_channels, out_channels, mode='transpose'): - if mode == 'transpose': - return nn.ConvTranspose2d( - in_channels, - out_channels, - kernel_size=2, - stride=2) +def conv3x3(in_channels, out_channels, stride=1, padding=1, bias=True, groups=1): + return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, bias=bias, groups=groups) + + +def upconv2x2(in_channels, out_channels, mode="transpose"): + if mode == "transpose": + return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) else: # out_channels is always going to be the same # as in_channels - return nn.Sequential( - nn.Upsample(mode='bilinear', scale_factor=2), - conv1x1(in_channels, out_channels)) + return nn.Sequential(nn.Upsample(mode="bilinear", scale_factor=2), conv1x1(in_channels, out_channels)) + def conv1x1(in_channels, out_channels, groups=1): - return nn.Conv2d( - in_channels, - out_channels, - kernel_size=1, - groups=groups, - stride=1) + return nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=groups, stride=1) class DownConv(nn.Module): - """ - A helper Module that performs 2 convolutions and 1 MaxPool. + """A helper Module that performs 2 convolutions and 1 MaxPool. A ReLU activation follows each convolution. """ + def __init__(self, in_channels, out_channels, pooling=True): - super(DownConv, self).__init__() + super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -71,39 +54,35 @@ class DownConv(nn.Module): class UpConv(nn.Module): - """ - A helper Module that performs 2 convolutions and 1 UpConvolution. + """A helper Module that performs 2 convolutions and 1 UpConvolution. A ReLU activation follows each convolution. """ - def __init__(self, in_channels, out_channels, - merge_mode='concat', up_mode='transpose'): - super(UpConv, self).__init__() + + def __init__(self, in_channels, out_channels, merge_mode="concat", up_mode="transpose"): + super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.merge_mode = merge_mode self.up_mode = up_mode - self.upconv = upconv2x2(self.in_channels, self.out_channels, - mode=self.up_mode) + self.upconv = upconv2x2(self.in_channels, self.out_channels, mode=self.up_mode) - if self.merge_mode == 'concat': - self.conv1 = conv3x3( - 2*self.out_channels, self.out_channels) + if self.merge_mode == "concat": + self.conv1 = conv3x3(2 * self.out_channels, self.out_channels) else: # num of input channels to conv2 is same self.conv1 = conv3x3(self.out_channels, self.out_channels) self.conv2 = conv3x3(self.out_channels, self.out_channels) - def forward(self, from_down, from_up): - """ Forward pass + """Forward pass Arguments: from_down: tensor from the encoder pathway - from_up: upconv'd tensor from the decoder pathway + from_up: upconv'd tensor from the decoder pathway. """ from_up = self.upconv(from_up) - if self.merge_mode == 'concat': + if self.merge_mode == "concat": x = torch.cat((from_up, from_down), 1) else: x = from_up + from_down @@ -113,7 +92,7 @@ class UpConv(nn.Module): class UNet(nn.Module): - """ `UNet` class is based on https://arxiv.org/abs/1505.04597 + """`UNet` class is based on https://arxiv.org/abs/1505.04597. The U-Net is a convolutional encoder-decoder neural network. Contextual spatial information (from the decoding, @@ -135,44 +114,37 @@ class UNet(nn.Module): the tranpose convolution (specified by upmode='transpose') """ - def __init__(self, num_classes, in_channels=3, depth=5, - start_filts=64, up_mode='transpose', - merge_mode='concat', **kwargs): + def __init__( + self, num_classes, in_channels=3, depth=5, start_filts=64, up_mode="transpose", merge_mode="concat", **kwargs, + ): + """Arguments: + in_channels: int, number of channels in the input tensor. + Default is 3 for RGB images. + depth: int, number of MaxPools in the U-Net. + start_filts: int, number of convolutional filters for the + first conv. + up_mode: string, type of upconvolution. Choices: 'transpose' + for transpose convolution or 'upsample' for nearest neighbour + upsampling. """ - Arguments: - in_channels: int, number of channels in the input tensor. - Default is 3 for RGB images. - depth: int, number of MaxPools in the U-Net. - start_filts: int, number of convolutional filters for the - first conv. - up_mode: string, type of upconvolution. Choices: 'transpose' - for transpose convolution or 'upsample' for nearest neighbour - upsampling. - """ - super(UNet, self).__init__() + super().__init__() - if up_mode in ('transpose', 'upsample'): + if up_mode in ("transpose", "upsample"): self.up_mode = up_mode else: - raise ValueError("\"{}\" is not a valid mode for " - "upsampling. Only \"transpose\" and " - "\"upsample\" are allowed.".format(up_mode)) - - if merge_mode in ('concat', 'add'): + msg = f'"{up_mode}" is not a valid mode for upsampling. Only "transpose" and "upsample" are allowed.' + raise ValueError(msg) + + if merge_mode in ("concat", "add"): self.merge_mode = merge_mode else: - raise ValueError("\"{}\" is not a valid mode for" - "merging up and down paths. " - "Only \"concat\" and " - "\"add\" are allowed.".format(up_mode)) + msg = f'"{up_mode}" is not a valid mode formerging up and down paths. Only "concat" and "add" are allowed.' + raise ValueError(msg) # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add' - if self.up_mode == 'upsample' and self.merge_mode == 'add': - raise ValueError("up_mode \"upsample\" is incompatible " - "with merge_mode \"add\" at the moment " - "because it doesn't make sense to use " - "nearest neighbour to reduce " - "depth channels (by half).") + if self.up_mode == "upsample" and self.merge_mode == "add": + msg = 'up_mode "upsample" is incompatible with merge_mode "add" at the moment because it doesn\'t make sense to use nearest neighbour to reduce depth channels (by half).' + raise ValueError(msg) self.num_classes = num_classes self.in_channels = in_channels @@ -185,19 +157,18 @@ class UNet(nn.Module): # create the encoder pathway and add to a list for i in range(depth): ins = self.in_channels if i == 0 else outs - outs = self.start_filts*(2**i) - pooling = True if i < depth-1 else False + outs = self.start_filts * (2**i) + pooling = True if i < depth - 1 else False down_conv = DownConv(ins, outs, pooling=pooling) self.down_convs.append(down_conv) # create the decoder pathway and add to a list # - careful! decoding only requires depth-1 blocks - for i in range(depth-1): + for i in range(depth - 1): ins = outs outs = ins // 2 - up_conv = UpConv(ins, outs, up_mode=up_mode, - merge_mode=merge_mode) + up_conv = UpConv(ins, outs, up_mode=up_mode, merge_mode=merge_mode) self.up_convs.append(up_conv) # add the list of modules to current module @@ -214,12 +185,10 @@ class UNet(nn.Module): init.xavier_normal_(m.weight) init.constant_(m.bias, 0) - def reset_params(self): - for i, m in enumerate(self.modules()): + for _i, m in enumerate(self.modules()): self.weight_init(m) - def forward(self, x): encoder_outs = [] # encoder pathway, save outputs for merging @@ -227,30 +196,31 @@ class UNet(nn.Module): x, before_pool = module(x) encoder_outs.append(before_pool) for i, module in enumerate(self.up_convs): - before_pool = encoder_outs[-(i+2)] + before_pool = encoder_outs[-(i + 2)] x = module(before_pool, x) - + # No softmax is used. This means you need to use # nn.CrossEntropyLoss is your training script, # as this module includes a softmax already. x = self.conv_final(x) return x + if __name__ == "__main__": """ testing """ - model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32) + model = UNet(1, depth=5, merge_mode="concat", in_channels=1, start_filts=32) print(model) print(sum(p.numel() for p in model.parameters())) reso = 176 x = np.zeros((1, 1, reso, reso)) - x[:,:,int(reso/2-1), int(reso/2-1)] = np.nan + x[:, :, int(reso / 2 - 1), int(reso / 2 - 1)] = np.nan x = torch.FloatTensor(x) out = model(x) - print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso))) - + print("%f" % (torch.sum(torch.isnan(out)).detach().cpu().numpy() / (reso * reso))) + # loss = torch.sum(out) # loss.backward() diff --git a/src/network/unet3d.py b/src/network/unet3d.py index 3c2bae2..324fa31 100644 --- a/src/network/unet3d.py +++ b/src/network/unet3d.py @@ -1,16 +1,18 @@ -''' -Code from the 3D UNet implementation: -https://github.com/wolny/pytorch-3dunet/ -''' +"""Code from the 3D UNet implementation: +https://github.com/wolny/pytorch-3dunet/. +""" import importlib +from functools import partial + import torch import torch.nn as nn from torch.nn import functional as F -from functools import partial + from src.network.utils import get_embedder + def number_of_features_per_level(init_channel_number, num_levels): - return [init_channel_number * 2 ** k for k in range(num_levels)] + return [init_channel_number * 2**k for k in range(num_levels)] def conv3d(in_channels, out_channels, kernel_size, bias, padding=1): @@ -18,8 +20,7 @@ def conv3d(in_channels, out_channels, kernel_size, bias, padding=1): def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1): - """ - Create a list of modules with together constitute a single conv layer with non-linearity + """Create a list of modules with together constitute a single conv layer with non-linearity and optional batchnorm/groupnorm. Args: @@ -37,23 +38,23 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi Return: list of tuple (name, module) """ - assert 'c' in order, "Conv layer MUST be present" - assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' + assert "c" in order, "Conv layer MUST be present" + assert order[0] not in "rle", "Non-linearity cannot be the first operation in the layer" modules = [] for i, char in enumerate(order): - if char == 'r': - modules.append(('ReLU', nn.ReLU(inplace=True))) - elif char == 'l': - modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True))) - elif char == 'e': - modules.append(('ELU', nn.ELU(inplace=True))) - elif char == 'c': + if char == "r": + modules.append(("ReLU", nn.ReLU(inplace=True))) + elif char == "l": + modules.append(("LeakyReLU", nn.LeakyReLU(negative_slope=0.1, inplace=True))) + elif char == "e": + modules.append(("ELU", nn.ELU(inplace=True))) + elif char == "c": # add learnable bias only in the absence of batchnorm/groupnorm - bias = not ('g' in order or 'b' in order) - modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding))) - elif char == 'g': - is_before_conv = i < order.index('c') + bias = not ("g" in order or "b" in order) + modules.append(("conv", conv3d(in_channels, out_channels, kernel_size, bias, padding=padding))) + elif char == "g": + is_before_conv = i < order.index("c") if is_before_conv: num_channels = in_channels else: @@ -63,14 +64,16 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi if num_channels < num_groups: num_groups = 1 - assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}' - modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels))) - elif char == 'b': - is_before_conv = i < order.index('c') + assert ( + num_channels % num_groups == 0 + ), f"Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}" + modules.append(("groupnorm", nn.GroupNorm(num_groups=num_groups, num_channels=num_channels))) + elif char == "b": + is_before_conv = i < order.index("c") if is_before_conv: - modules.append(('batchnorm', nn.BatchNorm3d(in_channels))) + modules.append(("batchnorm", nn.BatchNorm3d(in_channels))) else: - modules.append(('batchnorm', nn.BatchNorm3d(out_channels))) + modules.append(("batchnorm", nn.BatchNorm3d(out_channels))) else: raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']") @@ -78,9 +81,8 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi class SingleConv(nn.Sequential): - """ - Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order - of operations can be specified via the `order` parameter + """Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order + of operations can be specified via the `order` parameter. Args: in_channels (int): number of input channels @@ -94,16 +96,15 @@ class SingleConv(nn.Sequential): num_groups (int): number of groups for the GroupNorm """ - def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8, padding=1): - super(SingleConv, self).__init__() + def __init__(self, in_channels, out_channels, kernel_size=3, order="crg", num_groups=8, padding=1): + super().__init__() for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding): self.add_module(name, module) class DoubleConv(nn.Sequential): - """ - A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d). + """A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d). We use (Conv3d+ReLU+GroupNorm3d) by default. This can be changed however by providing the 'order' argument, e.g. in order to change to Conv3d+BatchNorm3d+ELU use order='cbe'. @@ -123,8 +124,8 @@ class DoubleConv(nn.Sequential): num_groups (int): number of groups for the GroupNorm """ - def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='crg', num_groups=8): - super(DoubleConv, self).__init__() + def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order="crg", num_groups=8): + super().__init__() if encoder: # we're in the encoder path conv1_in_channels = in_channels @@ -138,26 +139,27 @@ class DoubleConv(nn.Sequential): conv2_in_channels, conv2_out_channels = out_channels, out_channels # conv1 - self.add_module('SingleConv1', - SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups)) + self.add_module( + "SingleConv1", SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups), + ) # conv2 - self.add_module('SingleConv2', - SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups)) + self.add_module( + "SingleConv2", SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups), + ) class ExtResNetBlock(nn.Module): - """ - Basic UNet block consisting of a SingleConv followed by the residual block. + """Basic UNet block consisting of a SingleConv followed by the residual block. The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number of output channels is compatible with the residual block that follows. This block can be used instead of standard DoubleConv in the Encoder module. - Motivated by: https://arxiv.org/pdf/1706.00120.pdf + Motivated by: https://arxiv.org/pdf/1706.00120.pdf. Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm. """ - def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs): - super(ExtResNetBlock, self).__init__() + def __init__(self, in_channels, out_channels, kernel_size=3, order="cge", num_groups=8, **kwargs): + super().__init__() # first convolution self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) @@ -165,15 +167,16 @@ class ExtResNetBlock(nn.Module): self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual n_order = order - for c in 'rel': - n_order = n_order.replace(c, '') - self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order, - num_groups=num_groups) + for c in "rel": + n_order = n_order.replace(c, "") + self.conv3 = SingleConv( + out_channels, out_channels, kernel_size=kernel_size, order=n_order, num_groups=num_groups, + ) # create non-linearity separately - if 'l' in order: + if "l" in order: self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True) - elif 'e' in order: + elif "e" in order: self.non_linearity = nn.ELU(inplace=True) else: self.non_linearity = nn.ReLU(inplace=True) @@ -194,12 +197,12 @@ class ExtResNetBlock(nn.Module): class Encoder(nn.Module): - """ - A single module from the encoder path consisting of the optional max + """A single module from the encoder path consisting of the optional max pooling layer (one may specify the MaxPool kernel_size to be different than the standard (2,2,2), e.g. if the volumetric data is anisotropic (make sure to use complementary scale_factor in the decoder path) followed by a DoubleConv module. + Args: in_channels (int): number of input channels out_channels (int): number of output channels @@ -210,27 +213,39 @@ class Encoder(nn.Module): basic_module(nn.Module): either ResNetBlock or DoubleConv conv_layer_order (string): determines the order of layers in `DoubleConv` module. See `DoubleConv` for more info. - num_groups (int): number of groups for the GroupNorm + num_groups (int): number of groups for the GroupNorm. """ - def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, - pool_kernel_size=(2, 2, 2), pool_type='max', basic_module=DoubleConv, conv_layer_order='crg', - num_groups=8): - super(Encoder, self).__init__() - assert pool_type in ['max', 'avg'] + def __init__( + self, + in_channels, + out_channels, + conv_kernel_size=3, + apply_pooling=True, + pool_kernel_size=(2, 2, 2), + pool_type="max", + basic_module=DoubleConv, + conv_layer_order="crg", + num_groups=8, + ): + super().__init__() + assert pool_type in ["max", "avg"] if apply_pooling: - if pool_type == 'max': + if pool_type == "max": self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size) else: self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size) else: self.pooling = None - self.basic_module = basic_module(in_channels, out_channels, - encoder=True, - kernel_size=conv_kernel_size, - order=conv_layer_order, - num_groups=num_groups) + self.basic_module = basic_module( + in_channels, + out_channels, + encoder=True, + kernel_size=conv_kernel_size, + order=conv_layer_order, + num_groups=num_groups, + ) def forward(self, x): if self.pooling is not None: @@ -240,9 +255,9 @@ class Encoder(nn.Module): class Decoder(nn.Module): - """ - A single module for decoder path consisting of the upsampling layer + """A single module for decoder path consisting of the upsampling layer (either learned ConvTranspose3d or nearest neighbor interpolation) followed by a basic module (DoubleConv or ExtResNetBlock). + Args: in_channels (int): number of input channels out_channels (int): number of output channels @@ -253,32 +268,56 @@ class Decoder(nn.Module): basic_module(nn.Module): either ResNetBlock or DoubleConv conv_layer_order (string): determines the order of layers in `DoubleConv` module. See `DoubleConv` for more info. - num_groups (int): number of groups for the GroupNorm + num_groups (int): number of groups for the GroupNorm. """ - def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv, - conv_layer_order='crg', num_groups=8, mode='nearest'): - super(Decoder, self).__init__() + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + scale_factor=(2, 2, 2), + basic_module=DoubleConv, + conv_layer_order="crg", + num_groups=8, + mode="nearest", + ): + super().__init__() if basic_module == DoubleConv: # if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining - self.upsampling = Upsampling(transposed_conv=False, in_channels=in_channels, out_channels=out_channels, - kernel_size=kernel_size, scale_factor=scale_factor, mode=mode) + self.upsampling = Upsampling( + transposed_conv=False, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + scale_factor=scale_factor, + mode=mode, + ) # concat joining self.joining = partial(self._joining, concat=True) else: # if basic_module=ExtResNetBlock use transposed convolution upsampling and summation joining - self.upsampling = Upsampling(transposed_conv=True, in_channels=in_channels, out_channels=out_channels, - kernel_size=kernel_size, scale_factor=scale_factor, mode=mode) + self.upsampling = Upsampling( + transposed_conv=True, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + scale_factor=scale_factor, + mode=mode, + ) # sum joining self.joining = partial(self._joining, concat=False) # adapt the number of in_channels for the ExtResNetBlock in_channels = out_channels - self.basic_module = basic_module(in_channels, out_channels, - encoder=False, - kernel_size=kernel_size, - order=conv_layer_order, - num_groups=num_groups) + self.basic_module = basic_module( + in_channels, + out_channels, + encoder=False, + kernel_size=kernel_size, + order=conv_layer_order, + num_groups=num_groups, + ) def forward(self, encoder_features, x): x = self.upsampling(encoder_features=encoder_features, x=x) @@ -295,8 +334,7 @@ class Decoder(nn.Module): class Upsampling(nn.Module): - """ - Upsamples a given multi-channel 3D data using either interpolation or learned transposed convolution. + """Upsamples a given multi-channel 3D data using either interpolation or learned transposed convolution. Args: transposed_conv (bool): if True uses ConvTranspose3d for upsampling, otherwise uses interpolation @@ -310,15 +348,23 @@ class Upsampling(nn.Module): 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest' """ - def __init__(self, transposed_conv, in_channels=None, out_channels=None, kernel_size=3, - scale_factor=(2, 2, 2), mode='nearest'): - super(Upsampling, self).__init__() + def __init__( + self, + transposed_conv, + in_channels=None, + out_channels=None, + kernel_size=3, + scale_factor=(2, 2, 2), + mode="nearest", + ): + super().__init__() if transposed_conv: # make sure that the output size reverses the MaxPool3d from the corresponding encoder - # (D_out = (D_in − 1) ×  stride[0] − 2 ×  padding[0] +  kernel_size[0] +  output_padding[0]) - self.upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor, - padding=1) + # (D_out = (D_in - 1) x stride[0] - 2 x padding[0] + kernel_size[0] + output_padding[0]) + self.upsample = nn.ConvTranspose3d( + in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor, padding=1, + ) else: self.upsample = partial(self._interpolate, mode=mode) @@ -332,13 +378,13 @@ class Upsampling(nn.Module): class FinalConv(nn.Sequential): - """ - A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution + """A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution which reduces the number of channels to 'out_channels'. with the number of output channels 'out_channels // 2' and 'out_channels' respectively. We use (Conv3d+ReLU+GroupNorm3d) by default. This can be change however by providing the 'order' argument, e.g. in order to change to Conv3d+BatchNorm3d+ReLU use order='cbr'. + Args: in_channels (int): number of input channels out_channels (int): number of output channels @@ -346,22 +392,22 @@ class FinalConv(nn.Sequential): order (string): determines the order of layers, e.g. 'cr' -> conv + ReLU 'crg' -> conv + ReLU + groupnorm - num_groups (int): number of groups for the GroupNorm + num_groups (int): number of groups for the GroupNorm. """ - def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8): - super(FinalConv, self).__init__() + def __init__(self, in_channels, out_channels, kernel_size=3, order="crg", num_groups=8): + super().__init__() # conv1 - self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups)) + self.add_module("SingleConv", SingleConv(in_channels, in_channels, kernel_size, order, num_groups)) - # in the last layer a 1×1 convolution reduces the number of output channels to out_channels + # in the last layer a 1x1 convolution reduces the number of output channels to out_channels final_conv = nn.Conv3d(in_channels, out_channels, 1) - self.add_module('final_conv', final_conv) + self.add_module("final_conv", final_conv) + class Abstract3DUNet(nn.Module): - """ - Base class for standard and residual UNet. + """Base class for standard and residual UNet. Args: in_channels (int): number of input channels @@ -391,9 +437,22 @@ class Abstract3DUNet(nn.Module): and the `final_activation` (even if present) won't be applied; default: False """ - def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr', - num_groups=8, num_levels=4, is_segmentation=False, testing=False, pe_freq=0, **kwargs): - super(Abstract3DUNet, self).__init__() + def __init__( + self, + in_channels, + out_channels, + final_sigmoid, + basic_module, + f_maps=64, + layer_order="gcr", + num_groups=8, + num_levels=4, + is_segmentation=False, + testing=False, + pe_freq=0, + **kwargs, + ): + super().__init__() self.testing = testing @@ -411,13 +470,24 @@ class Abstract3DUNet(nn.Module): encoders = [] for i, out_feature_num in enumerate(f_maps): if i == 0: - encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=basic_module, - conv_layer_order=layer_order, num_groups=num_groups) + encoder = Encoder( + in_channels, + out_feature_num, + apply_pooling=False, + basic_module=basic_module, + conv_layer_order=layer_order, + num_groups=num_groups, + ) else: # TODO: adapt for anisotropy in the data, i.e. use proper pooling kernel to make the data isotropic after 1-2 pooling operations # currently pools with a constant kernel: (2, 2, 2) - encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=basic_module, - conv_layer_order=layer_order, num_groups=num_groups) + encoder = Encoder( + f_maps[i - 1], + out_feature_num, + basic_module=basic_module, + conv_layer_order=layer_order, + num_groups=num_groups, + ) encoders.append(encoder) self.encoders = nn.ModuleList(encoders) @@ -434,13 +504,18 @@ class Abstract3DUNet(nn.Module): out_feature_num = reversed_f_maps[i + 1] # TODO: if non-standard pooling was used, make sure to use correct striding for transpose conv # currently strides with a constant stride: (2, 2, 2) - decoder = Decoder(in_feature_num, out_feature_num, basic_module=basic_module, - conv_layer_order=layer_order, num_groups=num_groups) + decoder = Decoder( + in_feature_num, + out_feature_num, + basic_module=basic_module, + conv_layer_order=layer_order, + num_groups=num_groups, + ) decoders.append(decoder) self.decoders = nn.ModuleList(decoders) - # in the last layer a 1×1 convolution reduces the number of output + # in the last layer a 1x1 convolution reduces the number of output # channels to the number of labels self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) @@ -455,7 +530,6 @@ class Abstract3DUNet(nn.Module): self.final_activation = None def forward(self, x): - if self.embed_fn is not None: x = self.embed_fn(x.permute(0, 2, 3, 4, 1)) x = x.permute(0, 4, 1, 2, 3) @@ -488,49 +562,81 @@ class Abstract3DUNet(nn.Module): class UNet3D(Abstract3DUNet): - """ - 3DUnet model from + """3DUnet model from `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation" `. Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder """ - def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', - num_groups=8, num_levels=4, is_segmentation=True, **kwargs): - super(UNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid, - basic_module=DoubleConv, f_maps=f_maps, layer_order=layer_order, - num_groups=num_groups, num_levels=num_levels, is_segmentation=is_segmentation, - **kwargs) + def __init__( + self, + in_channels, + out_channels, + final_sigmoid=True, + f_maps=64, + layer_order="gcr", + num_groups=8, + num_levels=4, + is_segmentation=True, + **kwargs, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + final_sigmoid=final_sigmoid, + basic_module=DoubleConv, + f_maps=f_maps, + layer_order=layer_order, + num_groups=num_groups, + num_levels=num_levels, + is_segmentation=is_segmentation, + **kwargs, + ) class ResidualUNet3D(Abstract3DUNet): - """ - Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf. + """Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf. Uses ExtResNetBlock as a basic building block, summation joining instead of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts). Since the model effectively becomes a residual net, in theory it allows for deeper UNet. """ - def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', - num_groups=8, num_levels=5, is_segmentation=True, **kwargs): - super(ResidualUNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, - final_sigmoid=final_sigmoid, - basic_module=ExtResNetBlock, f_maps=f_maps, layer_order=layer_order, - num_groups=num_groups, num_levels=num_levels, - is_segmentation=is_segmentation, - **kwargs) + def __init__( + self, + in_channels, + out_channels, + final_sigmoid=True, + f_maps=64, + layer_order="gcr", + num_groups=8, + num_levels=5, + is_segmentation=True, + **kwargs, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + final_sigmoid=final_sigmoid, + basic_module=ExtResNetBlock, + f_maps=f_maps, + layer_order=layer_order, + num_groups=num_groups, + num_levels=num_levels, + is_segmentation=is_segmentation, + **kwargs, + ) def get_model(config): def _model_class(class_name): - m = importlib.import_module('pytorch3dunet.unet3d.model') + m = importlib.import_module("pytorch3dunet.unet3d.model") clazz = getattr(m, class_name) return clazz - assert 'model' in config, 'Could not find model configuration' - model_config = config['model'] - model_class = _model_class(model_config['name']) + assert "model" in config, "Could not find model configuration" + model_config = config["model"] + model_class = _model_class(model_config["name"]) return model_class(**model_config) @@ -542,18 +648,18 @@ if __name__ == "__main__": out_channels = 1 f_maps = 32 num_levels = 2 - model = UNet3D(in_channels, out_channels, f_maps=f_maps, num_levels=num_levels, layer_order='cr') + model = UNet3D(in_channels, out_channels, f_maps=f_maps, num_levels=num_levels, layer_order="cr") print(model) - print('number of parameters: ', sum(p.numel() for p in model.parameters())) + print("number of parameters: ", sum(p.numel() for p in model.parameters())) reso = 18 - + import numpy as np import torch + x = np.zeros((1, 1, reso, reso, reso)) - x[:,:, int(reso/2-1), int(reso/2-1), int(reso/2-1)] = np.nan + x[:, :, int(reso / 2 - 1), int(reso / 2 - 1), int(reso / 2 - 1)] = np.nan x = torch.FloatTensor(x) out = model(x) - print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso*reso))) - \ No newline at end of file + print("%f" % (torch.sum(torch.isnan(out)).detach().cpu().numpy() / (reso * reso * reso))) diff --git a/src/network/utils.py b/src/network/utils.py index 68ea744..ccc7a18 100644 --- a/src/network/utils.py +++ b/src/network/utils.py @@ -1,8 +1,9 @@ -""" Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """ +"""Positional encoding embedding. Code was taken from https://github.com/bmild/nerf.""" import torch import torch.nn as nn + class Embedder: def __init__(self, **kwargs): self.kwargs = kwargs @@ -10,24 +11,23 @@ class Embedder: def create_embedding_fn(self): embed_fns = [] - d = self.kwargs['input_dims'] + d = self.kwargs["input_dims"] out_dim = 0 - if self.kwargs['include_input']: + if self.kwargs["include_input"]: embed_fns.append(lambda x: x) out_dim += d - max_freq = self.kwargs['max_freq_log2'] - N_freqs = self.kwargs['num_freqs'] + max_freq = self.kwargs["max_freq_log2"] + N_freqs = self.kwargs["num_freqs"] - if self.kwargs['log_sampling']: - freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) + if self.kwargs["log_sampling"]: + freq_bands = 2.0 ** torch.linspace(0.0, max_freq, N_freqs) else: - freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) + freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, N_freqs) for freq in freq_bands: - for p_fn in self.kwargs['periodic_fns']: - embed_fns.append(lambda x, p_fn=p_fn, - freq=freq: p_fn(x * freq)) + for p_fn in self.kwargs["periodic_fns"]: + embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) out_dim += d self.embed_fns = embed_fns @@ -36,35 +36,40 @@ class Embedder: def embed(self, inputs): return torch.cat([fn(inputs) for fn in self.embed_fns], -1) + def get_embedder(multires, d_in=3): embed_kwargs = { - 'include_input': True, - 'input_dims': d_in, - 'max_freq_log2': multires-1, - 'num_freqs': multires, - 'log_sampling': True, - 'periodic_fns': [torch.sin, torch.cos], + "include_input": True, + "input_dims": d_in, + "max_freq_log2": multires - 1, + "num_freqs": multires, + "log_sampling": True, + "periodic_fns": [torch.sin, torch.cos], } embedder_obj = Embedder(**embed_kwargs) - def embed(x, eo=embedder_obj): return eo.embed(x) + + def embed(x, eo=embedder_obj): + return eo.embed(x) + return embed, embedder_obj.out_dim -def normalize_coordinate(p, plane='xz'): - ''' Normalize coordinate to [0, 1] for unit cube experiments + +def normalize_coordinate(p, plane="xz"): + """Normalize coordinate to [0, 1] for unit cube experiments. Args: p (tensor): point padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] plane (str): plane feature type, ['xz', 'xy', 'yz'] - ''' - if plane == 'xz': + """ + if plane == "xz": xy = p[:, :, [0, 2]] - elif plane =='xy': + elif plane == "xy": xy = p[:, :, [0, 1]] else: xy = p[:, :, [1, 2]] - + xy_new = xy # f there are outliers out of the range if xy_new.max() >= 1: @@ -75,40 +80,41 @@ def normalize_coordinate(p, plane='xz'): def normalize_3d_coordinate(p): - ''' Normalize coordinate to [0, 1] for unit cube experiments. - ''' + """Normalize coordinate to [0, 1] for unit cube experiments.""" if p.max() >= 1: p[p >= 1] = 1 - 10e-6 if p.min() < 0: p[p < 0] = 0.0 return p -def coordinate2index(x, reso, coord_type='2d'): - ''' Normalize coordinate to [0, 1] for unit cube experiments. - Corresponds to our 3D model + +def coordinate2index(x, reso, coord_type="2d"): + """Normalize coordinate to [0, 1] for unit cube experiments. + Corresponds to our 3D model. Args: x (tensor): coordinate reso (int): defined resolution coord_type (str): coordinate type - ''' + """ x = (x * reso).long() - if coord_type == '2d': # plane + if coord_type == "2d": # plane index = x[:, :, 0] + reso * x[:, :, 1] - elif coord_type == '3d': # grid + elif coord_type == "3d": # grid index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2]) index = index[:, None, :] return index -class map2local(object): - ''' Add new keys to the given input +class map2local: + """Add new keys to the given input. Args: s (float): the defined voxel size pos_encoding (str): method for the positional encoding, linear|sin_cos - ''' - def __init__(self, s, pos_encoding='linear'): + """ + + def __init__(self, s, pos_encoding="linear"): super().__init__() self.s = s # self.pe = positional_encoding(basis_function=pos_encoding, local=True) @@ -121,15 +127,16 @@ class map2local(object): # p = self.pe(p) return p + # Resnet Blocks class ResnetBlockFC(nn.Module): - ''' Fully connected ResNet Block class. + """Fully connected ResNet Block class. Args: size_in (int): input dimension size_out (int): output dimension size_h (int): hidden dimension - ''' + """ def __init__(self, size_in, size_out=None, size_h=None, siren=False): super().__init__() @@ -164,4 +171,4 @@ class ResnetBlockFC(nn.Module): else: x_s = x - return x_s + dx \ No newline at end of file + return x_s + dx diff --git a/src/optimization.py b/src/optimization.py index 12a0c1d..a156b65 100644 --- a/src/optimization.py +++ b/src/optimization.py @@ -1,112 +1,110 @@ -import time, os +import os + import numpy as np +import open3d as o3d import torch -from torch.nn import functional as F import trimesh +from pytorch3d.loss import chamfer_distance +from torchvision.io import write_video +from torchvision.utils import save_image from src.dpsr import DPSR from src.model import PSR2Mesh -from src.utils import grid_interp, verts_on_largest_mesh,\ - export_pointcloud, mc_from_psr, GaussianSmoothing -from src.visualize import visualize_points_mesh, visualize_psr_grid, \ - visualize_mesh_phong, render_rgb -from torchvision.utils import save_image -from torchvision.io import write_video -from pytorch3d.loss import chamfer_distance -import open3d as o3d +from src.utils import export_pointcloud, mc_from_psr, verts_on_largest_mesh +from src.visualize import ( + render_rgb, + visualize_mesh_phong, + visualize_points_mesh, + visualize_psr_grid, +) -class Trainer(object): - ''' - Args: - cfg : config file - optimizer : pytorch optimizer object - device : pytorch device - ''' + +class Trainer: + """Args: + cfg : config file + optimizer : pytorch optimizer object + device : pytorch device. + """ def __init__(self, cfg, optimizer, device=None): self.optimizer = optimizer self.device = device self.cfg = cfg self.psr2mesh = PSR2Mesh.apply - self.data_type = cfg['data']['data_type'] + self.data_type = cfg["data"]["data_type"] # initialize DPSR - self.dpsr = DPSR(res=(cfg['model']['grid_res'], - cfg['model']['grid_res'], - cfg['model']['grid_res']), - sig=cfg['model']['psr_sigma']) - if torch.cuda.device_count() > 1: - self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR + self.dpsr = DPSR( + res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]), + sig=cfg["model"]["psr_sigma"], + ) + if torch.cuda.device_count() > 1: + self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR self.dpsr = self.dpsr.to(device) def train_step(self, data, inputs, model, it): - ''' Performs a training step. + """Performs a training step. Args: data (dict) : data dictionary inputs (torch.tensor) : input point clouds model (nn.Module or None): a neural network or None it (int) : the number of iterations - ''' - + """ self.optimizer.zero_grad() loss, loss_each = self.compute_loss(inputs, data, model, it) loss.backward() self.optimizer.step() - + return loss.item(), loss_each def compute_loss(self, inputs, data, model, it=0): - ''' Compute the loss. + """Compute the loss. + Args: data (dict) : data dictionary inputs (torch.tensor) : input point clouds model (nn.Module or None): a neural network or None - it (int) : the number of iterations - ''' + it (int) : the number of iterations. + """ + res = self.cfg["model"]["grid_res"] - device = self.device - res = self.cfg['model']['grid_res'] - # source oriented point clouds to PSR grid psr_grid, points, normals = self.pcl2psr(inputs) - + # build mesh v, f, n = self.psr2mesh(psr_grid) - - # the output is in the range of [0, 1), we make it to the real range [0, 1]. - # This is a hack for our DPSR solver - v = v * res / (res-1) - points = points * 2. - 1. - v = v * 2. - 1. # within the range of (-1, 1) + # the output is in the range of [0, 1), we make it to the real range [0, 1]. + # This is a hack for our DPSR solver + v = v * res / (res - 1) + + points = points * 2.0 - 1.0 + v = v * 2.0 - 1.0 # within the range of (-1, 1) loss = 0 loss_each = {} # compute loss - if self.data_type == 'point': - if self.cfg['train']['w_chamfer'] > 0: - loss_ = self.cfg['train']['w_chamfer'] * \ - self.compute_3d_loss(v, data) - loss_each['chamfer'] = loss_ + if self.data_type == "point": + if self.cfg["train"]["w_chamfer"] > 0: + loss_ = self.cfg["train"]["w_chamfer"] * self.compute_3d_loss(v, data) + loss_each["chamfer"] = loss_ loss += loss_ - elif self.data_type == 'img': + elif self.data_type == "img": loss, loss_each = self.compute_2d_loss(inputs, data, model) - + return loss, loss_each - def pcl2psr(self, inputs): - ''' Convert an oriented point cloud to PSR indicator grid + """Convert an oriented point cloud to PSR indicator grid Args: - inputs (torch.tensor): input oriented point clouds - ''' - - points, normals = inputs[...,:3], inputs[...,3:] - if self.cfg['model']['apply_sigmoid']: + inputs (torch.tensor): input oriented point clouds. + """ + points, normals = inputs[..., :3], inputs[..., 3:] + if self.cfg["model"]["apply_sigmoid"]: points = torch.sigmoid(points) - if self.cfg['model']['normal_normalize']: + if self.cfg["model"]["normal_normalize"]: normals = normals / normals.norm(dim=-1, keepdim=True) # DPSR to get grid @@ -116,53 +114,50 @@ class Trainer(object): return psr_grid, points, normals def compute_3d_loss(self, v, data): - ''' Compute the loss for point clouds. + """Compute the loss for point clouds. + Args: v (torch.tensor) : mesh vertices - data (dict) : data dictionary - ''' - - pts_gt = data.get('target_points') - idx = np.random.randint(pts_gt.shape[1], size=self.cfg['train']['n_sup_point']) - if self.cfg['train']['subsample_vertex']: - #chamfer distance only on random sampled vertices - idx = np.random.randint(v.shape[1], size=self.cfg['train']['n_sup_point']) + data (dict) : data dictionary. + """ + pts_gt = data.get("target_points") + idx = np.random.randint(pts_gt.shape[1], size=self.cfg["train"]["n_sup_point"]) + if self.cfg["train"]["subsample_vertex"]: + # chamfer distance only on random sampled vertices + idx = np.random.randint(v.shape[1], size=self.cfg["train"]["n_sup_point"]) loss, _ = chamfer_distance(v[:, idx], pts_gt) else: loss, _ = chamfer_distance(v, pts_gt) return loss - + def compute_2d_loss(self, inputs, data, model): - ''' Compute the 2D losses. + """Compute the 2D losses. + Args: inputs (torch.tensor) : input source point clouds data (dict) : data dictionary - model (nn.Module or None): neural network or None - ''' - - losses = {"color": - {"weight": self.cfg['train']['l_weight']['rgb'], - "values": [] - }, - "silhouette": - {"weight": self.cfg['train']['l_weight']['mask'], - "values": []}, - } + model (nn.Module or None): neural network or None. + """ + losses = { + "color": { + "weight": self.cfg["train"]["l_weight"]["rgb"], + "values": [], + }, + "silhouette": {"weight": self.cfg["train"]["l_weight"]["mask"], "values": []}, + } loss_all = {k: torch.tensor(0.0, device=self.device) for k in losses} - + # forward pass out = model(inputs, data) - if out['rgb'] is not None: - rgb_gt = out['rgb_gt'].reshape(self.cfg['data']['n_views_per_iter'], - -1, 3)[out['vis_mask']] - loss_all["color"] += torch.nn.L1Loss(reduction='sum')(rgb_gt, - out['rgb']) / out['rgb'].shape[0] + if out["rgb"] is not None: + rgb_gt = out["rgb_gt"].reshape(self.cfg["data"]["n_views_per_iter"], -1, 3)[out["vis_mask"]] + loss_all["color"] += torch.nn.L1Loss(reduction="sum")(rgb_gt, out["rgb"]) / out["rgb"].shape[0] + + if out["mask"] is not None: + loss_all["silhouette"] += ((out["mask"] - out["mask_gt"]) ** 2).mean() - if out['mask'] is not None: - loss_all["silhouette"] += ((out['mask'] - out['mask_gt']) ** 2).mean() - # weighted sum of the losses loss = torch.tensor(0.0, device=self.device) for k, l in loss_all.items(): @@ -172,154 +167,146 @@ class Trainer(object): return loss, loss_all def point_resampling(self, inputs): - ''' Resample points + """Resample points Args: - inputs (torch.tensor): oriented point clouds - ''' - + inputs (torch.tensor): oriented point clouds. + """ psr_grid, points, normals = self.pcl2psr(inputs) - - # shortcuts - n_grow = self.cfg['train']['n_grow_points'] - # [hack] for points resampled from the mesh from marching cubes, + # shortcuts + n_grow = self.cfg["train"]["n_grow_points"] + + # [hack] for points resampled from the mesh from marching cubes, # we need to divide by s instead of (s-1), and the scale is correct. verts, faces, _ = mc_from_psr(psr_grid, real_scale=False, zero_level=0) # find the largest component pts_mesh, faces_mesh = verts_on_largest_mesh(verts, faces) - + # sample vertices only from the largest component, not from fragments mesh = trimesh.Trimesh(vertices=pts_mesh, faces=faces_mesh) - pi, face_idx = mesh.sample(n_grow+points.shape[1], return_index=True) - normals_i = mesh.face_normals[face_idx].astype('float32') - pts_mesh = torch.tensor(pi.astype('float32')).to(self.device)[None] + pi, face_idx = mesh.sample(n_grow + points.shape[1], return_index=True) + normals_i = mesh.face_normals[face_idx].astype("float32") + pts_mesh = torch.tensor(pi.astype("float32")).to(self.device)[None] n_mesh = torch.tensor(normals_i).to(self.device)[None] points, normals = pts_mesh, n_mesh - print('{} total points are resampled'.format(points.shape[1])) - + print(f"{points.shape[1]} total points are resampled") + # update inputs - points = torch.log(points / (1 - points)) # inverse sigmoid + points = torch.log(points / (1 - points)) # inverse sigmoid inputs = torch.cat([points, normals], dim=-1) - inputs.requires_grad = True + inputs.requires_grad = True return inputs def visualize(self, data, inputs, renderer, epoch, o3d_vis=None): - ''' Visualization. + """Visualization. + Args: data (dict) : data dictionary inputs (torch.tensor) : source point clouds renderer (nn.Module or None): a neural network or None epoch (int) : the number of iterations - o3d_vis (o3d.Visualizer) : open3d visualizer - ''' - - data_type = self.cfg['data']['data_type'] - it = '{:04d}'.format(int(epoch/self.cfg['train']['visualize_every'])) + o3d_vis (o3d.Visualizer) : open3d visualizer. + """ + data_type = self.cfg["data"]["data_type"] + it = "{:04d}".format(int(epoch / self.cfg["train"]["visualize_every"])) - - if (self.cfg['train']['exp_mesh']) \ - | (self.cfg['train']['exp_pcl']) \ - | (self.cfg['train']['o3d_show']): + if (self.cfg["train"]["exp_mesh"]) | (self.cfg["train"]["exp_pcl"]) | (self.cfg["train"]["o3d_show"]): psr_grid, points, normals = self.pcl2psr(inputs) with torch.no_grad(): - v, f, n = mc_from_psr(psr_grid, pytorchify=True, - zero_level=self.cfg['data']['zero_level'], real_scale=True) + v, f, n = mc_from_psr( + psr_grid, pytorchify=True, zero_level=self.cfg["data"]["zero_level"], real_scale=True, + ) v, f, n = v[None], f[None], n[None] - - v = v * 2. - 1. # change to the range of [-1, 1] + + v = v * 2.0 - 1.0 # change to the range of [-1, 1] color_v = None - if data_type == 'img': - if self.cfg['train']['vis_vert_color'] & \ - (self.cfg['train']['l_weight']['rgb'] != 0.): - color_v = renderer['color'](v, n).squeeze().detach().cpu().numpy() - color_v[color_v<0], color_v[color_v>1] = 0., 1. + if data_type == "img": + if self.cfg["train"]["vis_vert_color"] & (self.cfg["train"]["l_weight"]["rgb"] != 0.0): + color_v = renderer["color"](v, n).squeeze().detach().cpu().numpy() + color_v[color_v < 0], color_v[color_v > 1] = 0.0, 1.0 vv = v.detach().squeeze().cpu().numpy() ff = f.detach().squeeze().cpu().numpy() points = points * 2 - 1 - visualize_points_mesh(o3d_vis, points, normals, - vv, ff, self.cfg, it, epoch, color_v=color_v) + visualize_points_mesh(o3d_vis, points, normals, vv, ff, self.cfg, it, epoch, color_v=color_v) else: v, f, n = inputs - - if (data_type == 'img') & (self.cfg['train']['vis_rendering']): + if (data_type == "img") & (self.cfg["train"]["vis_rendering"]): pred_imgs = [] - pred_masks = [] - n_views = len(data['poses']) + len(data["poses"]) # idx_list = trange(n_views) idx_list = [13, 24, 27, 48] - #! + #! model = renderer.eval() for idx in idx_list: - pose = data['poses'][idx] - rgb = data['rgbs'][idx] - mask_gt = data['masks'][idx] - img_size = rgb.shape[0] if rgb.shape[0]== rgb.shape[1] else (rgb.shape[0], rgb.shape[1]) + pose = data["poses"][idx] + rgb = data["rgbs"][idx] + data["masks"][idx] + img_size = rgb.shape[0] if rgb.shape[0] == rgb.shape[1] else (rgb.shape[0], rgb.shape[1]) ray = None - if 'rays' in data.keys(): - ray = data['rays'][idx] - if self.cfg['train']['l_weight']['rgb'] != 0.: + if "rays" in data.keys(): + ray = data["rays"][idx] + if self.cfg["train"]["l_weight"]["rgb"] != 0.0: fea_grid = None if model.unet3d is not None: with torch.no_grad(): fea_grid = model.unet3d(psr_grid).permute(0, 2, 3, 4, 1) if model.encoder is not None: - pp = torch.cat([(points+1)/2, normals], dim=-1) - fea_grid = model.encoder(pp, - normalize=False).permute(0, 2, 3, 4, 1) + pp = torch.cat([(points + 1) / 2, normals], dim=-1) + fea_grid = model.encoder(pp, normalize=False).permute(0, 2, 3, 4, 1) - pred, visible_mask = render_rgb(v, f, n, pose, - model.rendering_network.eval(), - img_size, ray=ray, fea_grid=fea_grid) - img_pred = torch.zeros([rgb.shape[0]*rgb.shape[1], 3]) + pred, visible_mask = render_rgb( + v, f, n, pose, model.rendering_network.eval(), img_size, ray=ray, fea_grid=fea_grid, + ) + img_pred = torch.zeros([rgb.shape[0] * rgb.shape[1], 3]) img_pred[visible_mask] = pred.detach().cpu() img_pred = img_pred.reshape(rgb.shape[0], rgb.shape[1], 3) - img_pred[img_pred<0], img_pred[img_pred>1] = 0., 1. - filename=os.path.join(self.cfg['train']['dir_rendering'], - 'rendering_{}_{:d}.png'.format(it, idx)) + img_pred[img_pred < 0], img_pred[img_pred > 1] = 0.0, 1.0 + filename = os.path.join(self.cfg["train"]["dir_rendering"], f"rendering_{it}_{idx:d}.png") save_image(img_pred.permute(2, 0, 1), filename) pred_imgs.append(img_pred) #! Mesh rendering using Phong shading model - filename=os.path.join(self.cfg['train']['dir_rendering'], - 'mesh_{}_{:d}.png'.format(it, idx)) + filename = os.path.join(self.cfg["train"]["dir_rendering"], f"mesh_{it}_{idx:d}.png") visualize_mesh_phong(v, f, n, pose, img_size, name=filename) if len(pred_imgs) >= 1: pred_imgs = torch.stack(pred_imgs, dim=0) - save_image(pred_imgs.permute(0, 3, 1, 2), - os.path.join(self.cfg['train']['dir_rendering'], - '{}.png'.format(it)), nrow=4) - if self.cfg['train']['save_video']: - write_video(os.path.join(self.cfg['train']['dir_rendering'], - '{}.mp4'.format(it)), - (pred_imgs*255.).type(torch.uint8), fps=24) + save_image( + pred_imgs.permute(0, 3, 1, 2), os.path.join(self.cfg["train"]["dir_rendering"], f"{it}.png"), nrow=4, + ) + if self.cfg["train"]["save_video"]: + write_video( + os.path.join(self.cfg["train"]["dir_rendering"], f"{it}.mp4"), + (pred_imgs * 255.0).type(torch.uint8), + fps=24, + ) def save_mesh_pointclouds(self, inputs, epoch, center=None, scale=None): - ''' Save meshes and point clouds. + """Save meshes and point clouds. + Args: inputs (torch.tensor) : source point clouds epoch (int) : the number of iterations center (numpy.array) : center of the shape - scale (numpy.array) : scale of the shape - ''' + scale (numpy.array) : scale of the shape. + """ + exp_pcl = self.cfg["train"]["exp_pcl"] + exp_mesh = self.cfg["train"]["exp_mesh"] - exp_pcl = self.cfg['train']['exp_pcl'] - exp_mesh = self.cfg['train']['exp_mesh'] - psr_grid, points, normals = self.pcl2psr(inputs) - + if exp_pcl: - dir_pcl = self.cfg['train']['dir_pcl'] + dir_pcl = self.cfg["train"]["dir_pcl"] p = points.squeeze(0).detach().cpu().numpy() p = p * 2 - 1 n = normals.squeeze(0).detach().cpu().numpy() @@ -327,12 +314,11 @@ class Trainer(object): p *= scale if center is not None: p += center - export_pointcloud(os.path.join(dir_pcl, '{:04d}.ply'.format(epoch)), p, n) + export_pointcloud(os.path.join(dir_pcl, f"{epoch:04d}.ply"), p, n) if exp_mesh: - dir_mesh = self.cfg['train']['dir_mesh'] + dir_mesh = self.cfg["train"]["dir_mesh"] with torch.no_grad(): - v, f, _ = mc_from_psr(psr_grid, - zero_level=self.cfg['data']['zero_level'], real_scale=True) + v, f, _ = mc_from_psr(psr_grid, zero_level=self.cfg["data"]["zero_level"], real_scale=True) v = v * 2 - 1 if scale is not None: v *= scale @@ -341,9 +327,9 @@ class Trainer(object): mesh = o3d.geometry.TriangleMesh() mesh.vertices = o3d.utility.Vector3dVector(v) mesh.triangles = o3d.utility.Vector3iVector(f) - outdir_mesh = os.path.join(dir_mesh, '{:04d}.ply'.format(epoch)) + outdir_mesh = os.path.join(dir_mesh, f"{epoch:04d}.ply") o3d.io.write_triangle_mesh(outdir_mesh, mesh) - if self.cfg['train']['vis_psr']: - dir_psr_vis = self.cfg['train']['out_dir']+'/psr_vis_all' + if self.cfg["train"]["vis_psr"]: + dir_psr_vis = self.cfg["train"]["out_dir"] + "/psr_vis_all" visualize_psr_grid(psr_grid, out_dir=dir_psr_vis) diff --git a/src/training.py b/src/training.py index b38a1d4..6753657 100644 --- a/src/training.py +++ b/src/training.py @@ -1,75 +1,81 @@ import os +from collections import defaultdict + import numpy as np import torch +from pytorch3d.loss import chamfer_distance +from pytorch3d.ops.knn import knn_gather, knn_points from torch.nn import functional as F -from collections import defaultdict -import trimesh from tqdm import tqdm from src.dpsr import DPSR -from src.utils import grid_interp, export_pointcloud, export_mesh, \ - mc_from_psr, scale2onet, GaussianSmoothing -from pytorch3d.ops.knn import knn_gather, knn_points -from pytorch3d.loss import chamfer_distance -from pdb import set_trace as st +from src.utils import ( + GaussianSmoothing, + export_mesh, + export_pointcloud, + mc_from_psr, + scale2onet, +) + + +class Trainer: + """Args: + model (nn.Module): our defined model + optimizer (optimizer): pytorch optimizer object + device (device): pytorch device + input_type (str): input type + vis_dir (str): visualization directory. + """ -class Trainer(object): - ''' - Args: - model (nn.Module): our defined model - optimizer (optimizer): pytorch optimizer object - device (device): pytorch device - input_type (str): input type - vis_dir (str): visualization directory - ''' def __init__(self, cfg, optimizer, device=None): self.optimizer = optimizer self.device = device self.cfg = cfg - if self.cfg['train']['w_raw'] != 0: + if self.cfg["train"]["w_raw"] != 0: from src.model import PSR2Mesh + self.psr2mesh = PSR2Mesh.apply # initialize DPSR - self.dpsr = DPSR(res=(cfg['model']['grid_res'], - cfg['model']['grid_res'], - cfg['model']['grid_res']), - sig=cfg['model']['psr_sigma']) - if torch.cuda.device_count() > 1: - self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR + self.dpsr = DPSR( + res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]), + sig=cfg["model"]["psr_sigma"], + ) + if torch.cuda.device_count() > 1: + self.dpsr = torch.nn.DataParallel(self.dpsr) # parallell DPSR self.dpsr = self.dpsr.to(device) - if cfg['train']['gauss_weight']>0.: + if cfg["train"]["gauss_weight"] > 0.0: self.gauss_smooth = GaussianSmoothing(1, 7, 2).to(device) - + def train_step(self, inputs, data, model): - ''' Performs a training step. + """Performs a training step. Args: data (dict): data dictionary - ''' + """ self.optimizer.zero_grad() - p = data.get('inputs').to(self.device) - + p = data.get("inputs").to(self.device) + out = model(p) - + points, normals = out loss = 0 loss_each = {} - if self.cfg['train']['w_psr'] != 0: - psr_gt = data.get('gt_psr').to(self.device) - if self.cfg['model']['psr_tanh']: + if self.cfg["train"]["w_psr"] != 0: + psr_gt = data.get("gt_psr").to(self.device) + if self.cfg["model"]["psr_tanh"]: psr_gt = torch.tanh(psr_gt) - + psr_grid = self.dpsr(points, normals) - if self.cfg['model']['psr_tanh']: + if self.cfg["model"]["psr_tanh"]: psr_grid = torch.tanh(psr_grid) # apply a rescaling weight based on GT SDF values - if self.cfg['train']['gauss_weight']>0: - gauss_sigma = self.cfg['train']['gauss_weight'] - # set up the weighting for loss, higher weights + if self.cfg["train"]["gauss_weight"] > 0: + self.cfg["train"]["gauss_weight"] + # set up the weighting for loss, higher weights # for points near to the surface psr_gt_pad = torch.nn.ReplicationPad3d(1)(psr_gt.unsqueeze(1)).squeeze(1) delta_x = delta_y = delta_z = 1 @@ -82,97 +88,97 @@ class Trainer(object): psr_grad = torch.stack([grad_x, grad_y, grad_z], dim=-1) psr_grad_norm = psr_grad.norm(dim=-1)[:, None] w = torch.nn.ReplicationPad3d(3)(psr_grad_norm) - w = 2*self.gauss_smooth(w).squeeze(1) - loss_each['psr'] = self.cfg['train']['w_psr'] * F.mse_loss(w*psr_grid, w*psr_gt) + w = 2 * self.gauss_smooth(w).squeeze(1) + loss_each["psr"] = self.cfg["train"]["w_psr"] * F.mse_loss(w * psr_grid, w * psr_gt) else: - loss_each['psr'] = self.cfg['train']['w_psr'] * F.mse_loss(psr_grid, psr_gt) + loss_each["psr"] = self.cfg["train"]["w_psr"] * F.mse_loss(psr_grid, psr_gt) - loss += loss_each['psr'] + loss += loss_each["psr"] # regularization on the input point positions via chamfer distance - if self.cfg['train']['w_reg_point'] != 0.: - points_gt = data.get('gt_points').to(self.device) + if self.cfg["train"]["w_reg_point"] != 0.0: + points_gt = data.get("gt_points").to(self.device) loss_reg, loss_norm = chamfer_distance(points, points_gt) - - loss_each['reg'] = self.cfg['train']['w_reg_point'] * loss_reg - loss += loss_each['reg'] - - if self.cfg['train']['w_normals'] != 0.: - points_gt = data.get('gt_points').to(self.device) - normals_gt = data.get('gt_points.normals').to(self.device) + + loss_each["reg"] = self.cfg["train"]["w_reg_point"] * loss_reg + loss += loss_each["reg"] + + if self.cfg["train"]["w_normals"] != 0.0: + points_gt = data.get("gt_points").to(self.device) + normals_gt = data.get("gt_points.normals").to(self.device) x_nn = knn_points(points, points_gt, K=1) x_normals_near = knn_gather(normals_gt, x_nn.idx)[..., 0, :] - + cham_norm_x = F.l1_loss(normals, x_normals_near) loss_norm = cham_norm_x - loss_each['normals'] = self.cfg['train']['w_normals'] * loss_norm - loss += loss_each['normals'] - - if self.cfg['train']['w_raw'] != 0: - res = self.cfg['model']['grid_res'] + loss_each["normals"] = self.cfg["train"]["w_normals"] * loss_norm + loss += loss_each["normals"] + + if self.cfg["train"]["w_raw"] != 0: + self.cfg["model"]["grid_res"] # DPSR to get grid psr_grid = self.dpsr(points, normals) - if self.cfg['model']['psr_tanh']: + if self.cfg["model"]["psr_tanh"]: psr_grid = torch.tanh(psr_grid) - + v, f, n = self.psr2mesh(psr_grid) - pts_gt = data.get('gt_points').to(self.device) - + pts_gt = data.get("gt_points").to(self.device) + loss, _ = chamfer_distance(v, pts_gt) loss.backward() self.optimizer.step() return loss.item(), loss_each - - def save(self, model, data, epoch, id): - - p = data.get('inputs').to(self.device) - exp_pcl = self.cfg['train']['exp_pcl'] - exp_mesh = self.cfg['train']['exp_mesh'] - exp_gt = self.cfg['generation']['exp_gt'] - exp_input = self.cfg['generation']['exp_input'] - + def save(self, model, data, epoch, id): + p = data.get("inputs").to(self.device) + + exp_pcl = self.cfg["train"]["exp_pcl"] + exp_mesh = self.cfg["train"]["exp_mesh"] + exp_gt = self.cfg["generation"]["exp_gt"] + exp_input = self.cfg["generation"]["exp_input"] + model.eval() with torch.no_grad(): points, normals = model(p) if exp_gt: - points_gt = data.get('gt_points').to(self.device) - normals_gt = data.get('gt_points.normals').to(self.device) + points_gt = data.get("gt_points").to(self.device) + normals_gt = data.get("gt_points.normals").to(self.device) if exp_pcl: - dir_pcl = self.cfg['train']['dir_pcl'] - export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}.ply'.format(epoch, id)), scale2onet(points), normals) + dir_pcl = self.cfg["train"]["dir_pcl"] + export_pointcloud(os.path.join(dir_pcl, f"{epoch:04d}_{id:01d}.ply"), scale2onet(points), normals) if exp_gt: - export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}_oracle.ply'.format(epoch, id)), scale2onet(points_gt), normals_gt) + export_pointcloud( + os.path.join(dir_pcl, f"{epoch:04d}_{id:01d}_oracle.ply"), scale2onet(points_gt), normals_gt, + ) if exp_input: - export_pointcloud(os.path.join(dir_pcl, '{:04d}_{:01d}_input.ply'.format(epoch, id)), scale2onet(p)) + export_pointcloud(os.path.join(dir_pcl, f"{epoch:04d}_{id:01d}_input.ply"), scale2onet(p)) if exp_mesh: - dir_mesh = self.cfg['train']['dir_mesh'] + dir_mesh = self.cfg["train"]["dir_mesh"] psr_grid = self.dpsr(points, normals) # psr_grid = torch.tanh(psr_grid) with torch.no_grad(): - v, f, _ = mc_from_psr(psr_grid, - zero_level=self.cfg['data']['zero_level']) - outdir_mesh = os.path.join(dir_mesh, '{:04d}_{:01d}.ply'.format(epoch, id)) + v, f, _ = mc_from_psr(psr_grid, zero_level=self.cfg["data"]["zero_level"]) + outdir_mesh = os.path.join(dir_mesh, f"{epoch:04d}_{id:01d}.ply") export_mesh(outdir_mesh, scale2onet(v), f) if exp_gt: psr_gt = self.dpsr(points_gt, normals_gt) with torch.no_grad(): - v, f, _ = mc_from_psr(psr_gt, - zero_level=self.cfg['data']['zero_level']) - export_mesh(os.path.join(dir_mesh, '{:04d}_{:01d}_oracle.ply'.format(epoch, id)), scale2onet(v), f) - + v, f, _ = mc_from_psr(psr_gt, zero_level=self.cfg["data"]["zero_level"]) + export_mesh(os.path.join(dir_mesh, f"{epoch:04d}_{id:01d}_oracle.ply"), scale2onet(v), f) + def evaluate(self, val_loader, model): - ''' Performs an evaluation. + """Performs an evaluation. + Args: - val_loader (dataloader): pytorch dataloader - ''' + val_loader (dataloader): pytorch dataloader. + """ eval_list = defaultdict(list) for data in tqdm(val_loader): @@ -183,25 +189,26 @@ class Trainer(object): eval_dict = {k: np.mean(v) for k, v in eval_list.items()} return eval_dict - + def eval_step(self, data, model): - ''' Performs an evaluation step. + """Performs an evaluation step. + Args: - data (dict): data dictionary - ''' + data (dict): data dictionary. + """ model.eval() eval_dict = {} - p = data.get('inputs').to(self.device) - psr_gt = data.get('gt_psr').to(self.device) - + p = data.get("inputs").to(self.device) + psr_gt = data.get("gt_psr").to(self.device) + with torch.no_grad(): # forward pass points, normals = model(p) # DPSR to get predicted psr grid psr_grid = self.dpsr(points, normals) - eval_dict['psr_l1'] = F.l1_loss(psr_grid, psr_gt).item() - eval_dict['psr_l2'] = F.mse_loss(psr_grid, psr_gt).item() + eval_dict["psr_l1"] = F.l1_loss(psr_grid, psr_gt).item() + eval_dict["psr_l2"] = F.mse_loss(psr_grid, psr_gt).item() - return eval_dict \ No newline at end of file + return eval_dict diff --git a/src/utils.py b/src/utils.py index b9ce35b..222e251 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,53 +1,53 @@ -import torch -import io, os, logging, urllib -import yaml -import trimesh -import imageio -import numbers +import logging import math -import numpy as np +import numbers +import os +import urllib from collections import OrderedDict + +import numpy as np +import open3d as o3d +import torch +import trimesh +import yaml +from igl import adjacency_matrix, connected_components from plyfile import PlyData +from pytorch3d.renderer import PerspectiveCameras, rasterize_meshes +from pytorch3d.structures import Meshes +from skimage import measure from torch import nn from torch.nn import functional as F from torch.utils import model_zoo -from skimage import measure, img_as_float32 -from pytorch3d.structures import Meshes -from pytorch3d.renderer import PerspectiveCameras, rasterize_meshes -from igl import adjacency_matrix, connected_components -import open3d as o3d ################################################## # Below are functions for DPSR + def fftfreqs(res, dtype=torch.float32, exact=True): - """ - Helper function to return frequency tensors + """Helper function to return frequency tensors :param res: n_dims int tuple of number of frequency modes :return: """ - n_dims = len(res) freqs = [] for dim in range(n_dims - 1): r_ = res[dim] - freq = np.fft.fftfreq(r_, d=1/r_) + freq = np.fft.fftfreq(r_, d=1 / r_) freqs.append(torch.tensor(freq, dtype=dtype)) r_ = res[-1] if exact: - freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype)) + freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1 / r_), dtype=dtype)) else: - freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype)) + freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1 / r_)[:-1], dtype=dtype)) omega = torch.meshgrid(freqs) omega = list(omega) omega = torch.stack(omega, dim=-1) return omega -def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag) - """ - multiply tensor x by i ** deg - """ + +def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag) + """Multiply tensor x by i ** deg.""" deg %= 4 if deg == 0: res = x @@ -61,17 +61,18 @@ def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag) res[..., 1] = -res[..., 1] return res + def spec_gaussian_filter(res, sig): - omega = fftfreqs(res, dtype=torch.float64) # [dim0, dim1, dim2, d] - dis = torch.sqrt(torch.sum(omega ** 2, dim=-1)) - filter_ = torch.exp(-0.5*((sig*2*dis/res[0])**2)).unsqueeze(-1).unsqueeze(-1) + omega = fftfreqs(res, dtype=torch.float64) # [dim0, dim1, dim2, d] + dis = torch.sqrt(torch.sum(omega**2, dim=-1)) + filter_ = torch.exp(-0.5 * ((sig * 2 * dis / res[0]) ** 2)).unsqueeze(-1).unsqueeze(-1) filter_.requires_grad = False return filter_ + def grid_interp(grid, pts, batched=True): - """ - :param grid: tensor of shape (batch, *size, in_features) + """:param grid: tensor of shape (batch, *size, in_features) :param pts: tensor of shape (batch, num_points, dim) within range (0, 1) :return values at query points """ @@ -82,69 +83,71 @@ def grid_interp(grid, pts, batched=True): bs = grid.shape[0] size = torch.tensor(grid.shape[1:-1]).to(grid.device).type(pts.dtype) cubesize = 1.0 / size - + ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim) - ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around - ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim) - tmp = torch.tensor([0,1],dtype=torch.long) + ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around + ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim) + tmp = torch.tensor([0, 1], dtype=torch.long) com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim) - dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim) - ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points) - ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim) - ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim) + dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim) + ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points) + ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim) + ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim) # latent code on neighbor nodes if dim == 2: - lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1]] # (batch, num_points, 2**dim, in_features) + lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1]] # (batch, num_points, 2**dim, in_features) else: - lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2]] # (batch, num_points, 2**dim, in_features) + lat = grid.clone()[ + ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2], + ] # (batch, num_points, 2**dim, in_features) # weights of neighboring nodes - xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim) + xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim) xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim) - xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim) - pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim) - pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim) + xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim) + xyz01[com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim) + pos_ = xyz01[1 - com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim) pos_ = pos_.type(pts.dtype) - dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim) - weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim) + dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim) + weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim) query_values = torch.sum(lat * weights.unsqueeze(-1), dim=-2) # (batch, num_points, in_features) if not batched: query_values = query_values.squeeze(0) - + return query_values + def scatter_to_grid(inds, vals, size): - """ - Scatter update values into empty tensor of size size. + """Scatter update values into empty tensor of size size. :param inds: (#values, dims) :param vals: (#values) - :param size: tuple for size. len(size)=dims + :param size: tuple for size. len(size)=dims. """ dims = inds.shape[1] - assert(inds.shape[0] == vals.shape[0]) - assert(len(size) == dims) + assert inds.shape[0] == vals.shape[0] + assert len(size) == dims dev = vals.device # result = torch.zeros(*size).view(-1).to(dev).type(vals.dtype) # flatten # # flatten inds result = torch.zeros(*size, device=dev).view(-1).type(vals.dtype) # flatten # flatten inds - fac = [np.prod(size[i+1:]) for i in range(len(size)-1)] + [1] + fac = [np.prod(size[i + 1 :]) for i in range(len(size) - 1)] + [1] fac = torch.tensor(fac, device=dev).type(inds.dtype) - inds_fold = torch.sum(inds*fac, dim=-1) # [#values,] + inds_fold = torch.sum(inds * fac, dim=-1) # [#values,] result.scatter_add_(0, inds_fold, vals) result = result.view(*size) return result + def point_rasterize(pts, vals, size): - """ - :param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1) + """:param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1) :param vals: point values, tensor of shape (batch, num_points, features) :param size: len(size)=dim tuple for grid size :return rasterized values (batch, features, res0, res1, res2) """ dim = pts.shape[-1] - assert(pts.shape[:2] == vals.shape[:2]) - assert(pts.shape[2] == dim) + assert pts.shape[:2] == vals.shape[:2] + assert pts.shape[2] == dim size_list = list(size) size = torch.tensor(size).to(pts.device).float() cubesize = 1.0 / size @@ -152,55 +155,58 @@ def point_rasterize(pts, vals, size): nf = vals.shape[-1] npts = pts.shape[1] dev = pts.device - + ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim) - ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around - ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim) - tmp = torch.tensor([0,1],dtype=torch.long) + ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around + ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim) + tmp = torch.tensor([0, 1], dtype=torch.long) com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim) - dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim) - ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points) - ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim) + dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim) + ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points) + ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim) # ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim) - ind_b = torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim) - + ind_b = ( + torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) + ) # (batch, num_points, 2**dim) + # weights of neighboring nodes - xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim) + xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim) xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim) - xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim) - pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim) - pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim) + xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim) + xyz01[com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim) + pos_ = xyz01[1 - com_, ..., dim_].permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim) pos_ = pos_.type(pts.dtype) - dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim) - weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim) - - ind_b = ind_b.unsqueeze(-1).unsqueeze(-1) # (batch, num_points, 2**dim, 1, 1) - ind_n = ind_n.unsqueeze(-2) # (batch, num_points, 2**dim, 1, dim) + dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim) + weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim) + + ind_b = ind_b.unsqueeze(-1).unsqueeze(-1) # (batch, num_points, 2**dim, 1, 1) + ind_n = ind_n.unsqueeze(-2) # (batch, num_points, 2**dim, 1, dim) ind_f = torch.arange(nf, device=dev).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1) # ind_f = torch.arange(nf).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1) - + ind_b = ind_b.expand(bs, npts, 2**dim, nf, 1) ind_n = ind_n.expand(bs, npts, 2**dim, nf, dim).to(dev) ind_f = ind_f.expand(bs, npts, 2**dim, nf, 1) inds = torch.cat([ind_b, ind_f, ind_n], dim=-1) # (batch, num_points, 2**dim, nf, 1+1+dim) - - # weighted values - vals = weights.unsqueeze(-1) * vals.unsqueeze(-2) # (batch, num_points, 2**dim, nf) - - inds = inds.view(-1, dim+2).permute(1, 0).long() # (1+dim+1, bs*npts*2**dim*nf) - vals = vals.reshape(-1) # (bs*npts*2**dim*nf) - tensor_size = [bs, nf] + size_list - raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf] + size_list) - - return raster # [batch, nf, res, res, res] + # weighted values + vals = weights.unsqueeze(-1) * vals.unsqueeze(-2) # (batch, num_points, 2**dim, nf) + + inds = inds.view(-1, dim + 2).permute(1, 0).long() # (1+dim+1, bs*npts*2**dim*nf) + vals = vals.reshape(-1) # (bs*npts*2**dim*nf) + [bs, nf, *size_list] + raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf, *size_list]) + + return raster # [batch, nf, res, res, res] ################################################## # Below are the utilization functions in general -class AverageMeter(object): - """Computes and stores the average and current value""" + +class AverageMeter: + """Computes and stores the average and current value.""" + def __init__(self): self.reset() @@ -226,49 +232,49 @@ class AverageMeter(object): def avgcavg(self): return self.avg.sum().item() / (self.count != 0).sum().item() + def load_model_manual(state_dict, model): new_state_dict = OrderedDict() is_model_parallel = isinstance(model, torch.nn.DataParallel) for k, v in state_dict.items(): - if k.startswith('module.') != is_model_parallel: - if k.startswith('module.'): + if k.startswith("module.") != is_model_parallel: + if k.startswith("module."): # remove module k = k[7:] else: # add module - k = 'module.' + k + k = "module." + k - new_state_dict[k]=v + new_state_dict[k] = v model.load_state_dict(new_state_dict) + def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0): - ''' - Run marching cubes from PSR grid - ''' + """Run marching cubes from PSR grid.""" batch_size = psr_grid.shape[0] - s = psr_grid.shape[-1] # size of psr_grid + s = psr_grid.shape[-1] # size of psr_grid psr_grid_numpy = psr_grid.squeeze().detach().cpu().numpy() - - if batch_size>1: + + if batch_size > 1: verts, faces, normals = [], [], [] for i in range(batch_size): verts_cur, faces_cur, normals_cur, values = measure.marching_cubes(psr_grid_numpy[i], level=0) verts.append(verts_cur) faces.append(faces_cur) normals.append(normals_cur) - verts = np.stack(verts, axis = 0) - faces = np.stack(faces, axis = 0) - normals = np.stack(normals, axis = 0) + verts = np.stack(verts, axis=0) + faces = np.stack(faces, axis=0) + normals = np.stack(normals, axis=0) else: try: verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy, level=zero_level) except: verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy) if real_scale: - verts = verts / (s-1) # scale to range [0, 1] + verts = verts / (s - 1) # scale to range [0, 1] else: - verts = verts / s # scale to range [0, 1) + verts = verts / s # scale to range [0, 1) if pytorchify: device = psr_grid.device @@ -277,7 +283,8 @@ def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0): normals = torch.Tensor(np.ascontiguousarray(-normals)).to(device) return verts, faces, normals - + + def calc_inters_points(verts, faces, pose, img_size, mask_gt=None): verts = verts.squeeze() faces = faces.squeeze() @@ -288,42 +295,39 @@ def calc_inters_points(verts, faces, pose, img_size, mask_gt=None): # find 3D points intesected on the mesh if True: w_masked = w[mask] - f_p = faces[pix_to_face[mask]].long() # cooresponding faces for each pixel + f_p = faces[pix_to_face[mask]].long() # cooresponding faces for each pixel # corresponding vertices for p_closest v_a, v_b, v_c = verts[f_p[..., 0]], verts[f_p[..., 1]], verts[f_p[..., 2]] - + # calculate the intersection point of each pixel and the mesh - p_inters = w_masked[..., 0, None] * v_a + \ - w_masked[..., 1, None] * v_b + \ - w_masked[..., 2, None] * v_c + p_inters = w_masked[..., 0, None] * v_a + w_masked[..., 1, None] * v_b + w_masked[..., 2, None] * v_c else: # backproject ndc to world coordinates using z-buffer W, H = img_size[1], img_size[0] xy = uv.to(mask.device)[mask] - x_ndc = 1 - (2*xy[:, 0]) / (W - 1) - y_ndc = 1 - (2*xy[:, 1]) / (H - 1) + x_ndc = 1 - (2 * xy[:, 0]) / (W - 1) + y_ndc = 1 - (2 * xy[:, 1]) / (H - 1) z = zbuf.squeeze().reshape(H * W)[mask] xy_depth = torch.stack((x_ndc, y_ndc, z), dim=1) - + p_inters = pose.unproject_points(xy_depth, world_coordinates=True) - + # if there are outlier points, we should remove it - if (p_inters.max()>1) | (p_inters.min()<-1): - mask_bound = (p_inters>=-1) & (p_inters<=1) - mask_bound = (mask_bound.sum(dim=-1)==3) - mask[mask==True] = mask_bound + if (p_inters.max() > 1) | (p_inters.min() < -1): + mask_bound = (p_inters >= -1) & (p_inters <= 1) + mask_bound = mask_bound.sum(dim=-1) == 3 + mask[mask is True] = mask_bound p_inters = p_inters[mask_bound] - print('!!!!!find outlier!') + print("!!!!!find outlier!") return p_inters, mask, f_p, w_masked + def mesh_rasterization(verts, faces, pose, img_size): - ''' - Use PyTorch3D to rasterize the mesh given a camera - ''' - transformed_v = pose.transform_points(verts.detach()) # world -> ndc coordinate system + """Use PyTorch3D to rasterize the mesh given a camera.""" + transformed_v = pose.transform_points(verts.detach()) # world -> ndc coordinate system if isinstance(pose, PerspectiveCameras): - transformed_v[..., 2] = 1/transformed_v[..., 2] + transformed_v[..., 2] = 1 / transformed_v[..., 2] # find p_closest on mesh of each pixel via rasterization transformed_mesh = Meshes(verts=[transformed_v], faces=[faces]) pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( @@ -331,9 +335,9 @@ def mesh_rasterization(verts, faces, pose, img_size): image_size=img_size, blur_radius=0, faces_per_pixel=1, - perspective_correct=False + perspective_correct=False, ) - pix_to_face = pix_to_face.reshape(1, -1) # B x reso x reso -> B x (reso x reso) + pix_to_face = pix_to_face.reshape(1, -1) # B x reso x reso -> B x (reso x reso) mask = pix_to_face.clone() != -1 mask = mask.squeeze() pix_to_face = pix_to_face.squeeze() @@ -341,11 +345,11 @@ def mesh_rasterization(verts, faces, pose, img_size): return pix_to_face, w, mask + def verts_on_largest_mesh(verts, faces): - ''' - verts: Numpy array or Torch.Tensor (N, 3) - faces: Numpy array (N, 3) - ''' + """verts: Numpy array or Torch.Tensor (N, 3) + faces: Numpy array (N, 3). + """ if torch.is_tensor(faces): verts = verts.squeeze().detach().cpu().numpy() faces = faces.squeeze().int().detach().cpu().numpy() @@ -355,8 +359,8 @@ def verts_on_largest_mesh(verts, faces): if num == 0: v_large, f_large = verts, faces else: - max_idx = conn_size.argmax() # find the index of the largest component - v_large = verts[conn_idx==max_idx] # keep points on the largest component + max_idx = conn_size.argmax() # find the index of the largest component + v_large = verts[conn_idx == max_idx] # keep points on the largest component if True: mesh_largest = trimesh.Trimesh(verts, faces) @@ -366,36 +370,41 @@ def verts_on_largest_mesh(verts, faces): v_large = v_large.astype(np.float32) return v_large, f_large + def load_pointcloud(in_file): plydata = PlyData.read(in_file) - vertices = np.stack([ - plydata['vertex']['x'], - plydata['vertex']['y'], - plydata['vertex']['z'] - ], axis=1) + vertices = np.stack( + [ + plydata["vertex"]["x"], + plydata["vertex"]["y"], + plydata["vertex"]["z"], + ], + axis=1, + ) return vertices + # General config def load_config(path, default_path=None): - ''' Loads config file. + """Loads config file. - Args: + Args: path (str): path to config file default_path (bool): whether to use default path - ''' + """ # Load configuration from file itself - with open(path, 'r') as f: + with open(path) as f: cfg_special = yaml.load(f, Loader=yaml.Loader) # Check if we should inherit from a config - inherit_from = cfg_special.get('inherit_from') + inherit_from = cfg_special.get("inherit_from") # If yes, load this config first as default # If no, use the default_path if inherit_from is not None: cfg = load_config(inherit_from, default_path) elif default_path is not None: - with open(default_path, 'r') as f: + with open(default_path) as f: cfg = yaml.load(f, Loader=yaml.Loader) else: cfg = dict() @@ -405,65 +414,67 @@ def load_config(path, default_path=None): return cfg + def update_config(config, unknown): # update config given args - for idx,arg in enumerate(unknown): + for idx, arg in enumerate(unknown): if arg.startswith("--"): - keys = arg.replace("--","").split(':') - assert(len(keys)==2) + keys = arg.replace("--", "").split(":") + assert len(keys) == 2 k1, k2 = keys argtype = type(config[k1][k2]) if argtype == bool: - v = unknown[idx+1].lower() == 'true' + v = unknown[idx + 1].lower() == "true" else: if config[k1][k2] is not None: - v = type(config[k1][k2])(unknown[idx+1]) + v = type(config[k1][k2])(unknown[idx + 1]) else: - v = unknown[idx+1] - print(f'Changing {k1}:{k2} ---- {config[k1][k2]} to {v}') + v = unknown[idx + 1] + print(f"Changing {k1}:{k2} ---- {config[k1][k2]} to {v}") config[k1][k2] = v return config + def initialize_logger(cfg): - out_dir = cfg['train']['out_dir'] + out_dir = cfg["train"]["out_dir"] if not out_dir: os.makedirs(out_dir) - - cfg['train']['dir_model'] = os.path.join(out_dir, 'model') - os.makedirs(cfg['train']['dir_model'], exist_ok=True) - if cfg['train']['exp_mesh']: - cfg['train']['dir_mesh'] = os.path.join(out_dir, 'vis/mesh') - os.makedirs(cfg['train']['dir_mesh'], exist_ok=True) - if cfg['train']['exp_pcl']: - cfg['train']['dir_pcl'] = os.path.join(out_dir, 'vis/pointcloud') - os.makedirs(cfg['train']['dir_pcl'], exist_ok=True) - if cfg['train']['vis_rendering']: - cfg['train']['dir_rendering'] = os.path.join(out_dir, 'vis/rendering') - os.makedirs(cfg['train']['dir_rendering'], exist_ok=True) - if cfg['train']['o3d_show']: - cfg['train']['dir_o3d'] = os.path.join(out_dir, 'vis/o3d') - os.makedirs(cfg['train']['dir_o3d'], exist_ok=True) + cfg["train"]["dir_model"] = os.path.join(out_dir, "model") + os.makedirs(cfg["train"]["dir_model"], exist_ok=True) + if cfg["train"]["exp_mesh"]: + cfg["train"]["dir_mesh"] = os.path.join(out_dir, "vis/mesh") + os.makedirs(cfg["train"]["dir_mesh"], exist_ok=True) + if cfg["train"]["exp_pcl"]: + cfg["train"]["dir_pcl"] = os.path.join(out_dir, "vis/pointcloud") + os.makedirs(cfg["train"]["dir_pcl"], exist_ok=True) + if cfg["train"]["vis_rendering"]: + cfg["train"]["dir_rendering"] = os.path.join(out_dir, "vis/rendering") + os.makedirs(cfg["train"]["dir_rendering"], exist_ok=True) + if cfg["train"]["o3d_show"]: + cfg["train"]["dir_o3d"] = os.path.join(out_dir, "vis/o3d") + os.makedirs(cfg["train"]["dir_o3d"], exist_ok=True) logger = logging.getLogger("train") logger.setLevel(logging.DEBUG) logger.handlers = [] # ch = logging.StreamHandler() # logger.addHandler(ch) - fh = logging.FileHandler(os.path.join(cfg['train']['out_dir'], "log.txt")) + fh = logging.FileHandler(os.path.join(cfg["train"]["out_dir"], "log.txt")) logger.addHandler(fh) - logger.info('Outout dir: %s', out_dir) + logger.info("Outout dir: %s", out_dir) return logger + def update_recursive(dict1, dict2): - ''' Update two config dictionaries recursively. + """Update two config dictionaries recursively. Args: dict1 (dict): first dictionary to be updated dict2 (dict): second dictionary which entries should be used - ''' + """ for k, v in dict2.items(): if k not in dict1: dict1[k] = dict() @@ -472,6 +483,7 @@ def update_recursive(dict1, dict2): else: dict1[k] = v + def export_pointcloud(name, points, normals=None): if len(points.shape) > 2: points = points[0] @@ -487,6 +499,7 @@ def export_pointcloud(name, points, normals=None): pcd.normals = o3d.utility.Vector3dVector(normals) o3d.io.write_point_cloud(name, pcd) + def export_mesh(name, v, f): if len(v.shape) > 2: v, f = v[0], f[0] @@ -498,59 +511,63 @@ def export_mesh(name, v, f): mesh.triangles = o3d.utility.Vector3iVector(f) o3d.io.write_triangle_mesh(name, mesh) + def scale2onet(p, scale=1.2): - ''' - Scale the point cloud from SAP to ONet range - ''' + """Scale the point cloud from SAP to ONet range.""" return (p - 0.5) * scale - + + def update_optimizer(inputs, cfg, epoch, model=None, schedule=None): if model is not None: if schedule is not None: - optimizer = torch.optim.Adam([ - {"params": model.parameters(), - "lr": schedule[0].get_learning_rate(epoch)}, - {"params": inputs, - "lr": schedule[1].get_learning_rate(epoch)}]) - elif 'lr' in cfg['train']: - optimizer = torch.optim.Adam([ - {"params": model.parameters(), - "lr": float(cfg['train']['lr'])}, - {"params": inputs, - "lr": float(cfg['train']['lr_pcl'])}]) + optimizer = torch.optim.Adam( + [ + {"params": model.parameters(), "lr": schedule[0].get_learning_rate(epoch)}, + {"params": inputs, "lr": schedule[1].get_learning_rate(epoch)}, + ], + ) + elif "lr" in cfg["train"]: + optimizer = torch.optim.Adam( + [ + {"params": model.parameters(), "lr": float(cfg["train"]["lr"])}, + {"params": inputs, "lr": float(cfg["train"]["lr_pcl"])}, + ], + ) else: - raise Exception('no known learning rate') + msg = "no known learning rate" + raise Exception(msg) else: if schedule is not None: optimizer = torch.optim.Adam([inputs], lr=schedule[0].get_learning_rate(epoch)) else: - optimizer = torch.optim.Adam([inputs], lr=float(cfg['train']['lr_pcl'])) + optimizer = torch.optim.Adam([inputs], lr=float(cfg["train"]["lr_pcl"])) return optimizer def is_url(url): scheme = urllib.parse.urlparse(url).scheme - return scheme in ('http', 'https') + return scheme in ("http", "https") + def load_url(url): - '''Load a module dictionary from url. - + """Load a module dictionary from url. + Args: url (str): url to saved model - ''' + """ print(url) - print('=> Loading checkpoint from url...') + print("=> Loading checkpoint from url...") state_dict = model_zoo.load_url(url, progress=True) - + return state_dict class GaussianSmoothing(nn.Module): - """ - Apply gaussian smoothing on a + """Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input using a depthwise convolution. + Arguments: channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well. kernel_size (int, sequence): Size of the gaussian kernel. @@ -558,8 +575,9 @@ class GaussianSmoothing(nn.Module): dim (int, optional): The number of dimensions of the data. Default value is 2 (spatial). """ + def __init__(self, channels, kernel_size, sigma, dim=3): - super(GaussianSmoothing, self).__init__() + super().__init__() if isinstance(kernel_size, numbers.Number): kernel_size = [kernel_size] * dim if isinstance(sigma, numbers.Number): @@ -569,15 +587,11 @@ class GaussianSmoothing(nn.Module): # gaussian function of each dimension. kernel = 1 meshgrids = torch.meshgrid( - [ - torch.arange(size, dtype=torch.float32) - for size in kernel_size - ] + [torch.arange(size, dtype=torch.float32) for size in kernel_size], ) for size, std, mgrid in zip(kernel_size, sigma, meshgrids): mean = (size - 1) / 2 - kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ - torch.exp(-((mgrid - mean) / std) ** 2 / 2) + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / std) ** 2) / 2) # Make sure sum of values in gaussian kernel equals 1. kernel = kernel / torch.sum(kernel) @@ -586,7 +600,7 @@ class GaussianSmoothing(nn.Module): kernel = kernel.view(1, 1, *kernel.size()) kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) - self.register_buffer('weight', kernel) + self.register_buffer("weight", kernel) self.groups = channels if dim == 1: @@ -596,36 +610,44 @@ class GaussianSmoothing(nn.Module): elif dim == 3: self.conv = F.conv3d else: + msg = f"Only 1, 2 and 3 dimensions are supported. Received {dim}." raise RuntimeError( - 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) + msg, ) def forward(self, input): - """ - Apply gaussian filter to input. + """Apply gaussian filter to input. + Arguments: input (torch.Tensor): Input to apply gaussian filter on. + Returns: filtered (torch.Tensor): Filtered output. """ return self.conv(input, weight=self.weight, groups=self.groups) - + + # Originally from https://github.com/amosgropp/IGR/blob/0db06b1273/code/utils/general.py def get_learning_rate_schedules(schedule_specs): - schedules = [] for key in schedule_specs.keys(): - schedules.append(StepLearningRateSchedule( - schedule_specs[key]['initial'], + schedules.append( + StepLearningRateSchedule( + schedule_specs[key]["initial"], schedule_specs[key]["interval"], - schedule_specs[key]["factor"], - schedule_specs[key]["final"])) + schedule_specs[key]["factor"], + schedule_specs[key]["final"], + ), + ) return schedules + class LearningRateSchedule: def get_learning_rate(self, epoch): pass + + class StepLearningRateSchedule(LearningRateSchedule): def __init__(self, initial, interval, factor, final=1e-6): self.initial = float(initial) @@ -640,6 +662,7 @@ class StepLearningRateSchedule(LearningRateSchedule): else: return self.final + def adjust_learning_rate(lr_schedules, optimizer, epoch): for i, param_group in enumerate(optimizer.param_groups): param_group["lr"] = lr_schedules[i].get_learning_rate(epoch) diff --git a/src/visualize.py b/src/visualize.py index 0eec2da..346f178 100644 --- a/src/visualize.py +++ b/src/visualize.py @@ -1,95 +1,93 @@ import os -import torch + +import matplotlib.pyplot as plt import numpy as np import open3d as o3d -import matplotlib.pyplot as plt -from skimage import measure -from src.utils import calc_inters_points, grid_interp +import torch from scipy import ndimage -from tqdm import trange from torchvision.utils import save_image -from pdb import set_trace as st +from tqdm import trange + +from src.utils import calc_inters_points, grid_interp def visualize_points_mesh(vis, points, normals, verts, faces, cfg, it, epoch, color_v=None): - ''' Visualization. + """Visualization. Args: data (dict): data dictionary depth (int): PSR depth - out_path (str): output path for the mesh - ''' + out_path (str): output path for the mesh + """ mesh = o3d.geometry.TriangleMesh() mesh.vertices = o3d.utility.Vector3dVector(verts) mesh.triangles = o3d.utility.Vector3iVector(faces) - mesh.paint_uniform_color(np.array([0.7,0.7,0.7])) + mesh.paint_uniform_color(np.array([0.7, 0.7, 0.7])) if color_v is not None: mesh.vertex_colors = o3d.utility.Vector3dVector(color_v) - + if vis is not None: - dir_o3d = cfg['train']['dir_o3d'] - wire = o3d.geometry.LineSet.create_from_triangle_mesh(mesh) + dir_o3d = cfg["train"]["dir_o3d"] + o3d.geometry.LineSet.create_from_triangle_mesh(mesh) p = points.squeeze(0).detach().cpu().numpy() n = normals.squeeze(0).detach().cpu().numpy() pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(p) pcd.normals = o3d.utility.Vector3dVector(n) - pcd.paint_uniform_color(np.array([0.7,0.7,1.0])) + pcd.paint_uniform_color(np.array([0.7, 0.7, 1.0])) # pcd = pcd.uniform_down_sample(5) vis.clear_geometries() vis.add_geometry(mesh) vis.update_geometry(mesh) - + #! Thingi wheel - an example for how to change cameras in Open3D viewers - vis.get_view_control().set_front([ 0.0461, -0.7467, 0.6635 ]) - vis.get_view_control().set_lookat([ 0.0092, 0.0078, 0.0638 ]) - vis.get_view_control().set_up([ 0.0520, 0.6651, 0.7449 ]) + vis.get_view_control().set_front([0.0461, -0.7467, 0.6635]) + vis.get_view_control().set_lookat([0.0092, 0.0078, 0.0638]) + vis.get_view_control().set_up([0.0520, 0.6651, 0.7449]) vis.get_view_control().set_zoom(0.7) vis.poll_events() - - out_path = os.path.join(dir_o3d, '{}.jpg'.format(it)) + + out_path = os.path.join(dir_o3d, f"{it}.jpg") vis.capture_screen_image(out_path) vis.clear_geometries() vis.add_geometry(pcd, reset_bounding_box=False) vis.update_geometry(pcd) - vis.get_render_option().point_show_normal=True # visualize point normals + vis.get_render_option().point_show_normal = True # visualize point normals - vis.get_view_control().set_front([ 0.0461, -0.7467, 0.6635 ]) - vis.get_view_control().set_lookat([ 0.0092, 0.0078, 0.0638 ]) - vis.get_view_control().set_up([ 0.0520, 0.6651, 0.7449 ]) + vis.get_view_control().set_front([0.0461, -0.7467, 0.6635]) + vis.get_view_control().set_lookat([0.0092, 0.0078, 0.0638]) + vis.get_view_control().set_up([0.0520, 0.6651, 0.7449]) vis.get_view_control().set_zoom(0.7) vis.poll_events() - out_path = os.path.join(dir_o3d, '{}_pcd.jpg'.format(it)) + out_path = os.path.join(dir_o3d, f"{it}_pcd.jpg") vis.capture_screen_image(out_path) -def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name='video.mp4'): + +def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name="video.mp4"): if pose is not None: device = psr_grid.device # get world coordinate of grid points [-1, 1] res = psr_grid.shape[-1] x = torch.linspace(-1, 1, steps=res) co_x, co_y, co_z = torch.meshgrid(x, x, x) - co_grid = torch.stack( - [co_x.reshape(-1), co_y.reshape(-1), co_z.reshape(-1)], - dim=1).to(device).unsqueeze(0) + co_grid = torch.stack([co_x.reshape(-1), co_y.reshape(-1), co_z.reshape(-1)], dim=1).to(device).unsqueeze(0) # visualize the projected occ_soft value res = 128 psr_grid = psr_grid.reshape(-1) - out_mask = psr_grid>0 - in_mask = psr_grid<0 + out_mask = psr_grid > 0 + in_mask = psr_grid < 0 pix = pose.transform_points_screen(co_grid, ((res, res),))[..., :2].round().long().squeeze() - vis_mask = (pix[..., 0]>=0) & (pix[..., 0]<=res-1) & \ - (pix[..., 1]>=0) & (pix[..., 1]<=res-1) + vis_mask = (pix[..., 0] >= 0) & (pix[..., 0] <= res - 1) & (pix[..., 1] >= 0) & (pix[..., 1] <= res - 1) pix_out = pix[vis_mask & out_mask] pix_in = pix[vis_mask & in_mask] - - img = torch.ones([res,res]).to(device) - psr_grid = torch.sigmoid(- psr_grid * 5) + + img = torch.ones([res, res]).to(device) + psr_grid = torch.sigmoid(-psr_grid * 5) img[pix_out[:, 1], pix_out[:, 0]] = psr_grid[vis_mask & out_mask] img[pix_in[:, 1], pix_in[:, 0]] = psr_grid[vis_mask & in_mask] # save_image(img, 'tmp.png', normalize=True) @@ -98,78 +96,78 @@ def visualize_psr_grid(psr_grid, pose=None, out_dir=None, out_video_name='video. dir_psr_vis = out_dir os.makedirs(dir_psr_vis, exist_ok=True) psr_grid = psr_grid.squeeze().detach().cpu().numpy() - axis = ['x', 'y', 'z'] s = psr_grid.shape[0] for idx in trange(s): my_dpi = 100 - plt.figure(figsize=(1000/my_dpi, 300/my_dpi), dpi=my_dpi) + plt.figure(figsize=(1000 / my_dpi, 300 / my_dpi), dpi=my_dpi) plt.subplot(1, 3, 1) - plt.imshow(ndimage.rotate(psr_grid[idx], 180, mode='nearest'), cmap='nipy_spectral') + plt.imshow(ndimage.rotate(psr_grid[idx], 180, mode="nearest"), cmap="nipy_spectral") plt.clim(-1, 1) plt.colorbar() - plt.title('x') + plt.title("x") plt.grid("off") plt.axis("off") plt.subplot(1, 3, 2) - plt.imshow(ndimage.rotate(psr_grid[:, idx], 180, mode='nearest'), cmap='nipy_spectral') + plt.imshow(ndimage.rotate(psr_grid[:, idx], 180, mode="nearest"), cmap="nipy_spectral") plt.clim(-1, 1) plt.colorbar() - plt.title('y') + plt.title("y") plt.grid("off") plt.axis("off") plt.subplot(1, 3, 3) - plt.imshow(ndimage.rotate(psr_grid[:,:,idx], 90, mode='nearest'), cmap='nipy_spectral') + plt.imshow(ndimage.rotate(psr_grid[:, :, idx], 90, mode="nearest"), cmap="nipy_spectral") plt.clim(-1, 1) plt.colorbar() - plt.title('z') + plt.title("z") plt.grid("off") plt.axis("off") - - plt.savefig(os.path.join(dir_psr_vis, '{}'.format(idx)), pad_inches = 0, dpi=100) + plt.savefig(os.path.join(dir_psr_vis, f"{idx}"), pad_inches=0, dpi=100) plt.close() - os.system("rm {}/{}".format(dir_psr_vis, out_video_name)) - os.system("ffmpeg -framerate 25 -start_number 0 -i {}/%d.png -pix_fmt yuv420p -crf 17 {}/{}".format(dir_psr_vis, dir_psr_vis, out_video_name)) + os.system(f"rm {dir_psr_vis}/{out_video_name}") + os.system( + f"ffmpeg -framerate 25 -start_number 0 -i {dir_psr_vis}/%d.png -pix_fmt yuv420p -crf 17 {dir_psr_vis}/{out_video_name}", + ) + return None + return None -def visualize_mesh_phong(v, f, n, pose, img_size, name, device='cpu'): + +def visualize_mesh_phong(v, f, n, pose, img_size, name, device="cpu"): #! Mesh rendering using Phong shading model _, mask, f_p, w = calc_inters_points(v, f, pose, img_size) n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]] - n_inters = w[..., 0, None] * n_a.squeeze() + \ - w[..., 1, None] * n_b.squeeze() + \ - w[..., 2, None] * n_c.squeeze() + n_inters = w[..., 0, None] * n_a.squeeze() + w[..., 1, None] * n_b.squeeze() + w[..., 2, None] * n_c.squeeze() n_inters = n_inters.detach().to(device) - light_source = -pose.R@pose.T.squeeze() + light_source = -pose.R @ pose.T.squeeze() light = (light_source / light_source.norm(2)).permute(1, 0).to(device).float() - diffuse_per = torch.Tensor([0.7,0.7,0.7]).float() - ambiant = torch.Tensor([0.3,0.3,0.3]).float() + diffuse_per = torch.Tensor([0.7, 0.7, 0.7]).float() + ambiant = torch.Tensor([0.3, 0.3, 0.3]).float() diffuse = torch.mm(n_inters, light).clamp_min(0).repeat(1, 3) * diffuse_per.unsqueeze(0).to(device) - phong = torch.ones([img_size[0]*img_size[1], 3]).to(device) + phong = torch.ones([img_size[0] * img_size[1], 3]).to(device) phong[mask] = (ambiant.unsqueeze(0).to(device) + diffuse).clamp_max(1.0) pp = phong.reshape(img_size[0], img_size[1], -1) save_image(pp.permute(2, 0, 1), name) + def render_rgb(v, f, n, pose, renderer, img_size, mask_gt=None, ray=None, fea_grid=None): - p_inters, mask, f_p, w = calc_inters_points(v.detach(), f, pose, img_size, mask_gt=mask_gt) - # normals for p_inters - n_inters = None - if n is not None: - n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]] - n_inters = w[..., 0, None] * n_a.squeeze() + \ - w[..., 1, None] * n_b.squeeze() + \ - w[..., 2, None] * n_c.squeeze() - if ray is not None: - ray = ray.squeeze()[mask] - - fea = None - if fea_grid is not None: - fea = grid_interp(fea_grid, (p_inters.detach()[None] + 1) / 2).squeeze() + p_inters, mask, f_p, w = calc_inters_points(v.detach(), f, pose, img_size, mask_gt=mask_gt) + # normals for p_inters + n_inters = None + if n is not None: + n_a, n_b, n_c = n[:, f_p[..., 0]], n[:, f_p[..., 1]], n[:, f_p[..., 2]] + n_inters = w[..., 0, None] * n_a.squeeze() + w[..., 1, None] * n_b.squeeze() + w[..., 2, None] * n_c.squeeze() + if ray is not None: + ray = ray.squeeze()[mask] - # use MLP to regress color - color_pred = renderer(p_inters, normals=n_inters, view_dirs=ray, feature_vectors=fea).squeeze() + fea = None + if fea_grid is not None: + fea = grid_interp(fea_grid, (p_inters.detach()[None] + 1) / 2).squeeze() - return color_pred, mask \ No newline at end of file + # use MLP to regress color + color_pred = renderer(p_inters, normals=n_inters, view_dirs=ray, feature_vectors=fea).squeeze() + + return color_pred, mask diff --git a/train.py b/train.py index 2bc08f8..f890574 100644 --- a/train.py +++ b/train.py @@ -4,182 +4,185 @@ abspath = os.path.abspath(__file__) dname = os.path.dirname(abspath) os.chdir(dname) +import argparse +import shutil +import time + +import numpy as np import torch import torch.optim as optim -import open3d as o3d - -import numpy as np; np.set_printoptions(precision=4) -import shutil, argparse, time from torch.utils.tensorboard import SummaryWriter from src import config -from src.data import collate_remove_none, collate_stack_together, worker_init_fn -from src.training import Trainer +from src.data import collate_remove_none, worker_init_fn from src.model import Encode2Points -from src.utils import load_config, initialize_logger, \ -AverageMeter, load_model_manual +from src.training import Trainer +from src.utils import AverageMeter, initialize_logger, load_config, load_model_manual + +np.set_printoptions(precision=4) def main(): - parser = argparse.ArgumentParser(description='MNIST toy experiment') - parser.add_argument('config', type=str, help='Path to config file.') - parser.add_argument('--no_cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') - + parser = argparse.ArgumentParser(description="MNIST toy experiment") + parser.add_argument("config", type=str, help="Path to config file.") + parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + args = parser.parse_args() - cfg = load_config(args.config, 'configs/default.yaml') + cfg = load_config(args.config, "configs/default.yaml") use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") - input_type = cfg['data']['input_type'] - batch_size = cfg['train']['batch_size'] - model_selection_metric = cfg['train']['model_selection_metric'] + cfg["data"]["input_type"] + batch_size = cfg["train"]["batch_size"] + model_selection_metric = cfg["train"]["model_selection_metric"] # PYTORCH VERSION > 1.0.0 - assert(float(torch.__version__.split('.')[-3]) > 0) + assert float(torch.__version__.split(".")[-3]) > 0 # boiler-plate - if cfg['train']['timestamp']: - cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S") + if cfg["train"]["timestamp"]: + cfg["train"]["out_dir"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S") logger = initialize_logger(cfg) torch.manual_seed(args.seed) np.random.seed(args.seed) - shutil.copyfile(args.config, os.path.join(cfg['train']['out_dir'], 'config.yaml')) + shutil.copyfile(args.config, os.path.join(cfg["train"]["out_dir"], "config.yaml")) logger.info("using GPU: " + torch.cuda.get_device_name(0)) # TensorboardX writer - tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log") + tblogdir = os.path.join(cfg["train"]["out_dir"], "tensorboard_log") if not os.path.exists(tblogdir): os.makedirs(tblogdir, exist_ok=True) writer = SummaryWriter(log_dir=tblogdir) - inputs = None - train_dataset = config.get_dataset('train', cfg) - val_dataset = config.get_dataset('val', cfg) - vis_dataset = config.get_dataset('vis', cfg) + train_dataset = config.get_dataset("train", cfg) + val_dataset = config.get_dataset("val", cfg) + vis_dataset = config.get_dataset("vis", cfg) - collate_fn = collate_remove_none train_loader = torch.utils.data.DataLoader( - train_dataset, batch_size=batch_size, num_workers=cfg['train']['n_workers'], shuffle=True, - collate_fn=collate_fn, - worker_init_fn=worker_init_fn) + train_dataset, + batch_size=batch_size, + num_workers=cfg["train"]["n_workers"], + shuffle=True, + collate_fn=collate_fn, + worker_init_fn=worker_init_fn, + ) val_loader = torch.utils.data.DataLoader( - val_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False, - collate_fn=collate_remove_none, - worker_init_fn=worker_init_fn) + val_dataset, + batch_size=1, + num_workers=cfg["train"]["n_workers_val"], + shuffle=False, + collate_fn=collate_remove_none, + worker_init_fn=worker_init_fn, + ) vis_loader = torch.utils.data.DataLoader( - vis_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False, - collate_fn=collate_fn, - worker_init_fn=worker_init_fn) - + vis_dataset, + batch_size=1, + num_workers=cfg["train"]["n_workers_val"], + shuffle=False, + collate_fn=collate_fn, + worker_init_fn=worker_init_fn, + ) + if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(Encode2Points(cfg)).to(device) else: model = Encode2Points(cfg).to(device) n_parameter = sum(p.numel() for p in model.parameters() if p.requires_grad) - logger.info('Number of parameters: %d'% n_parameter) + logger.info("Number of parameters: %d" % n_parameter) # load model try: # load model - state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt')) - load_model_manual(state_dict['state_dict'], model) - - out = "Load model from iteration %d" % state_dict.get('it', 0) + state_dict = torch.load(os.path.join(cfg["train"]["out_dir"], "model.pt")) + load_model_manual(state_dict["state_dict"], model) + + out = "Load model from iteration %d" % state_dict.get("it", 0) logger.info(out) # load point cloud except: state_dict = dict() - - metric_val_best = state_dict.get( - 'loss_val_best', np.inf) - logger.info('Current best validation metric (%s): %.8f' - % (model_selection_metric, metric_val_best)) + metric_val_best = state_dict.get("loss_val_best", np.inf) - LR = float(cfg['train']['lr']) + logger.info(f"Current best validation metric ({model_selection_metric}): {metric_val_best:.8f}") + + LR = float(cfg["train"]["lr"]) optimizer = optim.Adam(model.parameters(), lr=LR) - start_epoch = state_dict.get('epoch', -1) - it = state_dict.get('it', -1) + start_epoch = state_dict.get("epoch", -1) + it = state_dict.get("it", -1) trainer = Trainer(cfg, optimizer, device=device) runtime = {} - runtime['all'] = AverageMeter() - - # training loop - for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1): + runtime["all"] = AverageMeter() + # training loop + for epoch in range(start_epoch + 1, cfg["train"]["total_epochs"] + 1): for batch in train_loader: it += 1 - + start = time.time() loss, loss_each = trainer.train_step(inputs, batch, model) # measure elapsed time end = time.time() - runtime['all'].update(end - start) + runtime["all"].update(end - start) - if it % cfg['train']['print_every'] == 0: - log_text = ('[Epoch %02d] it=%d, loss=%.4f') %(epoch, it, loss) - writer.add_scalar('train/loss', loss, it) + if it % cfg["train"]["print_every"] == 0: + log_text = ("[Epoch %02d] it=%d, loss=%.4f") % (epoch, it, loss) + writer.add_scalar("train/loss", loss, it) if loss_each is not None: for k, l in loss_each.items(): - if l.item() != 0.: - log_text += (' loss_%s=%.4f') % (k, l.item()) - writer.add_scalar('train/%s' % k, l, it) - - log_text += (' time=%.3f / %.2f') % (runtime['all'].val, runtime['all'].sum) + if l.item() != 0.0: + log_text += f" loss_{k}={l.item():.4f}" + writer.add_scalar("train/%s" % k, l, it) + + log_text += (" time={:.3f} / {:.2f}").format(runtime["all"].val, runtime["all"].sum) logger.info(log_text) - if (it>0)& (it % cfg['train']['visualize_every'] == 0): + if (it > 0) & (it % cfg["train"]["visualize_every"] == 0): for i, batch_vis in enumerate(vis_loader): trainer.save(model, batch_vis, it, i) if i >= 4: break - logger.info('Saved mesh and pointcloud') + logger.info("Saved mesh and pointcloud") # run validation - if it > 0 and (it % cfg['train']['validate_every']) == 0: + if it > 0 and (it % cfg["train"]["validate_every"]) == 0: eval_dict = trainer.evaluate(val_loader, model) metric_val = eval_dict[model_selection_metric] - logger.info('Validation metric (%s): %.4f' - % (model_selection_metric, metric_val)) - - for k, v in eval_dict.items(): - writer.add_scalar('val/%s' % k, v, it) + logger.info(f"Validation metric ({model_selection_metric}): {metric_val:.4f}") - if -(metric_val - metric_val_best) >= 0: + for k, v in eval_dict.items(): + writer.add_scalar("val/%s" % k, v, it) + + if -(metric_val - metric_val_best) >= 0: metric_val_best = metric_val - logger.info('New best model (loss %.4f)' % metric_val_best) - state = {'epoch': epoch, - 'it': it, - 'loss_val_best': metric_val_best} - state['state_dict'] = model.state_dict() - torch.save(state, os.path.join(cfg['train']['out_dir'], 'model_best.pt')) + logger.info("New best model (loss %.4f)" % metric_val_best) + state = {"epoch": epoch, "it": it, "loss_val_best": metric_val_best} + state["state_dict"] = model.state_dict() + torch.save(state, os.path.join(cfg["train"]["out_dir"], "model_best.pt")) # save checkpoint - if (epoch > 0) & (it % cfg['train']['checkpoint_every'] == 0): - state = {'epoch': epoch, - 'it': it, - 'loss_val_best': metric_val_best} - pcl = None - state['state_dict'] = model.state_dict() - - torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt')) + if (epoch > 0) & (it % cfg["train"]["checkpoint_every"] == 0): + state = {"epoch": epoch, "it": it, "loss_val_best": metric_val_best} + state["state_dict"] = model.state_dict() - if (it % cfg['train']['backup_every'] == 0): - torch.save(state, os.path.join(cfg['train']['dir_model'], '%04d' % it + '.pt')) + torch.save(state, os.path.join(cfg["train"]["out_dir"], "model.pt")) + + if it % cfg["train"]["backup_every"] == 0: + torch.save(state, os.path.join(cfg["train"]["dir_model"], "%04d" % it + ".pt")) logger.info("Backup model at iteration %d" % it) logger.info("Save new model at iteration %d" % it) - done=time.time() + time.time() -if __name__ == '__main__': - main() \ No newline at end of file + +if __name__ == "__main__": + main()