import json import os import random import numpy as np import torch import trimesh from plyfile import PlyData, PlyElement from torch.utils.data import Dataset def project_pc_to_image(points, resolution=64): """project point clouds into 2D image :param points: (n, 3) range(-1, 1) :return: binary image """ img = [] for i in range(3): canvas = np.zeros((resolution, resolution)) axis = [0, 1, 2] axis.remove(i) proj_points = (points[:, axis] + 1) / 2 * resolution proj_points = proj_points.astype(np.int) canvas[proj_points[:, 0], proj_points[:, 1]] = 1 img.append(canvas) img = np.concatenate(img, axis=1) return img def write_ply(points, filename, text=False): """input: Nx3, write points to filename as PLY format.""" points = [(points[i, 0], points[i, 1], points[i, 2]) for i in range(points.shape[0])] vertex = np.array(points, dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")]) el = PlyElement.describe(vertex, "vertex", comments=["vertices"]) with open(filename, mode="wb") as f: PlyData([el], text=text).write(f) def rotate_point_cloud(points, transformation_mat): new_points = np.dot(transformation_mat, points.T).T return new_points def rotate_point_cloud_by_axis_angle(points, axis, angle_deg): """align 3depn shapes to shapenet coordinates""" # angle = math.radians(angle_deg) # rot_m = pymesh.Quaternion.fromAxisAngle(axis, angle) # rot_m = rot_m.to_matrix() rot_m = np.array( [ [2.22044605e-16, 0.00000000e00, 1.00000000e00], [0.00000000e00, 1.00000000e00, 0.00000000e00], [-1.00000000e00, 0.00000000e00, 2.22044605e-16], ] ) new_points = rotate_point_cloud(points, rot_m) return new_points def downsample_point_cloud(points, n_pts): """downsample points by random choice :param points: (n, 3) :param n_pts: int :return: """ p_idx = random.choices(list(range(points.shape[0])), k=n_pts) return points[p_idx] def upsample_point_cloud(points, n_pts): """upsample points by random choice :param points: (n, 3) :param n_pts: int, > n :return: """ p_idx = random.choices(list(range(points.shape[0])), k=n_pts - points.shape[0]) dup_points = points[p_idx] points = np.concatenate([points, dup_points], axis=0) return points def sample_point_cloud_by_n(points, n_pts): """resample point cloud to given number of points""" if n_pts > points.shape[0]: return upsample_point_cloud(points, n_pts) elif n_pts < points.shape[0]: return downsample_point_cloud(points, n_pts) else: return points def collect_data_id(split_dir, classname, phase): filename = os.path.join(split_dir, "{}.{}.json".format(classname, phase)) if not os.path.exists(filename): raise ValueError("Invalid filepath: {}".format(filename)) all_ids = [] with open(filename, "r") as fp: info = json.load(fp) for item in info: all_ids.append(item["anno_id"]) return all_ids class GANdatasetPartNet(Dataset): def __init__(self, phase, data_root, category, n_pts): super(GANdatasetPartNet, self).__init__() if phase == "validation": phase = "val" self.phase = phase self.aug = phase == "train" self.data_root = data_root shape_names = collect_data_id( os.path.join(self.data_root, "partnet_labels/partnet_train_val_test_split"), category, phase ) self.shape_names = [] for name in shape_names: path = os.path.join(self.data_root, "partnet_labels/partnet_pc_label", name) if os.path.exists(path): self.shape_names.append(name) self.n_pts = n_pts self.raw_n_pts = self.n_pts // 2 self.rng = random.Random(1234) @staticmethod def load_point_cloud(path): pc = trimesh.load(path) pc = pc.vertices / 2.0 # scale to unit sphere return pc @staticmethod def read_point_cloud_part_label(path): with open(path, "r") as fp: labels = fp.readlines() labels = np.array([int(x) for x in labels]) return labels def random_rm_parts(self, raw_pc, part_labels): part_ids = sorted(np.unique(part_labels).tolist()) if self.phase == "train": random.shuffle(part_ids) n_part_keep = random.randint(1, max(1, len(part_ids) - 1)) else: self.rng.shuffle(part_ids) n_part_keep = self.rng.randint(1, max(1, len(part_ids) - 1)) part_ids_keep = part_ids[:n_part_keep] point_idx = [] for i in part_ids_keep: point_idx.extend(np.where(part_labels == i)[0].tolist()) raw_pc = raw_pc[point_idx] return raw_pc, n_part_keep def __getitem__(self, index): raw_shape_name = self.shape_names[index] raw_ply_path = os.path.join(self.data_root, "partnet_data", raw_shape_name, "point_sample/ply-10000.ply") raw_pc = self.load_point_cloud(raw_ply_path) raw_label_path = os.path.join( self.data_root, "partnet_labels/partnet_pc_label", raw_shape_name, "label-merge-level1-10000.txt" ) part_labels = self.read_point_cloud_part_label(raw_label_path) raw_pc, n_part_keep = self.random_rm_parts(raw_pc, part_labels) raw_pc = sample_point_cloud_by_n(raw_pc, self.raw_n_pts) raw_pc = torch.tensor(raw_pc, dtype=torch.float32).transpose(1, 0) real_shape_name = self.shape_names[index] real_ply_path = os.path.join(self.data_root, "partnet_data", real_shape_name, "point_sample/ply-10000.ply") real_pc = self.load_point_cloud(real_ply_path) real_pc = sample_point_cloud_by_n(real_pc, self.n_pts) real_pc = torch.tensor(real_pc, dtype=torch.float32).transpose(1, 0) return { "raw": raw_pc, "real": real_pc, "raw_id": raw_shape_name, "real_id": real_shape_name, "n_part_keep": n_part_keep, "idx": index, } def __len__(self): return len(self.shape_names)