2023-05-26 12:59:53 +00:00
|
|
|
import argparse
|
|
|
|
import glob
|
|
|
|
import os
|
|
|
|
import shutil
|
|
|
|
import time
|
2021-11-08 10:09:50 +00:00
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
import numpy as np
|
2021-11-08 10:09:50 +00:00
|
|
|
import open3d as o3d
|
2023-05-26 12:59:53 +00:00
|
|
|
import torch
|
|
|
|
import trimesh
|
2021-11-08 10:09:50 +00:00
|
|
|
from plyfile import PlyData
|
|
|
|
from pytorch3d.io import load_objs_as_meshes
|
2023-05-26 12:59:53 +00:00
|
|
|
from pytorch3d.ops import sample_points_from_meshes
|
2021-11-08 10:09:50 +00:00
|
|
|
from pytorch3d.structures import Meshes
|
2023-05-26 12:59:53 +00:00
|
|
|
from skimage import measure
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
|
|
from src.optimization import Trainer
|
|
|
|
from src.utils import (
|
|
|
|
AverageMeter,
|
|
|
|
adjust_learning_rate,
|
|
|
|
export_pointcloud,
|
|
|
|
get_learning_rate_schedules,
|
|
|
|
initialize_logger,
|
|
|
|
load_config,
|
|
|
|
update_config,
|
|
|
|
update_optimizer,
|
|
|
|
)
|
|
|
|
|
|
|
|
np.set_printoptions(precision=4)
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
def main():
|
2023-05-26 12:59:53 +00:00
|
|
|
parser = argparse.ArgumentParser(description="MNIST toy experiment")
|
|
|
|
parser.add_argument("config", type=str, help="Path to config file.")
|
|
|
|
parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
|
|
|
|
parser.add_argument("--seed", type=int, default=1457, metavar="S", help="random seed")
|
|
|
|
|
|
|
|
args, unknown = parser.parse_known_args()
|
|
|
|
cfg = load_config(args.config, "configs/default.yaml")
|
2021-11-08 10:09:50 +00:00
|
|
|
cfg = update_config(cfg, unknown)
|
|
|
|
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
|
|
|
device = torch.device("cuda" if use_cuda else "cpu")
|
2023-05-26 12:59:53 +00:00
|
|
|
data_type = cfg["data"]["data_type"]
|
|
|
|
cfg["data"]["class"]
|
2021-11-08 10:09:50 +00:00
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
print(cfg["train"]["out_dir"])
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
# PYTORCH VERSION > 1.0.0
|
2023-05-26 12:59:53 +00:00
|
|
|
assert float(torch.__version__.split(".")[-3]) > 0
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
# boiler-plate
|
2023-05-26 12:59:53 +00:00
|
|
|
if cfg["train"]["timestamp"]:
|
|
|
|
cfg["train"]["out_dir"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
|
2021-11-08 10:09:50 +00:00
|
|
|
logger = initialize_logger(cfg)
|
|
|
|
torch.manual_seed(args.seed)
|
|
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
|
|
np.random.seed(args.seed)
|
2023-05-26 12:59:53 +00:00
|
|
|
shutil.copyfile(args.config, os.path.join(cfg["train"]["out_dir"], "config.yaml"))
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
# tensorboardX writer
|
2023-05-26 12:59:53 +00:00
|
|
|
tblogdir = os.path.join(cfg["train"]["out_dir"], "tensorboard_log")
|
2021-11-08 10:09:50 +00:00
|
|
|
if not os.path.exists(tblogdir):
|
|
|
|
os.makedirs(tblogdir)
|
2023-05-26 12:59:53 +00:00
|
|
|
SummaryWriter(log_dir=tblogdir)
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
# initialize o3d visualizer
|
|
|
|
vis = None
|
2023-05-26 12:59:53 +00:00
|
|
|
if cfg["train"]["o3d_show"]:
|
2021-11-08 10:09:50 +00:00
|
|
|
vis = o3d.visualization.Visualizer()
|
2023-05-26 12:59:53 +00:00
|
|
|
vis.create_window(width=cfg["train"]["o3d_window_size"], height=cfg["train"]["o3d_window_size"])
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
# initialize dataset
|
2023-05-26 12:59:53 +00:00
|
|
|
if data_type == "point":
|
|
|
|
if cfg["data"]["object_id"] != -1:
|
|
|
|
data_paths = sorted(glob.glob(cfg["data"]["data_path"]))
|
|
|
|
data_path = data_paths[cfg["data"]["object_id"]]
|
|
|
|
print("Loaded %d/%d object" % (cfg["data"]["object_id"] + 1, len(data_paths)))
|
2021-11-08 10:09:50 +00:00
|
|
|
else:
|
2023-05-26 12:59:53 +00:00
|
|
|
data_path = cfg["data"]["data_path"]
|
|
|
|
print("Data loaded")
|
|
|
|
ext = data_path.split(".")[-1]
|
|
|
|
if ext == "obj": # have GT mesh
|
2021-11-08 10:09:50 +00:00
|
|
|
mesh = load_objs_as_meshes([data_path], device=device)
|
|
|
|
# scale the mesh into unit cube
|
|
|
|
verts = mesh.verts_packed()
|
|
|
|
N = verts.shape[0]
|
|
|
|
center = verts.mean(0)
|
|
|
|
mesh.offset_verts_(-center.expand(N, 3))
|
|
|
|
scale = max((verts - center).abs().max(0)[0])
|
2023-05-26 12:59:53 +00:00
|
|
|
mesh.scale_verts_(1.0 / float(scale))
|
2021-11-08 10:09:50 +00:00
|
|
|
# important for our DPSR to have the range in [0, 1), not reaching 1
|
|
|
|
mesh.scale_verts_(0.9)
|
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
target_pts, target_normals = sample_points_from_meshes(mesh, num_samples=200000, return_normals=True)
|
|
|
|
elif ext == "ply": # only have the point cloud
|
2021-11-08 10:09:50 +00:00
|
|
|
plydata = PlyData.read(data_path)
|
2023-05-26 12:59:53 +00:00
|
|
|
vertices = np.stack([plydata["vertex"]["x"], plydata["vertex"]["y"], plydata["vertex"]["z"]], axis=1)
|
|
|
|
normals = np.stack([plydata["vertex"]["nx"], plydata["vertex"]["ny"], plydata["vertex"]["nz"]], axis=1)
|
2021-11-08 10:09:50 +00:00
|
|
|
N = vertices.shape[0]
|
|
|
|
center = vertices.mean(0)
|
|
|
|
scale = np.max(np.max(np.abs(vertices - center), axis=0))
|
|
|
|
vertices -= center
|
|
|
|
vertices /= scale
|
|
|
|
vertices *= 0.9
|
|
|
|
|
|
|
|
target_pts = torch.tensor(vertices, device=device)[None].float()
|
|
|
|
target_normals = torch.tensor(normals, device=device)[None].float()
|
2023-05-26 12:59:53 +00:00
|
|
|
mesh = None # no GT mesh
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
if not torch.is_tensor(center):
|
|
|
|
center = torch.from_numpy(center)
|
|
|
|
if not torch.is_tensor(scale):
|
|
|
|
scale = torch.from_numpy(np.array([scale]))
|
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
data = {
|
|
|
|
"target_points": target_pts,
|
|
|
|
"target_normals": target_normals, # normals are never used
|
|
|
|
"gt_mesh": mesh,
|
|
|
|
}
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
# save the input point cloud
|
2023-05-26 12:59:53 +00:00
|
|
|
if "target_points" in data.keys():
|
|
|
|
outdir_pcl = os.path.join(cfg["train"]["out_dir"], "target_pcl.ply")
|
|
|
|
if "target_normals" in data.keys():
|
|
|
|
export_pointcloud(outdir_pcl, data["target_points"], data["target_normals"])
|
2021-11-08 10:09:50 +00:00
|
|
|
else:
|
2023-05-26 12:59:53 +00:00
|
|
|
export_pointcloud(outdir_pcl, data["target_points"])
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
# save oracle PSR mesh (mesh from our PSR using GT point+normals)
|
2023-05-26 12:59:53 +00:00
|
|
|
if data.get("gt_mesh") is not None:
|
|
|
|
gt_verts, gt_faces = data["gt_mesh"].get_mesh_verts_faces(0)
|
|
|
|
pts_gt, norms_gt = sample_points_from_meshes(data["gt_mesh"], num_samples=500000, return_normals=True)
|
2021-11-08 10:09:50 +00:00
|
|
|
pts_gt = (pts_gt + 1) / 2
|
|
|
|
from src.dpsr import DPSR
|
2023-05-26 12:59:53 +00:00
|
|
|
|
|
|
|
dpsr_tmp = DPSR(
|
|
|
|
res=(cfg["model"]["grid_res"], cfg["model"]["grid_res"], cfg["model"]["grid_res"]),
|
|
|
|
sig=cfg["model"]["psr_sigma"],
|
|
|
|
).to(device)
|
2021-11-08 10:09:50 +00:00
|
|
|
target = dpsr_tmp(pts_gt, norms_gt).unsqueeze(1).to(device)
|
|
|
|
target = torch.tanh(target)
|
2023-05-26 12:59:53 +00:00
|
|
|
s = target.shape[-1] # size of psr_grid
|
2021-11-08 10:09:50 +00:00
|
|
|
psr_grid_numpy = target.squeeze().detach().cpu().numpy()
|
|
|
|
verts, faces, _, _ = measure.marching_cubes(psr_grid_numpy)
|
2023-05-26 12:59:53 +00:00
|
|
|
verts = verts / s * 2.0 - 1 # [-1, 1]
|
2021-11-08 10:09:50 +00:00
|
|
|
mesh = o3d.geometry.TriangleMesh()
|
|
|
|
mesh.vertices = o3d.utility.Vector3dVector(verts)
|
|
|
|
mesh.triangles = o3d.utility.Vector3iVector(faces)
|
2023-05-26 12:59:53 +00:00
|
|
|
outdir_mesh = os.path.join(cfg["train"]["out_dir"], "oracle_mesh.ply")
|
2021-11-08 10:09:50 +00:00
|
|
|
o3d.io.write_triangle_mesh(outdir_mesh, mesh)
|
|
|
|
|
|
|
|
# initialize the source point cloud given an input mesh
|
2023-05-26 12:59:53 +00:00
|
|
|
if "input_mesh" in cfg["train"].keys() and os.path.isfile(cfg["train"]["input_mesh"]):
|
|
|
|
if cfg["train"]["input_mesh"].split("/")[-2] == "mesh":
|
|
|
|
mesh_tmp = trimesh.load_mesh(cfg["train"]["input_mesh"])
|
2021-11-08 10:09:50 +00:00
|
|
|
verts = torch.from_numpy(mesh_tmp.vertices[None]).float().to(device)
|
|
|
|
faces = torch.from_numpy(mesh_tmp.faces[None]).to(device)
|
|
|
|
mesh = Meshes(verts=verts, faces=faces)
|
2023-05-26 12:59:53 +00:00
|
|
|
points, normals = sample_points_from_meshes(
|
|
|
|
mesh, num_samples=cfg["data"]["num_points"], return_normals=True,
|
|
|
|
)
|
2021-11-08 10:09:50 +00:00
|
|
|
# mesh is saved in the original scale of the gt
|
|
|
|
points -= center.float().to(device)
|
|
|
|
points /= scale.float().to(device)
|
|
|
|
points *= 0.9
|
|
|
|
# make sure the points are within the range of [0, 1)
|
2023-05-26 12:59:53 +00:00
|
|
|
points = points / 2.0 + 0.5
|
2021-11-08 10:09:50 +00:00
|
|
|
else:
|
|
|
|
# directly initialize from a point cloud
|
2023-05-26 12:59:53 +00:00
|
|
|
pcd = o3d.io.read_point_cloud(cfg["train"]["input_mesh"])
|
2021-11-08 10:09:50 +00:00
|
|
|
points = torch.from_numpy(np.array(pcd.points)[None]).float().to(device)
|
|
|
|
normals = torch.from_numpy(np.array(pcd.normals)[None]).float().to(device)
|
|
|
|
points -= center.float().to(device)
|
|
|
|
points /= scale.float().to(device)
|
|
|
|
points *= 0.9
|
2023-05-26 12:59:53 +00:00
|
|
|
points = points / 2.0 + 0.5
|
|
|
|
else: #! initialize our source point cloud from a sphere
|
|
|
|
sphere_radius = cfg["model"]["sphere_radius"]
|
|
|
|
sphere_mesh = trimesh.creation.uv_sphere(radius=sphere_radius, count=[256, 256])
|
|
|
|
points, idx = sphere_mesh.sample(cfg["data"]["num_points"], return_index=True)
|
|
|
|
points += 0.5 # make sure the points are within the range of [0, 1)
|
2021-11-08 10:09:50 +00:00
|
|
|
normals = sphere_mesh.face_normals[idx]
|
|
|
|
points = torch.from_numpy(points).unsqueeze(0).to(device)
|
|
|
|
normals = torch.from_numpy(normals).unsqueeze(0).to(device)
|
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
points = torch.log(points / (1 - points)) # inverse sigmoid
|
2021-11-08 10:09:50 +00:00
|
|
|
inputs = torch.cat([points, normals], axis=-1).float()
|
|
|
|
inputs.requires_grad = True
|
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
model = None # no network
|
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
# initialize optimizer
|
2023-05-26 12:59:53 +00:00
|
|
|
cfg["train"]["schedule"]["pcl"]["initial"] = cfg["train"]["lr_pcl"]
|
|
|
|
print("Initial learning rate:", cfg["train"]["schedule"]["pcl"]["initial"])
|
|
|
|
if "schedule" in cfg["train"]:
|
|
|
|
lr_schedules = get_learning_rate_schedules(cfg["train"]["schedule"])
|
2021-11-08 10:09:50 +00:00
|
|
|
else:
|
|
|
|
lr_schedules = None
|
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
optimizer = update_optimizer(inputs, cfg, epoch=0, model=model, schedule=lr_schedules)
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
try:
|
|
|
|
# load model
|
2023-05-26 12:59:53 +00:00
|
|
|
state_dict = torch.load(os.path.join(cfg["train"]["out_dir"], "model.pt"))
|
|
|
|
if ("pcl" in state_dict.keys()) & (state_dict["pcl"] is not None):
|
|
|
|
inputs = state_dict["pcl"].to(device)
|
2021-11-08 10:09:50 +00:00
|
|
|
inputs.requires_grad = True
|
2023-05-26 12:59:53 +00:00
|
|
|
|
|
|
|
optimizer = update_optimizer(inputs, cfg, epoch=state_dict.get("epoch"), schedule=lr_schedules)
|
|
|
|
|
|
|
|
out = "Load model from epoch %d" % state_dict.get("epoch", 0)
|
2021-11-08 10:09:50 +00:00
|
|
|
print(out)
|
|
|
|
logger.info(out)
|
|
|
|
except:
|
|
|
|
state_dict = dict()
|
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
start_epoch = state_dict.get("epoch", -1)
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
trainer = Trainer(cfg, optimizer, device=device)
|
|
|
|
runtime = {}
|
2023-05-26 12:59:53 +00:00
|
|
|
runtime["all"] = AverageMeter()
|
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
# training loop
|
2023-05-26 12:59:53 +00:00
|
|
|
for epoch in range(start_epoch + 1, cfg["train"]["total_epochs"] + 1):
|
2021-11-08 10:09:50 +00:00
|
|
|
# schedule the learning rate
|
2023-05-26 12:59:53 +00:00
|
|
|
if (epoch > 0) & (lr_schedules is not None):
|
|
|
|
if epoch % lr_schedules[0].interval == 0:
|
2021-11-08 10:09:50 +00:00
|
|
|
adjust_learning_rate(lr_schedules, optimizer, epoch)
|
2023-05-26 12:59:53 +00:00
|
|
|
if len(lr_schedules) > 1:
|
|
|
|
print(
|
|
|
|
"[epoch {}] net_lr: {}, pcl_lr: {}".format(
|
|
|
|
epoch, lr_schedules[0].get_learning_rate(epoch), lr_schedules[1].get_learning_rate(epoch),
|
|
|
|
),
|
|
|
|
)
|
2021-11-08 10:09:50 +00:00
|
|
|
else:
|
2023-05-26 12:59:53 +00:00
|
|
|
print(f"[epoch {epoch}] adjust pcl_lr to: {lr_schedules[0].get_learning_rate(epoch)}")
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
start = time.time()
|
|
|
|
loss, loss_each = trainer.train_step(data, inputs, model, epoch)
|
2023-05-26 12:59:53 +00:00
|
|
|
runtime["all"].update(time.time() - start)
|
2021-11-08 10:09:50 +00:00
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
if epoch % cfg["train"]["print_every"] == 0:
|
|
|
|
log_text = ("[Epoch %02d] loss=%.5f") % (epoch, loss)
|
2021-11-08 10:09:50 +00:00
|
|
|
if loss_each is not None:
|
|
|
|
for k, l in loss_each.items():
|
2023-05-26 12:59:53 +00:00
|
|
|
if l.item() != 0.0:
|
|
|
|
log_text += f" loss_{k}={l.item():.5f}"
|
|
|
|
|
|
|
|
log_text += (" time={:.3f} / {:.3f}").format(runtime["all"].val, runtime["all"].sum)
|
2021-11-08 10:09:50 +00:00
|
|
|
logger.info(log_text)
|
|
|
|
print(log_text)
|
|
|
|
|
|
|
|
# visualize point clouds and meshes
|
2023-05-26 12:59:53 +00:00
|
|
|
if (epoch % cfg["train"]["visualize_every"] == 0) & (vis is not None):
|
2021-11-08 10:09:50 +00:00
|
|
|
trainer.visualize(data, inputs, model, epoch, o3d_vis=vis)
|
2023-05-26 12:59:53 +00:00
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
# save outputs
|
2023-05-26 12:59:53 +00:00
|
|
|
if epoch % cfg["train"]["save_every"] == 0:
|
|
|
|
trainer.save_mesh_pointclouds(inputs, epoch, center.cpu().numpy(), scale.cpu().numpy() * (1 / 0.9))
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
# save checkpoints
|
2023-05-26 12:59:53 +00:00
|
|
|
if (epoch > 0) & (epoch % cfg["train"]["checkpoint_every"] == 0):
|
|
|
|
state = {"epoch": epoch}
|
2021-11-08 10:09:50 +00:00
|
|
|
if isinstance(inputs, torch.Tensor):
|
2023-05-26 12:59:53 +00:00
|
|
|
state["pcl"] = inputs.detach().cpu()
|
|
|
|
|
|
|
|
torch.save(state, os.path.join(cfg["train"]["dir_model"], "%04d" % epoch + ".pt"))
|
2021-11-08 10:09:50 +00:00
|
|
|
print("Save new model at epoch %d" % epoch)
|
|
|
|
logger.info("Save new model at epoch %d" % epoch)
|
2023-05-26 12:59:53 +00:00
|
|
|
torch.save(state, os.path.join(cfg["train"]["out_dir"], "model.pt"))
|
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
# resample and gradually add new points to the source pcl
|
2023-05-26 12:59:53 +00:00
|
|
|
if (
|
|
|
|
(epoch > 0)
|
|
|
|
& (cfg["train"]["resample_every"] != 0)
|
|
|
|
& (epoch % cfg["train"]["resample_every"] == 0)
|
|
|
|
& (epoch < cfg["train"]["total_epochs"])
|
|
|
|
):
|
|
|
|
inputs = trainer.point_resampling(inputs)
|
|
|
|
optimizer = update_optimizer(inputs, cfg, epoch=epoch, model=model, schedule=lr_schedules)
|
|
|
|
trainer = Trainer(cfg, optimizer, device=device)
|
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
# visualize the Open3D outputs
|
2023-05-26 12:59:53 +00:00
|
|
|
if cfg["train"]["o3d_show"]:
|
|
|
|
out_video_dir = os.path.join(cfg["train"]["out_dir"], "vis/o3d/video.mp4")
|
2021-11-08 10:09:50 +00:00
|
|
|
if os.path.isfile(out_video_dir):
|
2023-05-26 12:59:53 +00:00
|
|
|
os.system(f"rm {out_video_dir}")
|
|
|
|
os.system(
|
|
|
|
"ffmpeg -framerate 30 \
|
2021-11-08 10:09:50 +00:00
|
|
|
-start_number 0 \
|
|
|
|
-i {}/vis/o3d/%04d.jpg \
|
|
|
|
-pix_fmt yuv420p \
|
2023-05-26 12:59:53 +00:00
|
|
|
-crf 17 {}".format(
|
|
|
|
cfg["train"]["out_dir"], out_video_dir,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
out_video_dir = os.path.join(cfg["train"]["out_dir"], "vis/o3d/video_pcd.mp4")
|
2021-11-08 10:09:50 +00:00
|
|
|
if os.path.isfile(out_video_dir):
|
2023-05-26 12:59:53 +00:00
|
|
|
os.system(f"rm {out_video_dir}")
|
|
|
|
os.system(
|
|
|
|
"ffmpeg -framerate 30 \
|
2021-11-08 10:09:50 +00:00
|
|
|
-start_number 0 \
|
|
|
|
-i {}/vis/o3d/%04d_pcd.jpg \
|
|
|
|
-pix_fmt yuv420p \
|
2023-05-26 12:59:53 +00:00
|
|
|
-crf 17 {}".format(
|
|
|
|
cfg["train"]["out_dir"], out_video_dir,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
print("Video saved.")
|
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
if __name__ == "__main__":
|
2021-11-08 10:09:50 +00:00
|
|
|
main()
|