PVD/datasets/partnet.py
Linqi (Alex) Zhou 2f6aa752a6 PVD
2021-10-19 13:54:46 -07:00

214 lines
7.3 KiB
Python

from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import os
import json
import random
import trimesh
import csv
from plyfile import PlyData, PlyElement
from glob import glob
def project_pc_to_image(points, resolution=64):
"""project point clouds into 2D image
:param points: (n, 3) range(-1, 1)
:return: binary image
"""
img = []
for i in range(3):
canvas = np.zeros((resolution, resolution))
axis = [0, 1, 2]
axis.remove(i)
proj_points = (points[:, axis] + 1) / 2 * resolution
proj_points = proj_points.astype(np.int)
canvas[proj_points[:, 0], proj_points[:, 1]] = 1
img.append(canvas)
img = np.concatenate(img, axis=1)
return img
def write_ply(points, filename, text=False):
""" input: Nx3, write points to filename as PLY format. """
points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])]
vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')])
el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
with open(filename, mode='wb') as f:
PlyData([el], text=text).write(f)
def rotate_point_cloud(points, transformation_mat):
new_points = np.dot(transformation_mat, points.T).T
return new_points
def rotate_point_cloud_by_axis_angle(points, axis, angle_deg):
""" align 3depn shapes to shapenet coordinates"""
# angle = math.radians(angle_deg)
# rot_m = pymesh.Quaternion.fromAxisAngle(axis, angle)
# rot_m = rot_m.to_matrix()
rot_m = np.array([[ 2.22044605e-16, 0.00000000e+00, 1.00000000e+00],
[ 0.00000000e+00, 1.00000000e+00, 0.00000000e+00],
[-1.00000000e+00, 0.00000000e+00, 2.22044605e-16]])
new_points = rotate_point_cloud(points, rot_m)
return new_points
def downsample_point_cloud(points, n_pts):
"""downsample points by random choice
:param points: (n, 3)
:param n_pts: int
:return:
"""
p_idx = random.choices(list(range(points.shape[0])), k=n_pts)
return points[p_idx]
def upsample_point_cloud(points, n_pts):
"""upsample points by random choice
:param points: (n, 3)
:param n_pts: int, > n
:return:
"""
p_idx = random.choices(list(range(points.shape[0])), k=n_pts - points.shape[0])
dup_points = points[p_idx]
points = np.concatenate([points, dup_points], axis=0)
return points
def sample_point_cloud_by_n(points, n_pts):
"""resample point cloud to given number of points"""
if n_pts > points.shape[0]:
return upsample_point_cloud(points, n_pts)
elif n_pts < points.shape[0]:
return downsample_point_cloud(points, n_pts)
else:
return points
def collect_data_id(split_dir, classname, phase):
filename = os.path.join(split_dir, "{}.{}.json".format(classname, phase))
if not os.path.exists(filename):
raise ValueError("Invalid filepath: {}".format(filename))
all_ids = []
with open(filename, 'r') as fp:
info = json.load(fp)
for item in info:
all_ids.append(item["anno_id"])
return all_ids
class GANdatasetPartNet(Dataset):
def __init__(self, phase, data_root, category, n_pts):
super(GANdatasetPartNet, self).__init__()
if phase == "validation":
phase = "val"
self.phase = phase
self.aug = phase == "train"
self.data_root = data_root
shape_names = collect_data_id(os.path.join(self.data_root, 'partnet_labels/partnet_train_val_test_split'), category, phase)
self.shape_names = []
for name in shape_names:
path = os.path.join(self.data_root, 'partnet_labels/partnet_pc_label', name)
if os.path.exists(path):
self.shape_names.append(name)
self.n_pts = n_pts
self.raw_n_pts = self.n_pts // 2
self.rng = random.Random(1234)
@staticmethod
def load_point_cloud(path):
pc = trimesh.load(path)
pc = pc.vertices / 2.0 # scale to unit sphere
return pc
@staticmethod
def read_point_cloud_part_label(path):
with open(path, 'r') as fp:
labels = fp.readlines()
labels = np.array([int(x) for x in labels])
return labels
def random_rm_parts(self, raw_pc, part_labels):
part_ids = sorted(np.unique(part_labels).tolist())
if self.phase == "train":
random.shuffle(part_ids)
n_part_keep = random.randint(1, max(1, len(part_ids) - 1))
else:
self.rng.shuffle(part_ids)
n_part_keep = self.rng.randint(1, max(1, len(part_ids) - 1))
part_ids_keep = part_ids[:n_part_keep]
point_idx = []
for i in part_ids_keep:
point_idx.extend(np.where(part_labels == i)[0].tolist())
raw_pc = raw_pc[point_idx]
return raw_pc, n_part_keep
def __getitem__(self, index):
raw_shape_name = self.shape_names[index]
raw_ply_path = os.path.join(self.data_root, 'partnet_data', raw_shape_name, 'point_sample/ply-10000.ply')
raw_pc = self.load_point_cloud(raw_ply_path)
raw_label_path = os.path.join(self.data_root, 'partnet_labels/partnet_pc_label', raw_shape_name, 'label-merge-level1-10000.txt')
part_labels = self.read_point_cloud_part_label(raw_label_path)
raw_pc, n_part_keep = self.random_rm_parts(raw_pc, part_labels)
raw_pc = sample_point_cloud_by_n(raw_pc, self.raw_n_pts)
raw_pc = torch.tensor(raw_pc, dtype=torch.float32).transpose(1, 0)
real_shape_name = self.shape_names[index]
real_ply_path = os.path.join(self.data_root, 'partnet_data', real_shape_name, 'point_sample/ply-10000.ply')
real_pc = self.load_point_cloud(real_ply_path)
real_pc = sample_point_cloud_by_n(real_pc, self.n_pts)
real_pc = torch.tensor(real_pc, dtype=torch.float32).transpose(1, 0)
return {"raw": raw_pc, "real": real_pc, "raw_id": raw_shape_name, "real_id": real_shape_name,
'n_part_keep': n_part_keep, 'idx': index}
def __len__(self):
return len(self.shape_names)
if __name__ == '__main__':
data_root = '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetPointCloud'
data_raw_root = '/viscam/u/alexzhou907/01DATA/shapenet/shapenet_dim32_sdf_pc'
pc_dataroot = '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k'
sn_root = '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2'
classes = 'car'
npoints = 2048
# from datasets.shapenet_data_pc import ShapeNet15kPointClouds
# pc_ds = ShapeNet15kPointClouds(root_dir=pc_dataroot,
# categories=[classes], split='train',
# tr_sample_size=npoints,
# te_sample_size=npoints,
# scale=1.,
# normalize_per_shape=False,
# normalize_std_per_axis=False,
# random_subsample=True)
train_ds = GANdatasetPartNet('test', pc_dataroot, data_raw_root, classes, npoints, np.array([0,0,0]),
np.array([1, 1, 1]))
d1 = train_ds[0]
real = d1['real']
raw = d1['raw']
m, s = d1['m'], d1['s']
x = (torch.cat([raw, real], dim=-1) * s + m).transpose(0,1)
write_ply(x.numpy(), 'x.ply')
pass