Shape-as-Point/scripts/process_shapenet.py

117 lines
3.1 KiB
Python
Raw Normal View History

2023-05-26 12:59:53 +00:00
import multiprocessing
2021-11-08 10:09:50 +00:00
import os
import time
2023-05-26 12:59:53 +00:00
2021-11-08 10:09:50 +00:00
import numpy as np
2023-05-26 12:59:53 +00:00
import torch
2021-11-08 10:09:50 +00:00
from tqdm import tqdm
2023-05-26 12:59:53 +00:00
2021-11-08 10:09:50 +00:00
from src.dpsr import DPSR
2023-05-26 12:59:53 +00:00
data_path = "data/ShapeNet" # path for ShapeNet from ONet
base = "data" # output base directory
dataset_name = "shapenet_psr"
2021-11-08 10:09:50 +00:00
multiprocess = True
njobs = 8
save_pointcloud = True
save_psr_field = True
resolution = 128
zero_level = 0.0
num_points = 100000
padding = 1.2
dpsr = DPSR(res=(resolution, resolution, resolution), sig=0)
2023-05-26 12:59:53 +00:00
def process_one(obj):
obj_name = obj.split("/")[-1]
c = obj.split("/")[-2]
2021-11-08 10:09:50 +00:00
# create new for the current object
out_path_cur = os.path.join(base, dataset_name, c)
out_path_cur_obj = os.path.join(out_path_cur, obj_name)
os.makedirs(out_path_cur_obj, exist_ok=True)
2023-05-26 12:59:53 +00:00
gt_path = os.path.join(data_path, c, obj_name, "pointcloud.npz")
2021-11-08 10:09:50 +00:00
data = np.load(gt_path)
2023-05-26 12:59:53 +00:00
points = data["points"]
normals = data["normals"]
2021-11-08 10:09:50 +00:00
# normalize the point to [0, 1)
points = points / padding + 0.5
# to scale back during inference, we should:
#! p = (p - 0.5) * padding
2023-05-26 12:59:53 +00:00
2021-11-08 10:09:50 +00:00
if save_pointcloud:
2023-05-26 12:59:53 +00:00
outdir = os.path.join(out_path_cur_obj, "pointcloud.npz")
2021-11-08 10:09:50 +00:00
# np.savez(outdir, points=points, normals=normals)
2023-05-26 12:59:53 +00:00
np.savez(outdir, points=data["points"], normals=data["normals"])
2021-11-08 10:09:50 +00:00
# return
2023-05-26 12:59:53 +00:00
2021-11-08 10:09:50 +00:00
if save_psr_field:
2023-05-26 12:59:53 +00:00
psr_gt = (
dpsr(torch.from_numpy(points.astype(np.float32))[None], torch.from_numpy(normals.astype(np.float32))[None])
.squeeze()
.cpu()
.numpy()
.astype(np.float16)
)
2021-11-08 10:09:50 +00:00
2023-05-26 12:59:53 +00:00
outdir = os.path.join(out_path_cur_obj, "psr.npz")
2021-11-08 10:09:50 +00:00
np.savez(outdir, psr=psr_gt)
2023-05-26 12:59:53 +00:00
2021-11-08 10:09:50 +00:00
def main(c):
2023-05-26 12:59:53 +00:00
print("---------------------------------------")
print(f"Processing {c} {split}")
print("---------------------------------------")
2021-11-08 10:09:50 +00:00
2023-05-26 12:59:53 +00:00
for split in ["train", "val", "test"]:
fname = os.path.join(data_path, c, split + ".lst")
with open(fname) as f:
obj_list = f.read().splitlines()
2021-11-08 10:09:50 +00:00
2023-05-26 12:59:53 +00:00
obj_list = [c + "/" + s for s in obj_list]
2021-11-08 10:09:50 +00:00
if multiprocess:
# multiprocessing.set_start_method('spawn', force=True)
pool = multiprocessing.Pool(njobs)
try:
for _ in tqdm(pool.imap_unordered(process_one, obj_list), total=len(obj_list)):
pass
# pool.map_async(process_one, obj_list).get()
except KeyboardInterrupt:
# Allow ^C to interrupt from any thread.
exit()
pool.close()
else:
for obj in tqdm(obj_list):
process_one(obj)
2023-05-26 12:59:53 +00:00
print(f"Done Processing {c} {split}!")
if __name__ == "__main__":
classes = [
"02691156",
"02828884",
"02933112",
"02958343",
"03211117",
"03001627",
"03636649",
"03691459",
"04090263",
"04256520",
"04379243",
"04401088",
"04530566",
]
2021-11-08 10:09:50 +00:00
t_start = time.time()
for c in classes:
main(c)
2023-05-26 12:59:53 +00:00
2021-11-08 10:09:50 +00:00
t_end = time.time()
2023-05-26 12:59:53 +00:00
print("Total processing time: ", t_end - t_start)