import torch import trimesh import shutil, argparse, time, os, glob import numpy as np; np.set_printoptions(precision=4) import open3d as o3d from torch.utils.tensorboard import SummaryWriter from torchvision.utils import save_image from torchvision.io import write_video from src.optimization import Trainer from src.utils import load_config, update_config, initialize_logger, \ get_learning_rate_schedules, adjust_learning_rate, AverageMeter,\ update_optimizer, export_pointcloud from skimage import measure from plyfile import PlyData from pytorch3d.ops import sample_points_from_meshes from pytorch3d.io import load_objs_as_meshes from pytorch3d.structures import Meshes def main(): parser = argparse.ArgumentParser(description='MNIST toy experiment') parser.add_argument('config', type=str, help='Path to config file.') parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1457, metavar='S', help='random seed') args, unknown = parser.parse_known_args() cfg = load_config(args.config, 'configs/default.yaml') cfg = update_config(cfg, unknown) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") data_type = cfg['data']['data_type'] data_class = cfg['data']['class'] print(cfg['train']['out_dir']) # PYTORCH VERSION > 1.0.0 assert(float(torch.__version__.split('.')[-3]) > 0) # boiler-plate if cfg['train']['timestamp']: cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S") logger = initialize_logger(cfg) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) np.random.seed(args.seed) shutil.copyfile(args.config, os.path.join(cfg['train']['out_dir'], 'config.yaml')) # tensorboardX writer tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log") if not os.path.exists(tblogdir): os.makedirs(tblogdir) writer = SummaryWriter(log_dir=tblogdir) # initialize o3d visualizer vis = None if cfg['train']['o3d_show']: vis = o3d.visualization.Visualizer() vis.create_window(width=cfg['train']['o3d_window_size'], height=cfg['train']['o3d_window_size']) # initialize dataset if data_type == 'point': if cfg['data']['object_id'] != -1: data_paths = sorted(glob.glob(cfg['data']['data_path'])) data_path = data_paths[cfg['data']['object_id']] print('Loaded %d/%d object' % (cfg['data']['object_id']+1, len(data_paths))) else: data_path = cfg['data']['data_path'] print('Data loaded') ext = data_path.split('.')[-1] if ext == 'obj': # have GT mesh mesh = load_objs_as_meshes([data_path], device=device) # scale the mesh into unit cube verts = mesh.verts_packed() N = verts.shape[0] center = verts.mean(0) mesh.offset_verts_(-center.expand(N, 3)) scale = max((verts - center).abs().max(0)[0]) mesh.scale_verts_((1.0 / float(scale))) # important for our DPSR to have the range in [0, 1), not reaching 1 mesh.scale_verts_(0.9) target_pts, target_normals = sample_points_from_meshes(mesh, num_samples=200000, return_normals=True) elif ext == 'ply': # only have the point cloud plydata = PlyData.read(data_path) vertices = np.stack([plydata['vertex']['x'], plydata['vertex']['y'], plydata['vertex']['z']], axis=1) normals = np.stack([plydata['vertex']['nx'], plydata['vertex']['ny'], plydata['vertex']['nz']], axis=1) N = vertices.shape[0] center = vertices.mean(0) scale = np.max(np.max(np.abs(vertices - center), axis=0)) vertices -= center vertices /= scale vertices *= 0.9 target_pts = torch.tensor(vertices, device=device)[None].float() target_normals = torch.tensor(normals, device=device)[None].float() mesh = None # no GT mesh if not torch.is_tensor(center): center = torch.from_numpy(center) if not torch.is_tensor(scale): scale = torch.from_numpy(np.array([scale])) data = {'target_points': target_pts, 'target_normals': target_normals, # normals are never used 'gt_mesh': mesh} else: raise NotImplementedError # save the input point cloud if 'target_points' in data.keys(): outdir_pcl = os.path.join(cfg['train']['out_dir'], 'target_pcl.ply') if 'target_normals' in data.keys(): export_pointcloud(outdir_pcl, data['target_points'], data['target_normals']) else: export_pointcloud(outdir_pcl, data['target_points']) # save oracle PSR mesh (mesh from our PSR using GT point+normals) if data.get('gt_mesh') is not None: gt_verts, gt_faces = data['gt_mesh'].get_mesh_verts_faces(0) pts_gt, norms_gt = sample_points_from_meshes(data['gt_mesh'], num_samples=500000, return_normals=True) pts_gt = (pts_gt + 1) / 2 from src.dpsr import DPSR dpsr_tmp = DPSR(res=(cfg['model']['grid_res'], cfg['model']['grid_res'], cfg['model']['grid_res']), sig=cfg['model']['psr_sigma']).to(device) target = dpsr_tmp(pts_gt, norms_gt).unsqueeze(1).to(device) target = torch.tanh(target) s = target.shape[-1] # size of psr_grid psr_grid_numpy = target.squeeze().detach().cpu().numpy() verts, faces, _, _ = measure.marching_cubes(psr_grid_numpy) verts = verts / s * 2. - 1 # [-1, 1] mesh = o3d.geometry.TriangleMesh() mesh.vertices = o3d.utility.Vector3dVector(verts) mesh.triangles = o3d.utility.Vector3iVector(faces) outdir_mesh = os.path.join(cfg['train']['out_dir'], 'oracle_mesh.ply') o3d.io.write_triangle_mesh(outdir_mesh, mesh) # initialize the source point cloud given an input mesh if 'input_mesh' in cfg['train'].keys() and \ os.path.isfile(cfg['train']['input_mesh']): if cfg['train']['input_mesh'].split('/')[-2] == 'mesh': mesh_tmp = trimesh.load_mesh(cfg['train']['input_mesh']) verts = torch.from_numpy(mesh_tmp.vertices[None]).float().to(device) faces = torch.from_numpy(mesh_tmp.faces[None]).to(device) mesh = Meshes(verts=verts, faces=faces) points, normals = sample_points_from_meshes(mesh, num_samples=cfg['data']['num_points'], return_normals=True) # mesh is saved in the original scale of the gt points -= center.float().to(device) points /= scale.float().to(device) points *= 0.9 # make sure the points are within the range of [0, 1) points = points / 2. + 0.5 else: # directly initialize from a point cloud pcd = o3d.io.read_point_cloud(cfg['train']['input_mesh']) points = torch.from_numpy(np.array(pcd.points)[None]).float().to(device) normals = torch.from_numpy(np.array(pcd.normals)[None]).float().to(device) points -= center.float().to(device) points /= scale.float().to(device) points *= 0.9 points = points / 2. + 0.5 else: #! initialize our source point cloud from a sphere sphere_radius = cfg['model']['sphere_radius'] sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius, count=[256,256]) points, idx = sphere_mesh.sample(cfg['data']['num_points'], return_index=True) points += 0.5 # make sure the points are within the range of [0, 1) normals = sphere_mesh.face_normals[idx] points = torch.from_numpy(points).unsqueeze(0).to(device) normals = torch.from_numpy(normals).unsqueeze(0).to(device) points = torch.log(points/(1-points)) # inverse sigmoid inputs = torch.cat([points, normals], axis=-1).float() inputs.requires_grad = True model = None # no network # initialize optimizer cfg['train']['schedule']['pcl']['initial'] = cfg['train']['lr_pcl'] print('Initial learning rate:', cfg['train']['schedule']['pcl']['initial']) if 'schedule' in cfg['train']: lr_schedules = get_learning_rate_schedules(cfg['train']['schedule']) else: lr_schedules = None optimizer = update_optimizer(inputs, cfg, epoch=0, model=model, schedule=lr_schedules) try: # load model state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt')) if ('pcl' in state_dict.keys()) & (state_dict['pcl'] is not None): inputs = state_dict['pcl'].to(device) inputs.requires_grad = True optimizer = update_optimizer(inputs, cfg, epoch=state_dict.get('epoch'), schedule=lr_schedules) out = "Load model from epoch %d" % state_dict.get('epoch', 0) print(out) logger.info(out) except: state_dict = dict() start_epoch = state_dict.get('epoch', -1) trainer = Trainer(cfg, optimizer, device=device) runtime = {} runtime['all'] = AverageMeter() # training loop for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1): # schedule the learning rate if (epoch>0) & (lr_schedules is not None): if (epoch % lr_schedules[0].interval == 0): adjust_learning_rate(lr_schedules, optimizer, epoch) if len(lr_schedules) >1: print('[epoch {}] net_lr: {}, pcl_lr: {}'.format(epoch, lr_schedules[0].get_learning_rate(epoch), lr_schedules[1].get_learning_rate(epoch))) else: print('[epoch {}] adjust pcl_lr to: {}'.format(epoch, lr_schedules[0].get_learning_rate(epoch))) start = time.time() loss, loss_each = trainer.train_step(data, inputs, model, epoch) runtime['all'].update(time.time() - start) if epoch % cfg['train']['print_every'] == 0: log_text = ('[Epoch %02d] loss=%.5f') %(epoch, loss) if loss_each is not None: for k, l in loss_each.items(): if l.item() != 0.: log_text += (' loss_%s=%.5f') % (k, l.item()) log_text += (' time=%.3f / %.3f') % (runtime['all'].val, runtime['all'].sum) logger.info(log_text) print(log_text) # visualize point clouds and meshes if (epoch % cfg['train']['visualize_every'] == 0) & (vis is not None): trainer.visualize(data, inputs, model, epoch, o3d_vis=vis) # save outputs if epoch % cfg['train']['save_every'] == 0: trainer.save_mesh_pointclouds(inputs, epoch, center.cpu().numpy(), scale.cpu().numpy()*(1/0.9)) # save checkpoints if (epoch > 0) & (epoch % cfg['train']['checkpoint_every'] == 0): state = {'epoch': epoch} pcl = None if isinstance(inputs, torch.Tensor): state['pcl'] = inputs.detach().cpu() torch.save(state, os.path.join(cfg['train']['dir_model'], '%04d' % epoch + '.pt')) print("Save new model at epoch %d" % epoch) logger.info("Save new model at epoch %d" % epoch) torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt')) # resample and gradually add new points to the source pcl if (epoch > 0) & \ (cfg['train']['resample_every']!=0) & \ (epoch % cfg['train']['resample_every'] == 0) & \ (epoch < cfg['train']['total_epochs']): inputs = trainer.point_resampling(inputs) optimizer = update_optimizer(inputs, cfg, epoch=epoch, model=model, schedule=lr_schedules) trainer = Trainer(cfg, optimizer, device=device) # visualize the Open3D outputs if cfg['train']['o3d_show']: out_video_dir = os.path.join(cfg['train']['out_dir'], 'vis/o3d/video.mp4') if os.path.isfile(out_video_dir): os.system('rm {}'.format(out_video_dir)) os.system('ffmpeg -framerate 30 \ -start_number 0 \ -i {}/vis/o3d/%04d.jpg \ -pix_fmt yuv420p \ -crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir)) out_video_dir = os.path.join(cfg['train']['out_dir'], 'vis/o3d/video_pcd.mp4') if os.path.isfile(out_video_dir): os.system('rm {}'.format(out_video_dir)) os.system('ffmpeg -framerate 30 \ -start_number 0 \ -i {}/vis/o3d/%04d_pcd.jpg \ -pix_fmt yuv420p \ -crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir)) print('Video saved.') if __name__ == '__main__': main()