Shape-as-Point/scripts/process_shapenet.py
2023-05-26 14:59:53 +02:00

117 lines
3.1 KiB
Python

import multiprocessing
import os
import time
import numpy as np
import torch
from tqdm import tqdm
from src.dpsr import DPSR
data_path = "data/ShapeNet" # path for ShapeNet from ONet
base = "data" # output base directory
dataset_name = "shapenet_psr"
multiprocess = True
njobs = 8
save_pointcloud = True
save_psr_field = True
resolution = 128
zero_level = 0.0
num_points = 100000
padding = 1.2
dpsr = DPSR(res=(resolution, resolution, resolution), sig=0)
def process_one(obj):
obj_name = obj.split("/")[-1]
c = obj.split("/")[-2]
# create new for the current object
out_path_cur = os.path.join(base, dataset_name, c)
out_path_cur_obj = os.path.join(out_path_cur, obj_name)
os.makedirs(out_path_cur_obj, exist_ok=True)
gt_path = os.path.join(data_path, c, obj_name, "pointcloud.npz")
data = np.load(gt_path)
points = data["points"]
normals = data["normals"]
# normalize the point to [0, 1)
points = points / padding + 0.5
# to scale back during inference, we should:
#! p = (p - 0.5) * padding
if save_pointcloud:
outdir = os.path.join(out_path_cur_obj, "pointcloud.npz")
# np.savez(outdir, points=points, normals=normals)
np.savez(outdir, points=data["points"], normals=data["normals"])
# return
if save_psr_field:
psr_gt = (
dpsr(torch.from_numpy(points.astype(np.float32))[None], torch.from_numpy(normals.astype(np.float32))[None])
.squeeze()
.cpu()
.numpy()
.astype(np.float16)
)
outdir = os.path.join(out_path_cur_obj, "psr.npz")
np.savez(outdir, psr=psr_gt)
def main(c):
print("---------------------------------------")
print(f"Processing {c} {split}")
print("---------------------------------------")
for split in ["train", "val", "test"]:
fname = os.path.join(data_path, c, split + ".lst")
with open(fname) as f:
obj_list = f.read().splitlines()
obj_list = [c + "/" + s for s in obj_list]
if multiprocess:
# multiprocessing.set_start_method('spawn', force=True)
pool = multiprocessing.Pool(njobs)
try:
for _ in tqdm(pool.imap_unordered(process_one, obj_list), total=len(obj_list)):
pass
# pool.map_async(process_one, obj_list).get()
except KeyboardInterrupt:
# Allow ^C to interrupt from any thread.
exit()
pool.close()
else:
for obj in tqdm(obj_list):
process_one(obj)
print(f"Done Processing {c} {split}!")
if __name__ == "__main__":
classes = [
"02691156",
"02828884",
"02933112",
"02958343",
"03211117",
"03001627",
"03636649",
"03691459",
"04090263",
"04256520",
"04379243",
"04401088",
"04530566",
]
t_start = time.time()
for c in classes:
main(c)
t_end = time.time()
print("Total processing time: ", t_end - t_start)