Shape-as-Point/eval_meshes.py

146 lines
4.9 KiB
Python
Raw Normal View History

2023-05-26 12:59:53 +00:00
import argparse
import os
import numpy as np
import pandas as pd
2021-11-08 10:09:50 +00:00
import torch
import trimesh
from tqdm import tqdm
2023-05-26 12:59:53 +00:00
from src.data import IndexField, PointCloudField, Shapes3dDataset
from src.eval import MeshEvaluator
from src.utils import load_config, load_pointcloud
np.set_printoptions(precision=4)
2021-11-08 10:09:50 +00:00
def main():
2023-05-26 12:59:53 +00:00
parser = argparse.ArgumentParser(description="MNIST toy experiment")
parser.add_argument("config", type=str, help="Path to config file.")
parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument("--iter", type=int, metavar="S", help="the training iteration to be evaluated.")
2021-11-08 10:09:50 +00:00
args = parser.parse_args()
2023-05-26 12:59:53 +00:00
cfg = load_config(args.config, "configs/default.yaml")
2021-11-08 10:09:50 +00:00
use_cuda = not args.no_cuda and torch.cuda.is_available()
2023-05-26 12:59:53 +00:00
torch.device("cuda" if use_cuda else "cpu")
cfg["data"]["data_type"]
2021-11-08 10:09:50 +00:00
# Shorthands
2023-05-26 12:59:53 +00:00
out_dir = cfg["train"]["out_dir"]
generation_dir = os.path.join(out_dir, cfg["generation"]["generation_dir"])
2021-11-08 10:09:50 +00:00
2023-05-26 12:59:53 +00:00
if cfg["generation"].get("iter", 0) != 0:
generation_dir += "_%04d" % cfg["generation"]["iter"]
2021-11-08 10:09:50 +00:00
elif args.iter is not None:
2023-05-26 12:59:53 +00:00
generation_dir += "_%04d" % args.iter
print("Evaluate meshes under %s" % generation_dir)
out_file = os.path.join(generation_dir, "eval_meshes_full.pkl")
out_file_class = os.path.join(generation_dir, "eval_meshes.csv")
2021-11-08 10:09:50 +00:00
# PYTORCH VERSION > 1.0.0
2023-05-26 12:59:53 +00:00
assert float(torch.__version__.split(".")[-3]) > 0
2021-11-08 10:09:50 +00:00
2023-05-26 12:59:53 +00:00
pointcloud_field = PointCloudField(cfg["data"]["pointcloud_file"])
2021-11-08 10:09:50 +00:00
fields = {
2023-05-26 12:59:53 +00:00
"pointcloud": pointcloud_field,
"idx": IndexField(),
2021-11-08 10:09:50 +00:00
}
2023-05-26 12:59:53 +00:00
print("Test split: ", cfg["data"]["test_split"])
2021-11-08 10:09:50 +00:00
2023-05-26 12:59:53 +00:00
dataset_folder = cfg["data"]["path"]
2021-11-08 10:09:50 +00:00
dataset = Shapes3dDataset(
2023-05-26 12:59:53 +00:00
dataset_folder, fields, cfg["data"]["test_split"], categories=cfg["data"]["class"], cfg=cfg,
)
2021-11-08 10:09:50 +00:00
# Loader
2023-05-26 12:59:53 +00:00
test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False)
2021-11-08 10:09:50 +00:00
# Evaluator
evaluator = MeshEvaluator(n_points=100000)
2023-05-26 12:59:53 +00:00
eval_dicts = []
print("Evaluating meshes...")
for _it, data in enumerate(tqdm(test_loader)):
2021-11-08 10:09:50 +00:00
if data is None:
2023-05-26 12:59:53 +00:00
print("Invalid data.")
2021-11-08 10:09:50 +00:00
continue
2023-05-26 12:59:53 +00:00
mesh_dir = os.path.join(generation_dir, "meshes")
pointcloud_dir = os.path.join(generation_dir, "pointcloud")
2021-11-08 10:09:50 +00:00
# Get index etc.
2023-05-26 12:59:53 +00:00
idx = data["idx"].item()
2021-11-08 10:09:50 +00:00
try:
model_dict = dataset.get_model_dict(idx)
except AttributeError:
2023-05-26 12:59:53 +00:00
model_dict = {"model": str(idx), "category": "n/a"}
2021-11-08 10:09:50 +00:00
2023-05-26 12:59:53 +00:00
modelname = model_dict["model"]
category_id = model_dict["category"]
2021-11-08 10:09:50 +00:00
try:
2023-05-26 12:59:53 +00:00
category_name = dataset.metadata[category_id].get("name", "n/a")
2021-11-08 10:09:50 +00:00
except AttributeError:
2023-05-26 12:59:53 +00:00
category_name = "n/a"
2021-11-08 10:09:50 +00:00
2023-05-26 12:59:53 +00:00
if category_id != "n/a":
2021-11-08 10:09:50 +00:00
mesh_dir = os.path.join(mesh_dir, category_id)
pointcloud_dir = os.path.join(pointcloud_dir, category_id)
# Evaluate
2023-05-26 12:59:53 +00:00
pointcloud_tgt = data["pointcloud"].squeeze(0).numpy()
normals_tgt = data["pointcloud.normals"].squeeze(0).numpy()
2021-11-08 10:09:50 +00:00
eval_dict = {
2023-05-26 12:59:53 +00:00
"idx": idx,
"class id": category_id,
"class name": category_name,
"modelname": modelname,
2021-11-08 10:09:50 +00:00
}
eval_dicts.append(eval_dict)
2023-05-26 12:59:53 +00:00
2021-11-08 10:09:50 +00:00
# Evaluate mesh
2023-05-26 12:59:53 +00:00
if cfg["test"]["eval_mesh"]:
mesh_file = os.path.join(mesh_dir, "%s.off" % modelname)
2021-11-08 10:09:50 +00:00
if os.path.exists(mesh_file):
mesh = trimesh.load(mesh_file, process=False)
2023-05-26 12:59:53 +00:00
eval_dict_mesh = evaluator.eval_mesh(mesh, pointcloud_tgt, normals_tgt)
2021-11-08 10:09:50 +00:00
for k, v in eval_dict_mesh.items():
2023-05-26 12:59:53 +00:00
eval_dict[k + " (mesh)"] = v
2021-11-08 10:09:50 +00:00
else:
2023-05-26 12:59:53 +00:00
print("Warning: mesh does not exist: %s" % mesh_file)
2021-11-08 10:09:50 +00:00
# Evaluate point cloud
2023-05-26 12:59:53 +00:00
if cfg["test"]["eval_pointcloud"]:
pointcloud_file = os.path.join(pointcloud_dir, "%s.ply" % modelname)
2021-11-08 10:09:50 +00:00
if os.path.exists(pointcloud_file):
pointcloud = load_pointcloud(pointcloud_file).astype(np.float32)
2023-05-26 12:59:53 +00:00
eval_dict_pcl = evaluator.eval_pointcloud(pointcloud, pointcloud_tgt)
2021-11-08 10:09:50 +00:00
for k, v in eval_dict_pcl.items():
2023-05-26 12:59:53 +00:00
eval_dict[k + " (pcl)"] = v
2021-11-08 10:09:50 +00:00
else:
2023-05-26 12:59:53 +00:00
print("Warning: pointcloud does not exist: %s" % pointcloud_file)
2021-11-08 10:09:50 +00:00
# Create pandas dataframe and save
eval_df = pd.DataFrame(eval_dicts)
2023-05-26 12:59:53 +00:00
eval_df.set_index(["idx"], inplace=True)
2021-11-08 10:09:50 +00:00
eval_df.to_pickle(out_file)
# Create CSV file with main statistics
2023-05-26 12:59:53 +00:00
eval_df_class = eval_df.groupby(by=["class name"]).mean()
eval_df_class.loc["mean"] = eval_df_class.mean()
2021-11-08 10:09:50 +00:00
eval_df_class.to_csv(out_file_class)
# Print results
print(eval_df_class)
2023-05-26 12:59:53 +00:00
if __name__ == "__main__":
main()