2021-10-19 20:54:46 +00:00
|
|
|
import os
|
|
|
|
import random
|
2023-04-11 09:12:58 +00:00
|
|
|
|
2021-10-19 20:54:46 +00:00
|
|
|
import numpy as np
|
2023-04-11 09:12:58 +00:00
|
|
|
import torch
|
|
|
|
from torch.utils.data import Dataset
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
# taken from https://github.com/optas/latent_3d_points/blob/8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py
|
|
|
|
synsetid_to_cate = {
|
2023-04-11 09:12:58 +00:00
|
|
|
"02691156": "airplane",
|
|
|
|
"02773838": "bag",
|
|
|
|
"02801938": "basket",
|
|
|
|
"02808440": "bathtub",
|
|
|
|
"02818832": "bed",
|
|
|
|
"02828884": "bench",
|
|
|
|
"02876657": "bottle",
|
|
|
|
"02880940": "bowl",
|
|
|
|
"02924116": "bus",
|
|
|
|
"02933112": "cabinet",
|
|
|
|
"02747177": "can",
|
|
|
|
"02942699": "camera",
|
|
|
|
"02954340": "cap",
|
|
|
|
"02958343": "car",
|
|
|
|
"03001627": "chair",
|
|
|
|
"03046257": "clock",
|
|
|
|
"03207941": "dishwasher",
|
|
|
|
"03211117": "monitor",
|
|
|
|
"04379243": "table",
|
|
|
|
"04401088": "telephone",
|
|
|
|
"02946921": "tin_can",
|
|
|
|
"04460130": "tower",
|
|
|
|
"04468005": "train",
|
|
|
|
"03085013": "keyboard",
|
|
|
|
"03261776": "earphone",
|
|
|
|
"03325088": "faucet",
|
|
|
|
"03337140": "file",
|
|
|
|
"03467517": "guitar",
|
|
|
|
"03513137": "helmet",
|
|
|
|
"03593526": "jar",
|
|
|
|
"03624134": "knife",
|
|
|
|
"03636649": "lamp",
|
|
|
|
"03642806": "laptop",
|
|
|
|
"03691459": "speaker",
|
|
|
|
"03710193": "mailbox",
|
|
|
|
"03759954": "microphone",
|
|
|
|
"03761084": "microwave",
|
|
|
|
"03790512": "motorcycle",
|
|
|
|
"03797390": "mug",
|
|
|
|
"03928116": "piano",
|
|
|
|
"03938244": "pillow",
|
|
|
|
"03948459": "pistol",
|
|
|
|
"03991062": "pot",
|
|
|
|
"04004475": "printer",
|
|
|
|
"04074963": "remote_control",
|
|
|
|
"04090263": "rifle",
|
|
|
|
"04099429": "rocket",
|
|
|
|
"04225987": "skateboard",
|
|
|
|
"04256520": "sofa",
|
|
|
|
"04330267": "stove",
|
|
|
|
"04530566": "vessel",
|
|
|
|
"04554684": "washer",
|
|
|
|
"02992529": "cellphone",
|
|
|
|
"02843684": "birdhouse",
|
|
|
|
"02871439": "bookshelf",
|
2021-10-19 20:54:46 +00:00
|
|
|
# '02858304': 'boat', no boat in our dataset, merged into vessels
|
|
|
|
# '02834778': 'bicycle', not in our taxonomy
|
|
|
|
}
|
|
|
|
cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()}
|
|
|
|
|
|
|
|
|
|
|
|
class Uniform15KPC(Dataset):
|
2023-04-11 09:12:58 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
root_dir,
|
|
|
|
subdirs,
|
|
|
|
tr_sample_size=10000,
|
|
|
|
te_sample_size=10000,
|
|
|
|
split="train",
|
|
|
|
scale=1.0,
|
|
|
|
normalize_per_shape=False,
|
|
|
|
box_per_shape=False,
|
|
|
|
random_subsample=False,
|
|
|
|
normalize_std_per_axis=False,
|
|
|
|
all_points_mean=None,
|
|
|
|
all_points_std=None,
|
|
|
|
input_dim=3,
|
|
|
|
use_mask=False,
|
|
|
|
):
|
2021-10-19 20:54:46 +00:00
|
|
|
self.root_dir = root_dir
|
|
|
|
self.split = split
|
|
|
|
self.in_tr_sample_size = tr_sample_size
|
|
|
|
self.in_te_sample_size = te_sample_size
|
|
|
|
self.subdirs = subdirs
|
|
|
|
self.scale = scale
|
|
|
|
self.random_subsample = random_subsample
|
|
|
|
self.input_dim = input_dim
|
|
|
|
self.use_mask = use_mask
|
|
|
|
self.box_per_shape = box_per_shape
|
|
|
|
if use_mask:
|
|
|
|
self.mask_transform = PointCloudMasks(radius=5, elev=5, azim=90)
|
|
|
|
|
|
|
|
self.all_cate_mids = []
|
|
|
|
self.cate_idx_lst = []
|
|
|
|
self.all_points = []
|
|
|
|
for cate_idx, subd in enumerate(self.subdirs):
|
|
|
|
# NOTE: [subd] here is synset id
|
|
|
|
sub_path = os.path.join(root_dir, subd, self.split)
|
|
|
|
if not os.path.isdir(sub_path):
|
|
|
|
print("Directory missing : %s" % sub_path)
|
|
|
|
continue
|
|
|
|
|
|
|
|
all_mids = []
|
|
|
|
for x in os.listdir(sub_path):
|
2023-04-11 09:12:58 +00:00
|
|
|
if not x.endswith(".npy"):
|
2021-10-19 20:54:46 +00:00
|
|
|
continue
|
2023-04-11 09:12:58 +00:00
|
|
|
all_mids.append(os.path.join(self.split, x[: -len(".npy")]))
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
# NOTE: [mid] contains the split: i.e. "train/<mid>" or "val/<mid>" or "test/<mid>"
|
|
|
|
for mid in all_mids:
|
|
|
|
# obj_fname = os.path.join(sub_path, x)
|
|
|
|
obj_fname = os.path.join(root_dir, subd, mid + ".npy")
|
|
|
|
try:
|
|
|
|
point_cloud = np.load(obj_fname) # (15k, 3)
|
|
|
|
|
|
|
|
except:
|
|
|
|
continue
|
|
|
|
|
|
|
|
assert point_cloud.shape[0] == 15000
|
|
|
|
self.all_points.append(point_cloud[np.newaxis, ...])
|
|
|
|
self.cate_idx_lst.append(cate_idx)
|
|
|
|
self.all_cate_mids.append((subd, mid))
|
|
|
|
|
|
|
|
# Shuffle the index deterministically (based on the number of examples)
|
|
|
|
self.shuffle_idx = list(range(len(self.all_points)))
|
|
|
|
random.Random(38383).shuffle(self.shuffle_idx)
|
|
|
|
self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
|
|
|
|
self.all_points = [self.all_points[i] for i in self.shuffle_idx]
|
|
|
|
self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx]
|
|
|
|
|
|
|
|
# Normalization
|
|
|
|
self.all_points = np.concatenate(self.all_points) # (N, 15000, 3)
|
|
|
|
self.normalize_per_shape = normalize_per_shape
|
|
|
|
self.normalize_std_per_axis = normalize_std_per_axis
|
|
|
|
if all_points_mean is not None and all_points_std is not None: # using loaded dataset stats
|
|
|
|
self.all_points_mean = all_points_mean
|
|
|
|
self.all_points_std = all_points_std
|
|
|
|
elif self.normalize_per_shape: # per shape normalization
|
|
|
|
B, N = self.all_points.shape[:2]
|
|
|
|
self.all_points_mean = self.all_points.mean(axis=1).reshape(B, 1, input_dim)
|
|
|
|
if normalize_std_per_axis:
|
|
|
|
self.all_points_std = self.all_points.reshape(B, N, -1).std(axis=1).reshape(B, 1, input_dim)
|
|
|
|
else:
|
|
|
|
self.all_points_std = self.all_points.reshape(B, -1).std(axis=1).reshape(B, 1, 1)
|
|
|
|
elif self.box_per_shape:
|
|
|
|
B, N = self.all_points.shape[:2]
|
|
|
|
self.all_points_mean = self.all_points.min(axis=1).reshape(B, 1, input_dim)
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
self.all_points_std = self.all_points.max(axis=1).reshape(B, 1, input_dim) - self.all_points.min(
|
|
|
|
axis=1
|
|
|
|
).reshape(B, 1, input_dim)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
else: # normalize across the dataset
|
|
|
|
self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
|
|
|
|
if normalize_std_per_axis:
|
|
|
|
self.all_points_std = self.all_points.reshape(-1, input_dim).std(axis=0).reshape(1, 1, input_dim)
|
|
|
|
else:
|
|
|
|
self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)
|
|
|
|
|
|
|
|
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
|
|
|
|
if self.box_per_shape:
|
|
|
|
self.all_points = self.all_points - 0.5
|
|
|
|
self.train_points = self.all_points[:, :10000]
|
|
|
|
self.test_points = self.all_points[:, 10000:]
|
|
|
|
|
|
|
|
self.tr_sample_size = min(10000, tr_sample_size)
|
|
|
|
self.te_sample_size = min(5000, te_sample_size)
|
|
|
|
print("Total number of data:%d" % len(self.train_points))
|
2023-04-11 09:12:58 +00:00
|
|
|
print("Min number of points: (train)%d (test)%d" % (self.tr_sample_size, self.te_sample_size))
|
2021-10-19 20:54:46 +00:00
|
|
|
assert self.scale == 1, "Scale (!= 1) is deprecated"
|
|
|
|
|
|
|
|
def get_pc_stats(self, idx):
|
|
|
|
if self.normalize_per_shape or self.box_per_shape:
|
|
|
|
m = self.all_points_mean[idx].reshape(1, self.input_dim)
|
|
|
|
s = self.all_points_std[idx].reshape(1, -1)
|
|
|
|
return m, s
|
|
|
|
|
|
|
|
return self.all_points_mean.reshape(1, -1), self.all_points_std.reshape(1, -1)
|
|
|
|
|
|
|
|
def renormalize(self, mean, std):
|
|
|
|
self.all_points = self.all_points * self.all_points_std + self.all_points_mean
|
|
|
|
self.all_points_mean = mean
|
|
|
|
self.all_points_std = std
|
|
|
|
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
|
|
|
|
self.train_points = self.all_points[:, :10000]
|
|
|
|
self.test_points = self.all_points[:, 10000:]
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.train_points)
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
tr_out = self.train_points[idx]
|
|
|
|
if self.random_subsample:
|
|
|
|
tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size)
|
|
|
|
else:
|
|
|
|
tr_idxs = np.arange(self.tr_sample_size)
|
|
|
|
tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float()
|
|
|
|
|
|
|
|
te_out = self.test_points[idx]
|
|
|
|
if self.random_subsample:
|
|
|
|
te_idxs = np.random.choice(te_out.shape[0], self.te_sample_size)
|
|
|
|
else:
|
|
|
|
te_idxs = np.arange(self.te_sample_size)
|
|
|
|
te_out = torch.from_numpy(te_out[te_idxs, :]).float()
|
|
|
|
|
|
|
|
m, s = self.get_pc_stats(idx)
|
|
|
|
cate_idx = self.cate_idx_lst[idx]
|
|
|
|
sid, mid = self.all_cate_mids[idx]
|
|
|
|
|
|
|
|
out = {
|
2023-04-11 09:12:58 +00:00
|
|
|
"idx": idx,
|
|
|
|
"train_points": tr_out,
|
|
|
|
"test_points": te_out,
|
|
|
|
"mean": m,
|
|
|
|
"std": s,
|
|
|
|
"cate_idx": cate_idx,
|
|
|
|
"sid": sid,
|
|
|
|
"mid": mid,
|
2021-10-19 20:54:46 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if self.use_mask:
|
|
|
|
# masked = torch.from_numpy(self.mask_transform(self.all_points[idx]))
|
|
|
|
# ss = min(masked.shape[0], self.in_tr_sample_size//2)
|
|
|
|
# masked = masked[:ss]
|
|
|
|
#
|
|
|
|
# tr_mask = torch.ones_like(masked)
|
|
|
|
# masked = torch.cat([masked, torch.zeros(self.in_tr_sample_size - ss, 3)],dim=0)#F.pad(masked, (self.in_tr_sample_size-masked.shape[0], 0), "constant", 0)
|
|
|
|
#
|
|
|
|
# tr_mask = torch.cat([tr_mask, torch.zeros(self.in_tr_sample_size- ss, 3)],dim=0)#F.pad(tr_mask, (self.in_tr_sample_size-tr_mask.shape[0], 0), "constant", 0)
|
|
|
|
# out['train_points_masked'] = masked
|
|
|
|
# out['train_masks'] = tr_mask
|
|
|
|
tr_mask = self.mask_transform(tr_out)
|
2023-04-11 09:12:58 +00:00
|
|
|
out["train_masks"] = tr_mask
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
class ShapeNet15kPointClouds(Uniform15KPC):
|
2023-04-11 09:12:58 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
root_dir="data/ShapeNetCore.v2.PC15k",
|
|
|
|
categories=["airplane"],
|
|
|
|
tr_sample_size=10000,
|
|
|
|
te_sample_size=2048,
|
|
|
|
split="train",
|
|
|
|
scale=1.0,
|
|
|
|
normalize_per_shape=False,
|
|
|
|
normalize_std_per_axis=False,
|
|
|
|
box_per_shape=False,
|
|
|
|
random_subsample=False,
|
|
|
|
all_points_mean=None,
|
|
|
|
all_points_std=None,
|
|
|
|
use_mask=False,
|
|
|
|
):
|
2021-10-19 20:54:46 +00:00
|
|
|
self.root_dir = root_dir
|
|
|
|
self.split = split
|
2023-04-11 09:12:58 +00:00
|
|
|
assert self.split in ["train", "test", "val"]
|
2021-10-19 20:54:46 +00:00
|
|
|
self.tr_sample_size = tr_sample_size
|
|
|
|
self.te_sample_size = te_sample_size
|
|
|
|
self.cates = categories
|
2023-04-11 09:12:58 +00:00
|
|
|
if "all" in categories:
|
2021-10-19 20:54:46 +00:00
|
|
|
self.synset_ids = list(cate_to_synsetid.values())
|
|
|
|
else:
|
|
|
|
self.synset_ids = [cate_to_synsetid[c] for c in self.cates]
|
|
|
|
|
|
|
|
# assert 'v2' in root_dir, "Only supporting v2 right now."
|
|
|
|
self.gravity_axis = 1
|
|
|
|
self.display_axis_order = [0, 2, 1]
|
|
|
|
|
|
|
|
super(ShapeNet15kPointClouds, self).__init__(
|
2023-04-11 09:12:58 +00:00
|
|
|
root_dir,
|
|
|
|
self.synset_ids,
|
2021-10-19 20:54:46 +00:00
|
|
|
tr_sample_size=tr_sample_size,
|
|
|
|
te_sample_size=te_sample_size,
|
2023-04-11 09:12:58 +00:00
|
|
|
split=split,
|
|
|
|
scale=scale,
|
|
|
|
normalize_per_shape=normalize_per_shape,
|
|
|
|
box_per_shape=box_per_shape,
|
2021-10-19 20:54:46 +00:00
|
|
|
normalize_std_per_axis=normalize_std_per_axis,
|
|
|
|
random_subsample=random_subsample,
|
2023-04-11 09:12:58 +00:00
|
|
|
all_points_mean=all_points_mean,
|
|
|
|
all_points_std=all_points_std,
|
|
|
|
input_dim=3,
|
|
|
|
use_mask=use_mask,
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
|
|
|
|
####################################################################################
|