Shape-as-Point/optim.py

318 lines
13 KiB
Python
Raw Permalink Normal View History

2023-05-26 12:59:53 +00:00
import argparse
import glob
import os
import shutil
import time
2021-11-08 10:09:50 +00:00
2023-05-26 12:59:53 +00:00
import numpy as np
2021-11-08 10:09:50 +00:00
import open3d as o3d
2023-05-26 12:59:53 +00:00
import torch
import trimesh
2021-11-08 10:09:50 +00:00
from plyfile import PlyData
from pytorch3d.io import load_objs_as_meshes
2023-05-26 12:59:53 +00:00
from pytorch3d.ops import sample_points_from_meshes
2021-11-08 10:09:50 +00:00
from pytorch3d.structures import Meshes
2023-05-26 12:59:53 +00:00
from skimage import measure
from torch.utils.tensorboard import SummaryWriter
from src.optimization import Trainer
from src.utils import (
AverageMeter,
adjust_learning_rate,
export_pointcloud,
get_learning_rate_schedules,
initialize_logger,
load_config,
update_config,
update_optimizer,
)
np.set_printoptions(precision=4)
2021-11-08 10:09:50 +00:00
def main():
2023-05-26 12:59:53 +00:00
parser = argparse.ArgumentParser(description="MNIST toy experiment")
parser.add_argument("config", type=str, help="Path to config file.")
parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
parser.add_argument("--seed", type=int, default=1457, metavar="S", help="random seed")
args, unknown = parser.parse_known_args()
cfg = load_config(args.config, "configs/default.yaml")
2021-11-08 10:09:50 +00:00
cfg = update_config(cfg, unknown)
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
2023-05-26 12:59:53 +00:00
data_type = cfg["data"]["data_type"]
cfg["data"]["class"]
2021-11-08 10:09:50 +00:00
2023-05-26 12:59:53 +00:00
print(cfg["train"]["out_dir"])
2021-11-08 10:09:50 +00:00
# PYTORCH VERSION > 1.0.0
2023-05-26 12:59:53 +00:00
assert float(torch.__version__.split(".")[-3]) > 0
2021-11-08 10:09:50 +00:00
# boiler-plate
2023-05-26 12:59:53 +00:00
if cfg["train"]["timestamp"]:
cfg["train"]["out_dir"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
2021-11-08 10:09:50 +00:00
logger = initialize_logger(cfg)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
2023-05-26 12:59:53 +00:00
shutil.copyfile(args.config, os.path.join(cfg["train"]["out_dir"], "config.yaml"))
2021-11-08 10:09:50 +00:00
# tensorboardX writer
2023-05-26 12:59:53 +00:00
tblogdir = os.path.join(cfg["train"]["out_dir"], "tensorboard_log")
2021-11-08 10:09:50 +00:00
if not os.path.exists(tblogdir):
os.makedirs(tblogdir)
2023-05-26 12:59:53 +00:00
SummaryWriter(log_dir=tblogdir)
2021-11-08 10:09:50 +00:00
# initialize o3d visualizer
vis = None
2023-05-26 12:59:53 +00:00
if cfg["train"]["o3d_show"]:
2021-11-08 10:09:50 +00:00
vis = o3d.visualization.Visualizer()
2023-05-26 12:59:53 +00:00
vis.create_window(width=cfg["train"]["o3d_window_size"], height=cfg["train"]["o3d_window_size"])
2021-11-08 10:09:50 +00:00
# initialize dataset
2023-05-26 12:59:53 +00:00
if data_type == "point":
if cfg["data"]["object_id"] != -1:
data_paths = sorted(glob.glob(cfg["data"]["data_path"]))
data_path = data_paths[cfg["data"]["object_id"]]
print("Loaded %d/%d object" % (cfg["data"]["object_id"] + 1, len(data_paths)))
2021-11-08 10:09:50 +00:00
else:
2023-05-26 12:59:53 +00:00
data_path = cfg["data"]["data_path"]
print("Data loaded")
ext = data_path.split(".")[-1]
if ext == "obj": # have GT mesh
2021-11-08 10:09:50 +00:00
mesh = load_objs_as_meshes([data_path], device=device)
# scale the mesh into unit cube
verts = mesh.verts_packed()
N = verts.shape[0]
center = verts.mean(0)
mesh.offset_verts_(-center.expand(N, 3))
scale = max((verts - center).abs().max(0)[0])
2023-05-26 12:59:53 +00:00
mesh.scale_verts_(1.0 / float(scale))
2021-11-08 10:09:50 +00:00
# important for our DPSR to have the range in [0, 1), not reaching 1
mesh.scale_verts_(0.9)
2023-05-26 12:59:53 +00:00
target_pts, target_normals = sample_points_from_meshes(mesh, num_samples=200000, return_normals=True)
elif ext == "ply": # only have the point cloud
2021-11-08 10:09:50 +00:00
plydata = PlyData.read(data_path)
2023-05-26 12:59:53 +00:00
vertices = np.stack([plydata["vertex"]["x"], plydata["vertex"]["y"], plydata["vertex"]["z"]], axis=1)
normals = np.stack([plydata["vertex"]["nx"], plydata["vertex"]["ny"], plydata["vertex"]["nz"]], axis=1)
2021-11-08 10:09:50 +00:00
N = vertices.shape[0]
center = vertices.mean(0)
scale = np.max(np.max(np.abs(vertices - center), axis=0))
vertices -= center
vertices /= scale
vertices *= 0.9
target_pts = torch.tensor(vertices, device=device)[None].float()
target_normals = torch.tensor(normals, device=device)[None].float()
2023-05-26 12:59:53 +00:00
mesh = None # no GT mesh
2021-11-08 10:09:50 +00:00
if not torch.is_tensor(center):
center = torch.from_numpy(center)
if not torch.is_tensor(scale):
scale = torch.from_numpy(np.array([scale]))
2023-05-26 12:59:53 +00:00
data = {
"target_points": target_pts,
"target_normals": target_normals, # normals are never used
"gt_mesh": mesh,
}
2021-11-08 10:09:50 +00:00
else:
raise NotImplementedError
# save the input point cloud
2023-05-26 12:59:53 +00:00
if "target_points" in data.keys():
outdir_pcl = os.path.join(cfg["train"]["out_dir"], "target_pcl.ply")
if "target_normals" in data.keys():
export_pointcloud(outdir_pcl, data["target_points"], data["target_normals"])
2021-11-08 10:09:50 +00:00
else:
2023-05-26 12:59:53 +00:00
export_pointcloud(outdir_pcl, data["target_points"])
2021-11-08 10:09:50 +00:00
# save oracle PSR mesh (mesh from our PSR using GT point+normals)
2023-05-26 12:59:53 +00:00
if data.get("gt_mesh") is not None:
gt_verts, gt_faces = data["gt_mesh"].get_mesh_verts_faces(0)
pts_gt, norms_gt = sample_points_from_meshes(data["gt_mesh"], num_samples=500000, return_normals=True)
2021-11-08 10:09:50 +00:00
pts_gt = (pts_gt + 1) / 2
from src.dpsr import DPSR
2023-05-26 12:59:53 +00:00
dpsr_tmp = DPSR(
res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
sig=cfg["model"]["psr_sigma"],
).to(device)
2021-11-08 10:09:50 +00:00
target = dpsr_tmp(pts_gt, norms_gt).unsqueeze(1).to(device)
target = torch.tanh(target)
2023-05-26 12:59:53 +00:00
s = target.shape[-1] # size of psr_grid
2021-11-08 10:09:50 +00:00
psr_grid_numpy = target.squeeze().detach().cpu().numpy()
verts, faces, _, _ = measure.marching_cubes(psr_grid_numpy)
2023-05-26 12:59:53 +00:00
verts = verts / s * 2.0 - 1 # [-1, 1]
2021-11-08 10:09:50 +00:00
mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(verts)
mesh.triangles = o3d.utility.Vector3iVector(faces)
2023-05-26 12:59:53 +00:00
outdir_mesh = os.path.join(cfg["train"]["out_dir"], "oracle_mesh.ply")
2021-11-08 10:09:50 +00:00
o3d.io.write_triangle_mesh(outdir_mesh, mesh)
# initialize the source point cloud given an input mesh
2023-05-26 12:59:53 +00:00
if "input_mesh" in cfg["train"].keys() and os.path.isfile(cfg["train"]["input_mesh"]):
if cfg["train"]["input_mesh"].split("/")[-2] == "mesh":
mesh_tmp = trimesh.load_mesh(cfg["train"]["input_mesh"])
2021-11-08 10:09:50 +00:00
verts = torch.from_numpy(mesh_tmp.vertices[None]).float().to(device)
faces = torch.from_numpy(mesh_tmp.faces[None]).to(device)
mesh = Meshes(verts=verts, faces=faces)
2023-05-26 12:59:53 +00:00
points, normals = sample_points_from_meshes(
mesh, num_samples=cfg["data"]["num_points"], return_normals=True,
)
2021-11-08 10:09:50 +00:00
# mesh is saved in the original scale of the gt
points -= center.float().to(device)
points /= scale.float().to(device)
points *= 0.9
# make sure the points are within the range of [0, 1)
2023-05-26 12:59:53 +00:00
points = points / 2.0 + 0.5
2021-11-08 10:09:50 +00:00
else:
# directly initialize from a point cloud
2023-05-26 12:59:53 +00:00
pcd = o3d.io.read_point_cloud(cfg["train"]["input_mesh"])
2021-11-08 10:09:50 +00:00
points = torch.from_numpy(np.array(pcd.points)[None]).float().to(device)
normals = torch.from_numpy(np.array(pcd.normals)[None]).float().to(device)
points -= center.float().to(device)
points /= scale.float().to(device)
points *= 0.9
2023-05-26 12:59:53 +00:00
points = points / 2.0 + 0.5
else: #! initialize our source point cloud from a sphere
sphere_radius = cfg["model"]["sphere_radius"]
sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius, count=[256, 256])
points, idx = sphere_mesh.sample(cfg["data"]["num_points"], return_index=True)
points += 0.5 # make sure the points are within the range of [0, 1)
2021-11-08 10:09:50 +00:00
normals = sphere_mesh.face_normals[idx]
points = torch.from_numpy(points).unsqueeze(0).to(device)
normals = torch.from_numpy(normals).unsqueeze(0).to(device)
2023-05-26 12:59:53 +00:00
points = torch.log(points / (1 - points)) # inverse sigmoid
2021-11-08 10:09:50 +00:00
inputs = torch.cat([points, normals], axis=-1).float()
inputs.requires_grad = True
2023-05-26 12:59:53 +00:00
model = None # no network
2021-11-08 10:09:50 +00:00
# initialize optimizer
2023-05-26 12:59:53 +00:00
cfg["train"]["schedule"]["pcl"]["initial"] = cfg["train"]["lr_pcl"]
print("Initial learning rate:", cfg["train"]["schedule"]["pcl"]["initial"])
if "schedule" in cfg["train"]:
lr_schedules = get_learning_rate_schedules(cfg["train"]["schedule"])
2021-11-08 10:09:50 +00:00
else:
lr_schedules = None
2023-05-26 12:59:53 +00:00
optimizer = update_optimizer(inputs, cfg, epoch=0, model=model, schedule=lr_schedules)
2021-11-08 10:09:50 +00:00
try:
# load model
2023-05-26 12:59:53 +00:00
state_dict = torch.load(os.path.join(cfg["train"]["out_dir"], "model.pt"))
if ("pcl" in state_dict.keys()) & (state_dict["pcl"] is not None):
inputs = state_dict["pcl"].to(device)
2021-11-08 10:09:50 +00:00
inputs.requires_grad = True
2023-05-26 12:59:53 +00:00
optimizer = update_optimizer(inputs, cfg, epoch=state_dict.get("epoch"), schedule=lr_schedules)
out = "Load model from epoch %d" % state_dict.get("epoch", 0)
2021-11-08 10:09:50 +00:00
print(out)
logger.info(out)
except:
state_dict = dict()
2023-05-26 12:59:53 +00:00
start_epoch = state_dict.get("epoch", -1)
2021-11-08 10:09:50 +00:00
trainer = Trainer(cfg, optimizer, device=device)
runtime = {}
2023-05-26 12:59:53 +00:00
runtime["all"] = AverageMeter()
2021-11-08 10:09:50 +00:00
# training loop
2023-05-26 12:59:53 +00:00
for epoch in range(start_epoch + 1, cfg["train"]["total_epochs"] + 1):
2021-11-08 10:09:50 +00:00
# schedule the learning rate
2023-05-26 12:59:53 +00:00
if (epoch > 0) & (lr_schedules is not None):
if epoch % lr_schedules[0].interval == 0:
2021-11-08 10:09:50 +00:00
adjust_learning_rate(lr_schedules, optimizer, epoch)
2023-05-26 12:59:53 +00:00
if len(lr_schedules) > 1:
print(
"[epoch {}] net_lr: {}, pcl_lr: {}".format(
epoch, lr_schedules[0].get_learning_rate(epoch), lr_schedules[1].get_learning_rate(epoch),
),
)
2021-11-08 10:09:50 +00:00
else:
2023-05-26 12:59:53 +00:00
print(f"[epoch {epoch}] adjust pcl_lr to: {lr_schedules[0].get_learning_rate(epoch)}")
2021-11-08 10:09:50 +00:00
start = time.time()
loss, loss_each = trainer.train_step(data, inputs, model, epoch)
2023-05-26 12:59:53 +00:00
runtime["all"].update(time.time() - start)
2021-11-08 10:09:50 +00:00
2023-05-26 12:59:53 +00:00
if epoch % cfg["train"]["print_every"] == 0:
log_text = ("[Epoch %02d] loss=%.5f") % (epoch, loss)
2021-11-08 10:09:50 +00:00
if loss_each is not None:
for k, l in loss_each.items():
2023-05-26 12:59:53 +00:00
if l.item() != 0.0:
log_text += f" loss_{k}={l.item():.5f}"
log_text += (" time={:.3f} / {:.3f}").format(runtime["all"].val, runtime["all"].sum)
2021-11-08 10:09:50 +00:00
logger.info(log_text)
print(log_text)
# visualize point clouds and meshes
2023-05-26 12:59:53 +00:00
if (epoch % cfg["train"]["visualize_every"] == 0) & (vis is not None):
2021-11-08 10:09:50 +00:00
trainer.visualize(data, inputs, model, epoch, o3d_vis=vis)
2023-05-26 12:59:53 +00:00
2021-11-08 10:09:50 +00:00
# save outputs
2023-05-26 12:59:53 +00:00
if epoch % cfg["train"]["save_every"] == 0:
trainer.save_mesh_pointclouds(inputs, epoch, center.cpu().numpy(), scale.cpu().numpy() * (1 / 0.9))
2021-11-08 10:09:50 +00:00
# save checkpoints
2023-05-26 12:59:53 +00:00
if (epoch > 0) & (epoch % cfg["train"]["checkpoint_every"] == 0):
state = {"epoch": epoch}
2021-11-08 10:09:50 +00:00
if isinstance(inputs, torch.Tensor):
2023-05-26 12:59:53 +00:00
state["pcl"] = inputs.detach().cpu()
torch.save(state, os.path.join(cfg["train"]["dir_model"], "%04d" % epoch + ".pt"))
2021-11-08 10:09:50 +00:00
print("Save new model at epoch %d" % epoch)
logger.info("Save new model at epoch %d" % epoch)
2023-05-26 12:59:53 +00:00
torch.save(state, os.path.join(cfg["train"]["out_dir"], "model.pt"))
2021-11-08 10:09:50 +00:00
# resample and gradually add new points to the source pcl
2023-05-26 12:59:53 +00:00
if (
(epoch > 0)
& (cfg["train"]["resample_every"] != 0)
& (epoch % cfg["train"]["resample_every"] == 0)
& (epoch < cfg["train"]["total_epochs"])
):
inputs = trainer.point_resampling(inputs)
optimizer = update_optimizer(inputs, cfg, epoch=epoch, model=model, schedule=lr_schedules)
trainer = Trainer(cfg, optimizer, device=device)
2021-11-08 10:09:50 +00:00
# visualize the Open3D outputs
2023-05-26 12:59:53 +00:00
if cfg["train"]["o3d_show"]:
out_video_dir = os.path.join(cfg["train"]["out_dir"], "vis/o3d/video.mp4")
2021-11-08 10:09:50 +00:00
if os.path.isfile(out_video_dir):
2023-05-26 12:59:53 +00:00
os.system(f"rm {out_video_dir}")
os.system(
"ffmpeg -framerate 30 \
2021-11-08 10:09:50 +00:00
-start_number 0 \
-i {}/vis/o3d/%04d.jpg \
-pix_fmt yuv420p \
2023-05-26 12:59:53 +00:00
-crf 17 {}".format(
cfg["train"]["out_dir"], out_video_dir,
),
)
out_video_dir = os.path.join(cfg["train"]["out_dir"], "vis/o3d/video_pcd.mp4")
2021-11-08 10:09:50 +00:00
if os.path.isfile(out_video_dir):
2023-05-26 12:59:53 +00:00
os.system(f"rm {out_video_dir}")
os.system(
"ffmpeg -framerate 30 \
2021-11-08 10:09:50 +00:00
-start_number 0 \
-i {}/vis/o3d/%04d_pcd.jpg \
-pix_fmt yuv420p \
2023-05-26 12:59:53 +00:00
-crf 17 {}".format(
cfg["train"]["out_dir"], out_video_dir,
),
)
print("Video saved.")
2021-11-08 10:09:50 +00:00
2023-05-26 12:59:53 +00:00
if __name__ == "__main__":
2021-11-08 10:09:50 +00:00
main()