PVD/dataset/partnet.py
2023-04-11 13:50:00 +02:00

192 lines
6.1 KiB
Python

import json
import os
import random
import numpy as np
import torch
import trimesh
from plyfile import PlyData, PlyElement
from torch.utils.data import Dataset
def project_pc_to_image(points, resolution=64):
"""project point clouds into 2D image
:param points: (n, 3) range(-1, 1)
:return: binary image
"""
img = []
for i in range(3):
canvas = np.zeros((resolution, resolution))
axis = [0, 1, 2]
axis.remove(i)
proj_points = (points[:, axis] + 1) / 2 * resolution
proj_points = proj_points.astype(np.int)
canvas[proj_points[:, 0], proj_points[:, 1]] = 1
img.append(canvas)
img = np.concatenate(img, axis=1)
return img
def write_ply(points, filename, text=False):
"""input: Nx3, write points to filename as PLY format."""
points = [(points[i, 0], points[i, 1], points[i, 2]) for i in range(points.shape[0])]
vertex = np.array(points, dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])
el = PlyElement.describe(vertex, "vertex", comments=["vertices"])
with open(filename, mode="wb") as f:
PlyData([el], text=text).write(f)
def rotate_point_cloud(points, transformation_mat):
new_points = np.dot(transformation_mat, points.T).T
return new_points
def rotate_point_cloud_by_axis_angle(points, axis, angle_deg):
"""align 3depn shapes to shapenet coordinates"""
# angle = math.radians(angle_deg)
# rot_m = pymesh.Quaternion.fromAxisAngle(axis, angle)
# rot_m = rot_m.to_matrix()
rot_m = np.array(
[
[2.22044605e-16, 0.00000000e00, 1.00000000e00],
[0.00000000e00, 1.00000000e00, 0.00000000e00],
[-1.00000000e00, 0.00000000e00, 2.22044605e-16],
]
)
new_points = rotate_point_cloud(points, rot_m)
return new_points
def downsample_point_cloud(points, n_pts):
"""downsample points by random choice
:param points: (n, 3)
:param n_pts: int
:return:
"""
p_idx = random.choices(list(range(points.shape[0])), k=n_pts)
return points[p_idx]
def upsample_point_cloud(points, n_pts):
"""upsample points by random choice
:param points: (n, 3)
:param n_pts: int, > n
:return:
"""
p_idx = random.choices(list(range(points.shape[0])), k=n_pts - points.shape[0])
dup_points = points[p_idx]
points = np.concatenate([points, dup_points], axis=0)
return points
def sample_point_cloud_by_n(points, n_pts):
"""resample point cloud to given number of points"""
if n_pts > points.shape[0]:
return upsample_point_cloud(points, n_pts)
elif n_pts < points.shape[0]:
return downsample_point_cloud(points, n_pts)
else:
return points
def collect_data_id(split_dir, classname, phase):
filename = os.path.join(split_dir, "{}.{}.json".format(classname, phase))
if not os.path.exists(filename):
raise ValueError("Invalid filepath: {}".format(filename))
all_ids = []
with open(filename, "r") as fp:
info = json.load(fp)
for item in info:
all_ids.append(item["anno_id"])
return all_ids
class GANdatasetPartNet(Dataset):
def __init__(self, phase, data_root, category, n_pts):
super(GANdatasetPartNet, self).__init__()
if phase == "validation":
phase = "val"
self.phase = phase
self.aug = phase == "train"
self.data_root = data_root
shape_names = collect_data_id(
os.path.join(self.data_root, "partnet_labels/partnet_train_val_test_split"), category, phase
)
self.shape_names = []
for name in shape_names:
path = os.path.join(self.data_root, "partnet_labels/partnet_pc_label", name)
if os.path.exists(path):
self.shape_names.append(name)
self.n_pts = n_pts
self.raw_n_pts = self.n_pts // 2
self.rng = random.Random(1234)
@staticmethod
def load_point_cloud(path):
pc = trimesh.load(path)
pc = pc.vertices / 2.0 # scale to unit sphere
return pc
@staticmethod
def read_point_cloud_part_label(path):
with open(path, "r") as fp:
labels = fp.readlines()
labels = np.array([int(x) for x in labels])
return labels
def random_rm_parts(self, raw_pc, part_labels):
part_ids = sorted(np.unique(part_labels).tolist())
if self.phase == "train":
random.shuffle(part_ids)
n_part_keep = random.randint(1, max(1, len(part_ids) - 1))
else:
self.rng.shuffle(part_ids)
n_part_keep = self.rng.randint(1, max(1, len(part_ids) - 1))
part_ids_keep = part_ids[:n_part_keep]
point_idx = []
for i in part_ids_keep:
point_idx.extend(np.where(part_labels == i)[0].tolist())
raw_pc = raw_pc[point_idx]
return raw_pc, n_part_keep
def __getitem__(self, index):
raw_shape_name = self.shape_names[index]
raw_ply_path = os.path.join(self.data_root, "partnet_data", raw_shape_name, "point_sample/ply-10000.ply")
raw_pc = self.load_point_cloud(raw_ply_path)
raw_label_path = os.path.join(
self.data_root, "partnet_labels/partnet_pc_label", raw_shape_name, "label-merge-level1-10000.txt"
)
part_labels = self.read_point_cloud_part_label(raw_label_path)
raw_pc, n_part_keep = self.random_rm_parts(raw_pc, part_labels)
raw_pc = sample_point_cloud_by_n(raw_pc, self.raw_n_pts)
raw_pc = torch.tensor(raw_pc, dtype=torch.float32).transpose(1, 0)
real_shape_name = self.shape_names[index]
real_ply_path = os.path.join(self.data_root, "partnet_data", real_shape_name, "point_sample/ply-10000.ply")
real_pc = self.load_point_cloud(real_ply_path)
real_pc = sample_point_cloud_by_n(real_pc, self.n_pts)
real_pc = torch.tensor(real_pc, dtype=torch.float32).transpose(1, 0)
return {
"raw": raw_pc,
"real": real_pc,
"raw_id": raw_shape_name,
"real_id": real_shape_name,
"n_part_keep": n_part_keep,
"idx": index,
}
def __len__(self):
return len(self.shape_names)