Shape-as-Point/optim.py
2023-05-26 14:59:53 +02:00

318 lines
13 KiB
Python

import argparse
import glob
import os
import shutil
import time
import numpy as np
import open3d as o3d
import torch
import trimesh
from plyfile import PlyData
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures import Meshes
from skimage import measure
from torch.utils.tensorboard import SummaryWriter
from src.optimization import Trainer
from src.utils import (
AverageMeter,
adjust_learning_rate,
export_pointcloud,
get_learning_rate_schedules,
initialize_logger,
load_config,
update_config,
update_optimizer,
)
np.set_printoptions(precision=4)
def main():
parser = argparse.ArgumentParser(description="MNIST toy experiment")
parser.add_argument("config", type=str, help="Path to config file.")
parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
parser.add_argument("--seed", type=int, default=1457, metavar="S", help="random seed")
args, unknown = parser.parse_known_args()
cfg = load_config(args.config, "configs/default.yaml")
cfg = update_config(cfg, unknown)
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
data_type = cfg["data"]["data_type"]
cfg["data"]["class"]
print(cfg["train"]["out_dir"])
# PYTORCH VERSION > 1.0.0
assert float(torch.__version__.split(".")[-3]) > 0
# boiler-plate
if cfg["train"]["timestamp"]:
cfg["train"]["out_dir"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
logger = initialize_logger(cfg)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
shutil.copyfile(args.config, os.path.join(cfg["train"]["out_dir"], "config.yaml"))
# tensorboardX writer
tblogdir = os.path.join(cfg["train"]["out_dir"], "tensorboard_log")
if not os.path.exists(tblogdir):
os.makedirs(tblogdir)
SummaryWriter(log_dir=tblogdir)
# initialize o3d visualizer
vis = None
if cfg["train"]["o3d_show"]:
vis = o3d.visualization.Visualizer()
vis.create_window(width=cfg["train"]["o3d_window_size"], height=cfg["train"]["o3d_window_size"])
# initialize dataset
if data_type == "point":
if cfg["data"]["object_id"] != -1:
data_paths = sorted(glob.glob(cfg["data"]["data_path"]))
data_path = data_paths[cfg["data"]["object_id"]]
print("Loaded %d/%d object" % (cfg["data"]["object_id"] + 1, len(data_paths)))
else:
data_path = cfg["data"]["data_path"]
print("Data loaded")
ext = data_path.split(".")[-1]
if ext == "obj": # have GT mesh
mesh = load_objs_as_meshes([data_path], device=device)
# scale the mesh into unit cube
verts = mesh.verts_packed()
N = verts.shape[0]
center = verts.mean(0)
mesh.offset_verts_(-center.expand(N, 3))
scale = max((verts - center).abs().max(0)[0])
mesh.scale_verts_(1.0 / float(scale))
# important for our DPSR to have the range in [0, 1), not reaching 1
mesh.scale_verts_(0.9)
target_pts, target_normals = sample_points_from_meshes(mesh, num_samples=200000, return_normals=True)
elif ext == "ply": # only have the point cloud
plydata = PlyData.read(data_path)
vertices = np.stack([plydata["vertex"]["x"], plydata["vertex"]["y"], plydata["vertex"]["z"]], axis=1)
normals = np.stack([plydata["vertex"]["nx"], plydata["vertex"]["ny"], plydata["vertex"]["nz"]], axis=1)
N = vertices.shape[0]
center = vertices.mean(0)
scale = np.max(np.max(np.abs(vertices - center), axis=0))
vertices -= center
vertices /= scale
vertices *= 0.9
target_pts = torch.tensor(vertices, device=device)[None].float()
target_normals = torch.tensor(normals, device=device)[None].float()
mesh = None # no GT mesh
if not torch.is_tensor(center):
center = torch.from_numpy(center)
if not torch.is_tensor(scale):
scale = torch.from_numpy(np.array([scale]))
data = {
"target_points": target_pts,
"target_normals": target_normals, # normals are never used
"gt_mesh": mesh,
}
else:
raise NotImplementedError
# save the input point cloud
if "target_points" in data.keys():
outdir_pcl = os.path.join(cfg["train"]["out_dir"], "target_pcl.ply")
if "target_normals" in data.keys():
export_pointcloud(outdir_pcl, data["target_points"], data["target_normals"])
else:
export_pointcloud(outdir_pcl, data["target_points"])
# save oracle PSR mesh (mesh from our PSR using GT point+normals)
if data.get("gt_mesh") is not None:
gt_verts, gt_faces = data["gt_mesh"].get_mesh_verts_faces(0)
pts_gt, norms_gt = sample_points_from_meshes(data["gt_mesh"], num_samples=500000, return_normals=True)
pts_gt = (pts_gt + 1) / 2
from src.dpsr import DPSR
dpsr_tmp = DPSR(
res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
sig=cfg["model"]["psr_sigma"],
).to(device)
target = dpsr_tmp(pts_gt, norms_gt).unsqueeze(1).to(device)
target = torch.tanh(target)
s = target.shape[-1] # size of psr_grid
psr_grid_numpy = target.squeeze().detach().cpu().numpy()
verts, faces, _, _ = measure.marching_cubes(psr_grid_numpy)
verts = verts / s * 2.0 - 1 # [-1, 1]
mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(verts)
mesh.triangles = o3d.utility.Vector3iVector(faces)
outdir_mesh = os.path.join(cfg["train"]["out_dir"], "oracle_mesh.ply")
o3d.io.write_triangle_mesh(outdir_mesh, mesh)
# initialize the source point cloud given an input mesh
if "input_mesh" in cfg["train"].keys() and os.path.isfile(cfg["train"]["input_mesh"]):
if cfg["train"]["input_mesh"].split("/")[-2] == "mesh":
mesh_tmp = trimesh.load_mesh(cfg["train"]["input_mesh"])
verts = torch.from_numpy(mesh_tmp.vertices[None]).float().to(device)
faces = torch.from_numpy(mesh_tmp.faces[None]).to(device)
mesh = Meshes(verts=verts, faces=faces)
points, normals = sample_points_from_meshes(
mesh, num_samples=cfg["data"]["num_points"], return_normals=True,
)
# mesh is saved in the original scale of the gt
points -= center.float().to(device)
points /= scale.float().to(device)
points *= 0.9
# make sure the points are within the range of [0, 1)
points = points / 2.0 + 0.5
else:
# directly initialize from a point cloud
pcd = o3d.io.read_point_cloud(cfg["train"]["input_mesh"])
points = torch.from_numpy(np.array(pcd.points)[None]).float().to(device)
normals = torch.from_numpy(np.array(pcd.normals)[None]).float().to(device)
points -= center.float().to(device)
points /= scale.float().to(device)
points *= 0.9
points = points / 2.0 + 0.5
else: #! initialize our source point cloud from a sphere
sphere_radius = cfg["model"]["sphere_radius"]
sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius, count=[256, 256])
points, idx = sphere_mesh.sample(cfg["data"]["num_points"], return_index=True)
points += 0.5 # make sure the points are within the range of [0, 1)
normals = sphere_mesh.face_normals[idx]
points = torch.from_numpy(points).unsqueeze(0).to(device)
normals = torch.from_numpy(normals).unsqueeze(0).to(device)
points = torch.log(points / (1 - points)) # inverse sigmoid
inputs = torch.cat([points, normals], axis=-1).float()
inputs.requires_grad = True
model = None # no network
# initialize optimizer
cfg["train"]["schedule"]["pcl"]["initial"] = cfg["train"]["lr_pcl"]
print("Initial learning rate:", cfg["train"]["schedule"]["pcl"]["initial"])
if "schedule" in cfg["train"]:
lr_schedules = get_learning_rate_schedules(cfg["train"]["schedule"])
else:
lr_schedules = None
optimizer = update_optimizer(inputs, cfg, epoch=0, model=model, schedule=lr_schedules)
try:
# load model
state_dict = torch.load(os.path.join(cfg["train"]["out_dir"], "model.pt"))
if ("pcl" in state_dict.keys()) & (state_dict["pcl"] is not None):
inputs = state_dict["pcl"].to(device)
inputs.requires_grad = True
optimizer = update_optimizer(inputs, cfg, epoch=state_dict.get("epoch"), schedule=lr_schedules)
out = "Load model from epoch %d" % state_dict.get("epoch", 0)
print(out)
logger.info(out)
except:
state_dict = dict()
start_epoch = state_dict.get("epoch", -1)
trainer = Trainer(cfg, optimizer, device=device)
runtime = {}
runtime["all"] = AverageMeter()
# training loop
for epoch in range(start_epoch + 1, cfg["train"]["total_epochs"] + 1):
# schedule the learning rate
if (epoch > 0) & (lr_schedules is not None):
if epoch % lr_schedules[0].interval == 0:
adjust_learning_rate(lr_schedules, optimizer, epoch)
if len(lr_schedules) > 1:
print(
"[epoch {}] net_lr: {}, pcl_lr: {}".format(
epoch, lr_schedules[0].get_learning_rate(epoch), lr_schedules[1].get_learning_rate(epoch),
),
)
else:
print(f"[epoch {epoch}] adjust pcl_lr to: {lr_schedules[0].get_learning_rate(epoch)}")
start = time.time()
loss, loss_each = trainer.train_step(data, inputs, model, epoch)
runtime["all"].update(time.time() - start)
if epoch % cfg["train"]["print_every"] == 0:
log_text = ("[Epoch %02d] loss=%.5f") % (epoch, loss)
if loss_each is not None:
for k, l in loss_each.items():
if l.item() != 0.0:
log_text += f" loss_{k}={l.item():.5f}"
log_text += (" time={:.3f} / {:.3f}").format(runtime["all"].val, runtime["all"].sum)
logger.info(log_text)
print(log_text)
# visualize point clouds and meshes
if (epoch % cfg["train"]["visualize_every"] == 0) & (vis is not None):
trainer.visualize(data, inputs, model, epoch, o3d_vis=vis)
# save outputs
if epoch % cfg["train"]["save_every"] == 0:
trainer.save_mesh_pointclouds(inputs, epoch, center.cpu().numpy(), scale.cpu().numpy() * (1 / 0.9))
# save checkpoints
if (epoch > 0) & (epoch % cfg["train"]["checkpoint_every"] == 0):
state = {"epoch": epoch}
if isinstance(inputs, torch.Tensor):
state["pcl"] = inputs.detach().cpu()
torch.save(state, os.path.join(cfg["train"]["dir_model"], "%04d" % epoch + ".pt"))
print("Save new model at epoch %d" % epoch)
logger.info("Save new model at epoch %d" % epoch)
torch.save(state, os.path.join(cfg["train"]["out_dir"], "model.pt"))
# resample and gradually add new points to the source pcl
if (
(epoch > 0)
& (cfg["train"]["resample_every"] != 0)
& (epoch % cfg["train"]["resample_every"] == 0)
& (epoch < cfg["train"]["total_epochs"])
):
inputs = trainer.point_resampling(inputs)
optimizer = update_optimizer(inputs, cfg, epoch=epoch, model=model, schedule=lr_schedules)
trainer = Trainer(cfg, optimizer, device=device)
# visualize the Open3D outputs
if cfg["train"]["o3d_show"]:
out_video_dir = os.path.join(cfg["train"]["out_dir"], "vis/o3d/video.mp4")
if os.path.isfile(out_video_dir):
os.system(f"rm {out_video_dir}")
os.system(
"ffmpeg -framerate 30 \
-start_number 0 \
-i {}/vis/o3d/%04d.jpg \
-pix_fmt yuv420p \
-crf 17 {}".format(
cfg["train"]["out_dir"], out_video_dir,
),
)
out_video_dir = os.path.join(cfg["train"]["out_dir"], "vis/o3d/video_pcd.mp4")
if os.path.isfile(out_video_dir):
os.system(f"rm {out_video_dir}")
os.system(
"ffmpeg -framerate 30 \
-start_number 0 \
-i {}/vis/o3d/%04d_pcd.jpg \
-pix_fmt yuv420p \
-crf 17 {}".format(
cfg["train"]["out_dir"], out_video_dir,
),
)
print("Video saved.")
if __name__ == "__main__":
main()