2023-05-26 12:59:53 +00:00
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
import shutil
|
2021-11-08 10:09:50 +00:00
|
|
|
from collections import defaultdict
|
2023-05-26 12:59:53 +00:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import pandas as pd
|
|
|
|
import torch
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
from src import config
|
|
|
|
from src.dpsr import DPSR
|
|
|
|
from src.model import Encode2Points
|
2023-05-26 12:59:53 +00:00
|
|
|
from src.utils import (
|
|
|
|
export_mesh,
|
|
|
|
export_pointcloud,
|
|
|
|
is_url,
|
|
|
|
load_config,
|
|
|
|
load_model_manual,
|
|
|
|
load_url,
|
|
|
|
mc_from_psr,
|
|
|
|
scale2onet,
|
|
|
|
)
|
|
|
|
|
|
|
|
np.set_printoptions(precision=4)
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
def main():
|
2023-05-26 12:59:53 +00:00
|
|
|
parser = argparse.ArgumentParser(description="MNIST toy experiment")
|
|
|
|
parser.add_argument("config", type=str, help="Path to config file.")
|
|
|
|
parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
|
|
|
|
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
|
|
|
|
parser.add_argument("--iter", type=int, metavar="S", help="the training iteration to be evaluated.")
|
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
args = parser.parse_args()
|
2023-05-26 12:59:53 +00:00
|
|
|
cfg = load_config(args.config, "configs/default.yaml")
|
2021-11-08 10:09:50 +00:00
|
|
|
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
|
|
|
device = torch.device("cuda" if use_cuda else "cpu")
|
2023-05-26 12:59:53 +00:00
|
|
|
cfg["data"]["data_type"]
|
|
|
|
cfg["data"]["input_type"]
|
|
|
|
vis_n_outputs = cfg["generation"]["vis_n_outputs"]
|
2021-11-08 10:09:50 +00:00
|
|
|
if vis_n_outputs is None:
|
|
|
|
vis_n_outputs = -1
|
|
|
|
# Shorthands
|
2023-05-26 12:59:53 +00:00
|
|
|
out_dir = cfg["train"]["out_dir"]
|
2021-11-08 10:09:50 +00:00
|
|
|
if not out_dir:
|
|
|
|
os.makedirs(out_dir)
|
2023-05-26 12:59:53 +00:00
|
|
|
generation_dir = os.path.join(out_dir, cfg["generation"]["generation_dir"])
|
|
|
|
out_time_file = os.path.join(generation_dir, "time_generation_full.pkl")
|
|
|
|
out_time_file_class = os.path.join(generation_dir, "time_generation.pkl")
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
# PYTORCH VERSION > 1.0.0
|
2023-05-26 12:59:53 +00:00
|
|
|
assert float(torch.__version__.split(".")[-3]) > 0
|
2021-11-08 10:09:50 +00:00
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
dataset = config.get_dataset("test", cfg, return_idx=True)
|
|
|
|
test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False)
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
model = Encode2Points(cfg).to(device)
|
2023-05-26 12:59:53 +00:00
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
# load model
|
|
|
|
try:
|
2023-05-26 12:59:53 +00:00
|
|
|
if is_url(cfg["test"]["model_file"]):
|
|
|
|
state_dict = load_url(cfg["test"]["model_file"])
|
|
|
|
elif cfg["generation"].get("iter", 0) != 0:
|
|
|
|
state_dict = torch.load(os.path.join(out_dir, "model", "%04d.pt" % cfg["generation"]["iter"]))
|
|
|
|
generation_dir += "_%04d" % cfg["generation"]["iter"]
|
2021-11-08 10:09:50 +00:00
|
|
|
elif args.iter is not None:
|
2023-05-26 12:59:53 +00:00
|
|
|
state_dict = torch.load(os.path.join(out_dir, "model", "%04d.pt" % args.iter))
|
2021-11-08 10:09:50 +00:00
|
|
|
else:
|
2023-05-26 12:59:53 +00:00
|
|
|
state_dict = torch.load(os.path.join(out_dir, "model_best.pt"))
|
2021-11-08 10:09:50 +00:00
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
load_model_manual(state_dict["state_dict"], model)
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
except:
|
2023-05-26 12:59:53 +00:00
|
|
|
print("Model loading error. Exiting.")
|
2021-11-08 10:09:50 +00:00
|
|
|
exit()
|
2023-05-26 12:59:53 +00:00
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
# Generator
|
|
|
|
generator = config.get_generator(model, cfg, device=device)
|
2023-05-26 12:59:53 +00:00
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
# Determine what to generate
|
2023-05-26 12:59:53 +00:00
|
|
|
generate_mesh = cfg["generation"]["generate_mesh"]
|
|
|
|
generate_pointcloud = cfg["generation"]["generate_pointcloud"]
|
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
# Statistics
|
|
|
|
time_dicts = []
|
|
|
|
|
|
|
|
# Generate
|
|
|
|
model.eval()
|
2023-05-26 12:59:53 +00:00
|
|
|
dpsr = DPSR(
|
|
|
|
res=(
|
|
|
|
cfg["generation"]["psr_resolution"],
|
|
|
|
cfg["generation"]["psr_resolution"],
|
|
|
|
cfg["generation"]["psr_resolution"],
|
|
|
|
),
|
|
|
|
sig=cfg["generation"]["psr_sigma"],
|
|
|
|
).to(device)
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
# Count how many models already created
|
|
|
|
model_counter = defaultdict(int)
|
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
print("Generating...")
|
|
|
|
for _it, data in enumerate(tqdm(test_loader)):
|
2021-11-08 10:09:50 +00:00
|
|
|
# Output folders
|
2023-05-26 12:59:53 +00:00
|
|
|
mesh_dir = os.path.join(generation_dir, "meshes")
|
|
|
|
in_dir = os.path.join(generation_dir, "input")
|
|
|
|
pointcloud_dir = os.path.join(generation_dir, "pointcloud")
|
|
|
|
generation_vis_dir = os.path.join(generation_dir, "vis")
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
# Get index etc.
|
2023-05-26 12:59:53 +00:00
|
|
|
idx = data["idx"].item()
|
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
try:
|
|
|
|
model_dict = dataset.get_model_dict(idx)
|
|
|
|
except AttributeError:
|
2023-05-26 12:59:53 +00:00
|
|
|
model_dict = {"model": str(idx), "category": "n/a"}
|
2021-11-08 10:09:50 +00:00
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
modelname = model_dict["model"]
|
|
|
|
category_id = model_dict["category"]
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
try:
|
2023-05-26 12:59:53 +00:00
|
|
|
category_name = dataset.metadata[category_id].get("name", "n/a")
|
2021-11-08 10:09:50 +00:00
|
|
|
except AttributeError:
|
2023-05-26 12:59:53 +00:00
|
|
|
category_name = "n/a"
|
|
|
|
|
|
|
|
if category_id != "n/a":
|
2021-11-08 10:09:50 +00:00
|
|
|
mesh_dir = os.path.join(mesh_dir, str(category_id))
|
|
|
|
pointcloud_dir = os.path.join(pointcloud_dir, str(category_id))
|
|
|
|
in_dir = os.path.join(in_dir, str(category_id))
|
|
|
|
|
|
|
|
folder_name = str(category_id)
|
2023-05-26 12:59:53 +00:00
|
|
|
if category_name != "n/a":
|
|
|
|
folder_name = str(folder_name) + "_" + category_name.split(",")[0]
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
generation_vis_dir = os.path.join(generation_vis_dir, folder_name)
|
|
|
|
|
|
|
|
# Create directories if necessary
|
|
|
|
if vis_n_outputs >= 0 and not os.path.exists(generation_vis_dir):
|
|
|
|
os.makedirs(generation_vis_dir)
|
2023-05-26 12:59:53 +00:00
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
if generate_mesh and not os.path.exists(mesh_dir):
|
|
|
|
os.makedirs(mesh_dir)
|
2023-05-26 12:59:53 +00:00
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
if generate_pointcloud and not os.path.exists(pointcloud_dir):
|
|
|
|
os.makedirs(pointcloud_dir)
|
2023-05-26 12:59:53 +00:00
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
if not os.path.exists(in_dir):
|
|
|
|
os.makedirs(in_dir)
|
|
|
|
|
|
|
|
# Timing dict
|
|
|
|
time_dict = {
|
2023-05-26 12:59:53 +00:00
|
|
|
"idx": idx,
|
|
|
|
"class id": category_id,
|
|
|
|
"class name": category_name,
|
|
|
|
"modelname": modelname,
|
2021-11-08 10:09:50 +00:00
|
|
|
}
|
|
|
|
time_dicts.append(time_dict)
|
|
|
|
|
|
|
|
# Generate outputs
|
|
|
|
out_file_dict = {}
|
2023-05-26 12:59:53 +00:00
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
if generate_mesh:
|
|
|
|
#! deploy the generator to a separate class
|
|
|
|
out = generator.generate_mesh(data)
|
|
|
|
|
|
|
|
v, f, points, normals, stats_dict = out
|
|
|
|
time_dict.update(stats_dict)
|
|
|
|
|
|
|
|
# Write output
|
2023-05-26 12:59:53 +00:00
|
|
|
mesh_out_file = os.path.join(mesh_dir, "%s.off" % modelname)
|
2021-11-08 10:09:50 +00:00
|
|
|
export_mesh(mesh_out_file, scale2onet(v), f)
|
2023-05-26 12:59:53 +00:00
|
|
|
out_file_dict["mesh"] = mesh_out_file
|
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
if generate_pointcloud:
|
2023-05-26 12:59:53 +00:00
|
|
|
pointcloud_out_file = os.path.join(pointcloud_dir, "%s.ply" % modelname)
|
2021-11-08 10:09:50 +00:00
|
|
|
export_pointcloud(pointcloud_out_file, scale2onet(points), normals)
|
2023-05-26 12:59:53 +00:00
|
|
|
out_file_dict["pointcloud"] = pointcloud_out_file
|
|
|
|
|
|
|
|
if cfg["generation"]["copy_input"]:
|
|
|
|
inputs_path = os.path.join(in_dir, "%s.ply" % modelname)
|
|
|
|
p = data.get("inputs").to(device)
|
2021-11-08 10:09:50 +00:00
|
|
|
export_pointcloud(inputs_path, scale2onet(p))
|
2023-05-26 12:59:53 +00:00
|
|
|
out_file_dict["in"] = inputs_path
|
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
# Copy to visualization directory for first vis_n_output samples
|
|
|
|
c_it = model_counter[category_id]
|
|
|
|
if c_it < vis_n_outputs:
|
|
|
|
# Save output files
|
2023-05-26 12:59:53 +00:00
|
|
|
"%02d.off" % c_it
|
2021-11-08 10:09:50 +00:00
|
|
|
for k, filepath in out_file_dict.items():
|
|
|
|
ext = os.path.splitext(filepath)[1]
|
2023-05-26 12:59:53 +00:00
|
|
|
out_file = os.path.join(generation_vis_dir, "%02d_%s%s" % (c_it, k, ext))
|
2021-11-08 10:09:50 +00:00
|
|
|
shutil.copyfile(filepath, out_file)
|
2023-05-26 12:59:53 +00:00
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
# Also generate oracle meshes
|
2023-05-26 12:59:53 +00:00
|
|
|
if cfg["generation"]["exp_oracle"]:
|
|
|
|
points_gt = data.get("gt_points").to(device)
|
|
|
|
normals_gt = data.get("gt_points.normals").to(device)
|
2021-11-08 10:09:50 +00:00
|
|
|
psr_gt = dpsr(points_gt, normals_gt)
|
2023-05-26 12:59:53 +00:00
|
|
|
v, f, _ = mc_from_psr(psr_gt, zero_level=cfg["data"]["zero_level"])
|
|
|
|
out_file = os.path.join(generation_vis_dir, "%02d_%s%s" % (c_it, "mesh_oracle", ".off"))
|
2021-11-08 10:09:50 +00:00
|
|
|
export_mesh(out_file, scale2onet(v), f)
|
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
model_counter[category_id] += 1
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
# Create pandas dataframe and save
|
|
|
|
time_df = pd.DataFrame(time_dicts)
|
2023-05-26 12:59:53 +00:00
|
|
|
time_df.set_index(["idx"], inplace=True)
|
2021-11-08 10:09:50 +00:00
|
|
|
time_df.to_pickle(out_time_file)
|
|
|
|
|
|
|
|
# Create pickle files with main statistics
|
2023-05-26 12:59:53 +00:00
|
|
|
time_df_class = time_df.groupby(by=["class name"]).mean()
|
|
|
|
time_df_class.loc["mean"] = time_df_class.mean()
|
2021-11-08 10:09:50 +00:00
|
|
|
time_df_class.to_pickle(out_time_file_class)
|
|
|
|
|
|
|
|
# Print results
|
2023-05-26 12:59:53 +00:00
|
|
|
print("Timings [s]:")
|
2021-11-08 10:09:50 +00:00
|
|
|
print(time_df_class)
|
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|