PVD/datasets/shapenet_data_pc.py

298 lines
11 KiB
Python
Raw Normal View History

2021-10-19 20:54:46 +00:00
import os
import random
2023-04-11 09:12:58 +00:00
2021-10-19 20:54:46 +00:00
import numpy as np
2023-04-11 09:12:58 +00:00
import torch
from torch.utils.data import Dataset
2021-10-19 20:54:46 +00:00
# taken from https://github.com/optas/latent_3d_points/blob/8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py
synsetid_to_cate = {
2023-04-11 09:12:58 +00:00
"02691156": "airplane",
"02773838": "bag",
"02801938": "basket",
"02808440": "bathtub",
"02818832": "bed",
"02828884": "bench",
"02876657": "bottle",
"02880940": "bowl",
"02924116": "bus",
"02933112": "cabinet",
"02747177": "can",
"02942699": "camera",
"02954340": "cap",
"02958343": "car",
"03001627": "chair",
"03046257": "clock",
"03207941": "dishwasher",
"03211117": "monitor",
"04379243": "table",
"04401088": "telephone",
"02946921": "tin_can",
"04460130": "tower",
"04468005": "train",
"03085013": "keyboard",
"03261776": "earphone",
"03325088": "faucet",
"03337140": "file",
"03467517": "guitar",
"03513137": "helmet",
"03593526": "jar",
"03624134": "knife",
"03636649": "lamp",
"03642806": "laptop",
"03691459": "speaker",
"03710193": "mailbox",
"03759954": "microphone",
"03761084": "microwave",
"03790512": "motorcycle",
"03797390": "mug",
"03928116": "piano",
"03938244": "pillow",
"03948459": "pistol",
"03991062": "pot",
"04004475": "printer",
"04074963": "remote_control",
"04090263": "rifle",
"04099429": "rocket",
"04225987": "skateboard",
"04256520": "sofa",
"04330267": "stove",
"04530566": "vessel",
"04554684": "washer",
"02992529": "cellphone",
"02843684": "birdhouse",
"02871439": "bookshelf",
2021-10-19 20:54:46 +00:00
# '02858304': 'boat', no boat in our dataset, merged into vessels
# '02834778': 'bicycle', not in our taxonomy
}
cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()}
class Uniform15KPC(Dataset):
2023-04-11 09:12:58 +00:00
def __init__(
self,
root_dir,
subdirs,
tr_sample_size=10000,
te_sample_size=10000,
split="train",
scale=1.0,
normalize_per_shape=False,
box_per_shape=False,
random_subsample=False,
normalize_std_per_axis=False,
all_points_mean=None,
all_points_std=None,
input_dim=3,
use_mask=False,
):
2021-10-19 20:54:46 +00:00
self.root_dir = root_dir
self.split = split
self.in_tr_sample_size = tr_sample_size
self.in_te_sample_size = te_sample_size
self.subdirs = subdirs
self.scale = scale
self.random_subsample = random_subsample
self.input_dim = input_dim
self.use_mask = use_mask
self.box_per_shape = box_per_shape
if use_mask:
self.mask_transform = PointCloudMasks(radius=5, elev=5, azim=90)
self.all_cate_mids = []
self.cate_idx_lst = []
self.all_points = []
for cate_idx, subd in enumerate(self.subdirs):
# NOTE: [subd] here is synset id
sub_path = os.path.join(root_dir, subd, self.split)
if not os.path.isdir(sub_path):
print("Directory missing : %s" % sub_path)
continue
all_mids = []
for x in os.listdir(sub_path):
2023-04-11 09:12:58 +00:00
if not x.endswith(".npy"):
2021-10-19 20:54:46 +00:00
continue
2023-04-11 09:12:58 +00:00
all_mids.append(os.path.join(self.split, x[: -len(".npy")]))
2021-10-19 20:54:46 +00:00
# NOTE: [mid] contains the split: i.e. "train/<mid>" or "val/<mid>" or "test/<mid>"
for mid in all_mids:
# obj_fname = os.path.join(sub_path, x)
obj_fname = os.path.join(root_dir, subd, mid + ".npy")
try:
point_cloud = np.load(obj_fname) # (15k, 3)
except:
continue
assert point_cloud.shape[0] == 15000
self.all_points.append(point_cloud[np.newaxis, ...])
self.cate_idx_lst.append(cate_idx)
self.all_cate_mids.append((subd, mid))
# Shuffle the index deterministically (based on the number of examples)
self.shuffle_idx = list(range(len(self.all_points)))
random.Random(38383).shuffle(self.shuffle_idx)
self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
self.all_points = [self.all_points[i] for i in self.shuffle_idx]
self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx]
# Normalization
self.all_points = np.concatenate(self.all_points) # (N, 15000, 3)
self.normalize_per_shape = normalize_per_shape
self.normalize_std_per_axis = normalize_std_per_axis
if all_points_mean is not None and all_points_std is not None: # using loaded dataset stats
self.all_points_mean = all_points_mean
self.all_points_std = all_points_std
elif self.normalize_per_shape: # per shape normalization
B, N = self.all_points.shape[:2]
self.all_points_mean = self.all_points.mean(axis=1).reshape(B, 1, input_dim)
if normalize_std_per_axis:
self.all_points_std = self.all_points.reshape(B, N, -1).std(axis=1).reshape(B, 1, input_dim)
else:
self.all_points_std = self.all_points.reshape(B, -1).std(axis=1).reshape(B, 1, 1)
elif self.box_per_shape:
B, N = self.all_points.shape[:2]
self.all_points_mean = self.all_points.min(axis=1).reshape(B, 1, input_dim)
2023-04-11 09:12:58 +00:00
self.all_points_std = self.all_points.max(axis=1).reshape(B, 1, input_dim) - self.all_points.min(
axis=1
).reshape(B, 1, input_dim)
2021-10-19 20:54:46 +00:00
else: # normalize across the dataset
self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
if normalize_std_per_axis:
self.all_points_std = self.all_points.reshape(-1, input_dim).std(axis=0).reshape(1, 1, input_dim)
else:
self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
if self.box_per_shape:
self.all_points = self.all_points - 0.5
self.train_points = self.all_points[:, :10000]
self.test_points = self.all_points[:, 10000:]
self.tr_sample_size = min(10000, tr_sample_size)
self.te_sample_size = min(5000, te_sample_size)
print("Total number of data:%d" % len(self.train_points))
2023-04-11 09:12:58 +00:00
print("Min number of points: (train)%d (test)%d" % (self.tr_sample_size, self.te_sample_size))
2021-10-19 20:54:46 +00:00
assert self.scale == 1, "Scale (!= 1) is deprecated"
def get_pc_stats(self, idx):
if self.normalize_per_shape or self.box_per_shape:
m = self.all_points_mean[idx].reshape(1, self.input_dim)
s = self.all_points_std[idx].reshape(1, -1)
return m, s
return self.all_points_mean.reshape(1, -1), self.all_points_std.reshape(1, -1)
def renormalize(self, mean, std):
self.all_points = self.all_points * self.all_points_std + self.all_points_mean
self.all_points_mean = mean
self.all_points_std = std
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
self.train_points = self.all_points[:, :10000]
self.test_points = self.all_points[:, 10000:]
def __len__(self):
return len(self.train_points)
def __getitem__(self, idx):
tr_out = self.train_points[idx]
if self.random_subsample:
tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size)
else:
tr_idxs = np.arange(self.tr_sample_size)
tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float()
te_out = self.test_points[idx]
if self.random_subsample:
te_idxs = np.random.choice(te_out.shape[0], self.te_sample_size)
else:
te_idxs = np.arange(self.te_sample_size)
te_out = torch.from_numpy(te_out[te_idxs, :]).float()
m, s = self.get_pc_stats(idx)
cate_idx = self.cate_idx_lst[idx]
sid, mid = self.all_cate_mids[idx]
out = {
2023-04-11 09:12:58 +00:00
"idx": idx,
"train_points": tr_out,
"test_points": te_out,
"mean": m,
"std": s,
"cate_idx": cate_idx,
"sid": sid,
"mid": mid,
2021-10-19 20:54:46 +00:00
}
if self.use_mask:
# masked = torch.from_numpy(self.mask_transform(self.all_points[idx]))
# ss = min(masked.shape[0], self.in_tr_sample_size//2)
# masked = masked[:ss]
#
# tr_mask = torch.ones_like(masked)
# masked = torch.cat([masked, torch.zeros(self.in_tr_sample_size - ss, 3)],dim=0)#F.pad(masked, (self.in_tr_sample_size-masked.shape[0], 0), "constant", 0)
#
# tr_mask = torch.cat([tr_mask, torch.zeros(self.in_tr_sample_size- ss, 3)],dim=0)#F.pad(tr_mask, (self.in_tr_sample_size-tr_mask.shape[0], 0), "constant", 0)
# out['train_points_masked'] = masked
# out['train_masks'] = tr_mask
tr_mask = self.mask_transform(tr_out)
2023-04-11 09:12:58 +00:00
out["train_masks"] = tr_mask
2021-10-19 20:54:46 +00:00
return out
class ShapeNet15kPointClouds(Uniform15KPC):
2023-04-11 09:12:58 +00:00
def __init__(
self,
root_dir="data/ShapeNetCore.v2.PC15k",
categories=["airplane"],
tr_sample_size=10000,
te_sample_size=2048,
split="train",
scale=1.0,
normalize_per_shape=False,
normalize_std_per_axis=False,
box_per_shape=False,
random_subsample=False,
all_points_mean=None,
all_points_std=None,
use_mask=False,
):
2021-10-19 20:54:46 +00:00
self.root_dir = root_dir
self.split = split
2023-04-11 09:12:58 +00:00
assert self.split in ["train", "test", "val"]
2021-10-19 20:54:46 +00:00
self.tr_sample_size = tr_sample_size
self.te_sample_size = te_sample_size
self.cates = categories
2023-04-11 09:12:58 +00:00
if "all" in categories:
2021-10-19 20:54:46 +00:00
self.synset_ids = list(cate_to_synsetid.values())
else:
self.synset_ids = [cate_to_synsetid[c] for c in self.cates]
# assert 'v2' in root_dir, "Only supporting v2 right now."
self.gravity_axis = 1
self.display_axis_order = [0, 2, 1]
super(ShapeNet15kPointClouds, self).__init__(
2023-04-11 09:12:58 +00:00
root_dir,
self.synset_ids,
2021-10-19 20:54:46 +00:00
tr_sample_size=tr_sample_size,
te_sample_size=te_sample_size,
2023-04-11 09:12:58 +00:00
split=split,
scale=scale,
normalize_per_shape=normalize_per_shape,
box_per_shape=box_per_shape,
2021-10-19 20:54:46 +00:00
normalize_std_per_axis=normalize_std_per_axis,
random_subsample=random_subsample,
2023-04-11 09:12:58 +00:00
all_points_mean=all_points_mean,
all_points_std=all_points_std,
input_dim=3,
use_mask=use_mask,
)
2021-10-19 20:54:46 +00:00
####################################################################################