117 lines
3.1 KiB
Python
117 lines
3.1 KiB
Python
import multiprocessing
|
|
import os
|
|
import time
|
|
|
|
import numpy as np
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from src.dpsr import DPSR
|
|
|
|
data_path = "data/ShapeNet" # path for ShapeNet from ONet
|
|
base = "data" # output base directory
|
|
dataset_name = "shapenet_psr"
|
|
multiprocess = True
|
|
njobs = 8
|
|
save_pointcloud = True
|
|
save_psr_field = True
|
|
resolution = 128
|
|
zero_level = 0.0
|
|
num_points = 100000
|
|
padding = 1.2
|
|
|
|
dpsr = DPSR(res=(resolution, resolution, resolution), sig=0)
|
|
|
|
|
|
def process_one(obj):
|
|
obj_name = obj.split("/")[-1]
|
|
c = obj.split("/")[-2]
|
|
|
|
# create new for the current object
|
|
out_path_cur = os.path.join(base, dataset_name, c)
|
|
out_path_cur_obj = os.path.join(out_path_cur, obj_name)
|
|
os.makedirs(out_path_cur_obj, exist_ok=True)
|
|
|
|
gt_path = os.path.join(data_path, c, obj_name, "pointcloud.npz")
|
|
data = np.load(gt_path)
|
|
points = data["points"]
|
|
normals = data["normals"]
|
|
|
|
# normalize the point to [0, 1)
|
|
points = points / padding + 0.5
|
|
# to scale back during inference, we should:
|
|
#! p = (p - 0.5) * padding
|
|
|
|
if save_pointcloud:
|
|
outdir = os.path.join(out_path_cur_obj, "pointcloud.npz")
|
|
# np.savez(outdir, points=points, normals=normals)
|
|
np.savez(outdir, points=data["points"], normals=data["normals"])
|
|
# return
|
|
|
|
if save_psr_field:
|
|
psr_gt = (
|
|
dpsr(torch.from_numpy(points.astype(np.float32))[None], torch.from_numpy(normals.astype(np.float32))[None])
|
|
.squeeze()
|
|
.cpu()
|
|
.numpy()
|
|
.astype(np.float16)
|
|
)
|
|
|
|
outdir = os.path.join(out_path_cur_obj, "psr.npz")
|
|
np.savez(outdir, psr=psr_gt)
|
|
|
|
|
|
def main(c):
|
|
print("---------------------------------------")
|
|
print(f"Processing {c} {split}")
|
|
print("---------------------------------------")
|
|
|
|
for split in ["train", "val", "test"]:
|
|
fname = os.path.join(data_path, c, split + ".lst")
|
|
with open(fname) as f:
|
|
obj_list = f.read().splitlines()
|
|
|
|
obj_list = [c + "/" + s for s in obj_list]
|
|
|
|
if multiprocess:
|
|
# multiprocessing.set_start_method('spawn', force=True)
|
|
pool = multiprocessing.Pool(njobs)
|
|
try:
|
|
for _ in tqdm(pool.imap_unordered(process_one, obj_list), total=len(obj_list)):
|
|
pass
|
|
# pool.map_async(process_one, obj_list).get()
|
|
except KeyboardInterrupt:
|
|
# Allow ^C to interrupt from any thread.
|
|
exit()
|
|
pool.close()
|
|
else:
|
|
for obj in tqdm(obj_list):
|
|
process_one(obj)
|
|
|
|
print(f"Done Processing {c} {split}!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
classes = [
|
|
"02691156",
|
|
"02828884",
|
|
"02933112",
|
|
"02958343",
|
|
"03211117",
|
|
"03001627",
|
|
"03636649",
|
|
"03691459",
|
|
"04090263",
|
|
"04256520",
|
|
"04379243",
|
|
"04401088",
|
|
"04530566",
|
|
]
|
|
|
|
t_start = time.time()
|
|
for c in classes:
|
|
main(c)
|
|
|
|
t_end = time.time()
|
|
print("Total processing time: ", t_end - t_start)
|