Shape-as-Point/optim.py
2021-11-08 11:09:50 +01:00

316 lines
14 KiB
Python

import torch
import trimesh
import shutil, argparse, time, os, glob
import numpy as np; np.set_printoptions(precision=4)
import open3d as o3d
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image
from torchvision.io import write_video
from src.optimization import Trainer
from src.utils import load_config, update_config, initialize_logger, \
get_learning_rate_schedules, adjust_learning_rate, AverageMeter,\
update_optimizer, export_pointcloud
from skimage import measure
from plyfile import PlyData
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.structures import Meshes
def main():
parser = argparse.ArgumentParser(description='MNIST toy experiment')
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--no_cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1457, metavar='S',
help='random seed')
args, unknown = parser.parse_known_args()
cfg = load_config(args.config, 'configs/default.yaml')
cfg = update_config(cfg, unknown)
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
data_type = cfg['data']['data_type']
data_class = cfg['data']['class']
print(cfg['train']['out_dir'])
# PYTORCH VERSION > 1.0.0
assert(float(torch.__version__.split('.')[-3]) > 0)
# boiler-plate
if cfg['train']['timestamp']:
cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S")
logger = initialize_logger(cfg)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
shutil.copyfile(args.config,
os.path.join(cfg['train']['out_dir'], 'config.yaml'))
# tensorboardX writer
tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log")
if not os.path.exists(tblogdir):
os.makedirs(tblogdir)
writer = SummaryWriter(log_dir=tblogdir)
# initialize o3d visualizer
vis = None
if cfg['train']['o3d_show']:
vis = o3d.visualization.Visualizer()
vis.create_window(width=cfg['train']['o3d_window_size'],
height=cfg['train']['o3d_window_size'])
# initialize dataset
if data_type == 'point':
if cfg['data']['object_id'] != -1:
data_paths = sorted(glob.glob(cfg['data']['data_path']))
data_path = data_paths[cfg['data']['object_id']]
print('Loaded %d/%d object' % (cfg['data']['object_id']+1, len(data_paths)))
else:
data_path = cfg['data']['data_path']
print('Data loaded')
ext = data_path.split('.')[-1]
if ext == 'obj': # have GT mesh
mesh = load_objs_as_meshes([data_path], device=device)
# scale the mesh into unit cube
verts = mesh.verts_packed()
N = verts.shape[0]
center = verts.mean(0)
mesh.offset_verts_(-center.expand(N, 3))
scale = max((verts - center).abs().max(0)[0])
mesh.scale_verts_((1.0 / float(scale)))
# important for our DPSR to have the range in [0, 1), not reaching 1
mesh.scale_verts_(0.9)
target_pts, target_normals = sample_points_from_meshes(mesh,
num_samples=200000, return_normals=True)
elif ext == 'ply': # only have the point cloud
plydata = PlyData.read(data_path)
vertices = np.stack([plydata['vertex']['x'],
plydata['vertex']['y'],
plydata['vertex']['z']], axis=1)
normals = np.stack([plydata['vertex']['nx'],
plydata['vertex']['ny'],
plydata['vertex']['nz']], axis=1)
N = vertices.shape[0]
center = vertices.mean(0)
scale = np.max(np.max(np.abs(vertices - center), axis=0))
vertices -= center
vertices /= scale
vertices *= 0.9
target_pts = torch.tensor(vertices, device=device)[None].float()
target_normals = torch.tensor(normals, device=device)[None].float()
mesh = None # no GT mesh
if not torch.is_tensor(center):
center = torch.from_numpy(center)
if not torch.is_tensor(scale):
scale = torch.from_numpy(np.array([scale]))
data = {'target_points': target_pts,
'target_normals': target_normals, # normals are never used
'gt_mesh': mesh}
else:
raise NotImplementedError
# save the input point cloud
if 'target_points' in data.keys():
outdir_pcl = os.path.join(cfg['train']['out_dir'], 'target_pcl.ply')
if 'target_normals' in data.keys():
export_pointcloud(outdir_pcl, data['target_points'], data['target_normals'])
else:
export_pointcloud(outdir_pcl, data['target_points'])
# save oracle PSR mesh (mesh from our PSR using GT point+normals)
if data.get('gt_mesh') is not None:
gt_verts, gt_faces = data['gt_mesh'].get_mesh_verts_faces(0)
pts_gt, norms_gt = sample_points_from_meshes(data['gt_mesh'],
num_samples=500000, return_normals=True)
pts_gt = (pts_gt + 1) / 2
from src.dpsr import DPSR
dpsr_tmp = DPSR(res=(cfg['model']['grid_res'],
cfg['model']['grid_res'],
cfg['model']['grid_res']),
sig=cfg['model']['psr_sigma']).to(device)
target = dpsr_tmp(pts_gt, norms_gt).unsqueeze(1).to(device)
target = torch.tanh(target)
s = target.shape[-1] # size of psr_grid
psr_grid_numpy = target.squeeze().detach().cpu().numpy()
verts, faces, _, _ = measure.marching_cubes(psr_grid_numpy)
verts = verts / s * 2. - 1 # [-1, 1]
mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(verts)
mesh.triangles = o3d.utility.Vector3iVector(faces)
outdir_mesh = os.path.join(cfg['train']['out_dir'], 'oracle_mesh.ply')
o3d.io.write_triangle_mesh(outdir_mesh, mesh)
# initialize the source point cloud given an input mesh
if 'input_mesh' in cfg['train'].keys() and \
os.path.isfile(cfg['train']['input_mesh']):
if cfg['train']['input_mesh'].split('/')[-2] == 'mesh':
mesh_tmp = trimesh.load_mesh(cfg['train']['input_mesh'])
verts = torch.from_numpy(mesh_tmp.vertices[None]).float().to(device)
faces = torch.from_numpy(mesh_tmp.faces[None]).to(device)
mesh = Meshes(verts=verts, faces=faces)
points, normals = sample_points_from_meshes(mesh,
num_samples=cfg['data']['num_points'], return_normals=True)
# mesh is saved in the original scale of the gt
points -= center.float().to(device)
points /= scale.float().to(device)
points *= 0.9
# make sure the points are within the range of [0, 1)
points = points / 2. + 0.5
else:
# directly initialize from a point cloud
pcd = o3d.io.read_point_cloud(cfg['train']['input_mesh'])
points = torch.from_numpy(np.array(pcd.points)[None]).float().to(device)
normals = torch.from_numpy(np.array(pcd.normals)[None]).float().to(device)
points -= center.float().to(device)
points /= scale.float().to(device)
points *= 0.9
points = points / 2. + 0.5
else: #! initialize our source point cloud from a sphere
sphere_radius = cfg['model']['sphere_radius']
sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius,
count=[256,256])
points, idx = sphere_mesh.sample(cfg['data']['num_points'],
return_index=True)
points += 0.5 # make sure the points are within the range of [0, 1)
normals = sphere_mesh.face_normals[idx]
points = torch.from_numpy(points).unsqueeze(0).to(device)
normals = torch.from_numpy(normals).unsqueeze(0).to(device)
points = torch.log(points/(1-points)) # inverse sigmoid
inputs = torch.cat([points, normals], axis=-1).float()
inputs.requires_grad = True
model = None # no network
# initialize optimizer
cfg['train']['schedule']['pcl']['initial'] = cfg['train']['lr_pcl']
print('Initial learning rate:', cfg['train']['schedule']['pcl']['initial'])
if 'schedule' in cfg['train']:
lr_schedules = get_learning_rate_schedules(cfg['train']['schedule'])
else:
lr_schedules = None
optimizer = update_optimizer(inputs, cfg,
epoch=0, model=model, schedule=lr_schedules)
try:
# load model
state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt'))
if ('pcl' in state_dict.keys()) & (state_dict['pcl'] is not None):
inputs = state_dict['pcl'].to(device)
inputs.requires_grad = True
optimizer = update_optimizer(inputs, cfg,
epoch=state_dict.get('epoch'), schedule=lr_schedules)
out = "Load model from epoch %d" % state_dict.get('epoch', 0)
print(out)
logger.info(out)
except:
state_dict = dict()
start_epoch = state_dict.get('epoch', -1)
trainer = Trainer(cfg, optimizer, device=device)
runtime = {}
runtime['all'] = AverageMeter()
# training loop
for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1):
# schedule the learning rate
if (epoch>0) & (lr_schedules is not None):
if (epoch % lr_schedules[0].interval == 0):
adjust_learning_rate(lr_schedules, optimizer, epoch)
if len(lr_schedules) >1:
print('[epoch {}] net_lr: {}, pcl_lr: {}'.format(epoch,
lr_schedules[0].get_learning_rate(epoch),
lr_schedules[1].get_learning_rate(epoch)))
else:
print('[epoch {}] adjust pcl_lr to: {}'.format(epoch,
lr_schedules[0].get_learning_rate(epoch)))
start = time.time()
loss, loss_each = trainer.train_step(data, inputs, model, epoch)
runtime['all'].update(time.time() - start)
if epoch % cfg['train']['print_every'] == 0:
log_text = ('[Epoch %02d] loss=%.5f') %(epoch, loss)
if loss_each is not None:
for k, l in loss_each.items():
if l.item() != 0.:
log_text += (' loss_%s=%.5f') % (k, l.item())
log_text += (' time=%.3f / %.3f') % (runtime['all'].val,
runtime['all'].sum)
logger.info(log_text)
print(log_text)
# visualize point clouds and meshes
if (epoch % cfg['train']['visualize_every'] == 0) & (vis is not None):
trainer.visualize(data, inputs, model, epoch, o3d_vis=vis)
# save outputs
if epoch % cfg['train']['save_every'] == 0:
trainer.save_mesh_pointclouds(inputs, epoch,
center.cpu().numpy(),
scale.cpu().numpy()*(1/0.9))
# save checkpoints
if (epoch > 0) & (epoch % cfg['train']['checkpoint_every'] == 0):
state = {'epoch': epoch}
pcl = None
if isinstance(inputs, torch.Tensor):
state['pcl'] = inputs.detach().cpu()
torch.save(state, os.path.join(cfg['train']['dir_model'],
'%04d' % epoch + '.pt'))
print("Save new model at epoch %d" % epoch)
logger.info("Save new model at epoch %d" % epoch)
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt'))
# resample and gradually add new points to the source pcl
if (epoch > 0) & \
(cfg['train']['resample_every']!=0) & \
(epoch % cfg['train']['resample_every'] == 0) & \
(epoch < cfg['train']['total_epochs']):
inputs = trainer.point_resampling(inputs)
optimizer = update_optimizer(inputs, cfg,
epoch=epoch, model=model, schedule=lr_schedules)
trainer = Trainer(cfg, optimizer, device=device)
# visualize the Open3D outputs
if cfg['train']['o3d_show']:
out_video_dir = os.path.join(cfg['train']['out_dir'],
'vis/o3d/video.mp4')
if os.path.isfile(out_video_dir):
os.system('rm {}'.format(out_video_dir))
os.system('ffmpeg -framerate 30 \
-start_number 0 \
-i {}/vis/o3d/%04d.jpg \
-pix_fmt yuv420p \
-crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir))
out_video_dir = os.path.join(cfg['train']['out_dir'],
'vis/o3d/video_pcd.mp4')
if os.path.isfile(out_video_dir):
os.system('rm {}'.format(out_video_dir))
os.system('ffmpeg -framerate 30 \
-start_number 0 \
-i {}/vis/o3d/%04d_pcd.jpg \
-pix_fmt yuv420p \
-crf 17 {}'.format(cfg['train']['out_dir'], out_video_dir))
print('Video saved.')
if __name__ == '__main__':
main()