import hashlib import os import warnings from pathlib import Path import matplotlib.pyplot as plt import numpy as np import torch from torch.utils.data import Dataset from tqdm import tqdm synset_to_label = { "02691156": "airplane", "02773838": "bag", "02801938": "basket", "02808440": "bathtub", "02818832": "bed", "02828884": "bench", "02876657": "bottle", "02880940": "bowl", "02924116": "bus", "02933112": "cabinet", "02747177": "can", "02942699": "camera", "02954340": "cap", "02958343": "car", "03001627": "chair", "03046257": "clock", "03207941": "dishwasher", "03211117": "monitor", "04379243": "table", "04401088": "telephone", "02946921": "tin_can", "04460130": "tower", "04468005": "train", "03085013": "keyboard", "03261776": "earphone", "03325088": "faucet", "03337140": "file", "03467517": "guitar", "03513137": "helmet", "03593526": "jar", "03624134": "knife", "03636649": "lamp", "03642806": "laptop", "03691459": "speaker", "03710193": "mailbox", "03759954": "microphone", "03761084": "microwave", "03790512": "motorcycle", "03797390": "mug", "03928116": "piano", "03938244": "pillow", "03948459": "pistol", "03991062": "pot", "04004475": "printer", "04074963": "remote_control", "04090263": "rifle", "04099429": "rocket", "04225987": "skateboard", "04256520": "sofa", "04330267": "stove", "04530566": "vessel", "04554684": "washer", "02992529": "cellphone", "02843684": "birdhouse", "02871439": "bookshelf", # '02858304': 'boat', no boat in our dataset, merged into vessels # '02834778': 'bicycle', not in our taxonomy } # Label to Synset mapping (for ShapeNet core classes) label_to_synset = {v: k for k, v in synset_to_label.items()} def _convert_categories(categories): assert categories is not None, "List of categories cannot be empty!" if not (c in synset_to_label.keys() + label_to_synset.keys() for c in categories): warnings.warn( "Some or all of the categories requested are not part of \ ShapeNetCore. Data loading may fail if these categories are not avaliable." ) synsets = [label_to_synset[c] if c in label_to_synset.keys() else c for c in categories] return synsets class ShapeNet_Multiview_Points(Dataset): def __init__( self, root_pc: str, root_views: str, cache: str, categories: list = ["chair"], split: str = "val", npoints=2048, sv_samples=800, all_points_mean=None, all_points_std=None, get_image=False, ): self.root = Path(root_views) self.split = split self.get_image = get_image params = { "cat": categories, "npoints": npoints, "sv_samples": sv_samples, } params = tuple(sorted(pair for pair in params.items())) self.cache_dir = Path(cache) / "svpoints/{}/{}".format( "_".join(categories), hashlib.md5(bytes(repr(params), "utf-8")).hexdigest() ) self.cache_dir.mkdir(parents=True, exist_ok=True) self.paths = [] self.synset_idxs = [] self.synsets = _convert_categories(categories) self.labels = [synset_to_label[s] for s in self.synsets] self.npoints = npoints self.sv_samples = sv_samples self.all_points = [] self.all_points_sv = [] # loops through desired classes for i in range(len(self.synsets)): syn = self.synsets[i] class_target = self.root / syn if not class_target.exists(): raise ValueError( "Class {0} ({1}) was not found at location {2}.".format(syn, self.labels[i], str(class_target)) ) sub_path_pc = os.path.join(root_pc, syn, split) if not os.path.isdir(sub_path_pc): print("Directory missing : %s" % sub_path_pc) continue self.all_mids = [] self.imgs = [] for x in os.listdir(sub_path_pc): if not x.endswith(".npy"): continue self.all_mids.append(os.path.join(split, x[: -len(".npy")])) for mid in tqdm(self.all_mids): # obj_fname = os.path.join(sub_path, x) obj_fname = os.path.join(root_pc, syn, mid + ".npy") cams_pths = list((self.root / syn / mid.split("/")[-1]).glob("*_cam_params.npz")) if len(cams_pths) < 20: continue point_cloud = np.load(obj_fname) sv_points_group = [] img_path_group = [] (self.cache_dir / (mid.split("/")[-1])).mkdir(parents=True, exist_ok=True) success = True for i, cp in enumerate(cams_pths): cp = str(cp) vp = cp.split("cam_params")[0] + "depth.png" depth_minmax_pth = cp.split("_cam_params")[0] + ".npy" cache_pth = str(self.cache_dir / mid.split("/")[-1] / os.path.basename(depth_minmax_pth)) cam_params = np.load(cp) extr = cam_params["extr"] intr = cam_params["intr"] self.transform = DepthToSingleViewPoints(cam_ext=extr, cam_int=intr) try: sv_point_cloud = self._render(cache_pth, vp, depth_minmax_pth) img_path_group.append(vp) sv_points_group.append(sv_point_cloud) except Exception as e: print(e) success = False break if not success: continue self.all_points_sv.append(np.stack(sv_points_group, axis=0)) self.all_points.append(point_cloud) self.imgs.append(img_path_group) self.all_points = np.stack(self.all_points, axis=0) self.all_points_sv = np.stack(self.all_points_sv, axis=0) if all_points_mean is not None and all_points_std is not None: # using loaded dataset stats self.all_points_mean = all_points_mean self.all_points_std = all_points_std else: # normalize across the dataset self.all_points_mean = self.all_points.reshape(-1, 3).mean(axis=0).reshape(1, 1, 3) self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1) self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std self.train_points = self.all_points[:, :10000] self.test_points = self.all_points[:, 10000:] self.all_points_sv = (self.all_points_sv - self.all_points_mean) / self.all_points_std def get_pc_stats(self, idx): return self.all_points_mean.reshape(1, 1, -1), self.all_points_std.reshape(1, 1, -1) def __len__(self): """Returns the length of the dataset.""" return len(self.all_points) def __getitem__(self, index): tr_out = self.train_points[index] tr_idxs = np.random.choice(tr_out.shape[0], self.npoints) tr_out = tr_out[tr_idxs, :] gt_points = self.test_points[index][: self.npoints] m, s = self.get_pc_stats(index) sv_points = self.all_points_sv[index] idxs = np.arange(0, sv_points.shape[-2])[ : self.sv_samples ] # np.random.choice(sv_points.shape[0], 500, replace=False) data = torch.cat( [ torch.from_numpy(sv_points[:, idxs]).float(), torch.zeros(sv_points.shape[0], self.npoints - idxs.shape[0], sv_points.shape[2]), ], dim=1, ) masks = torch.zeros_like(data) masks[:, : idxs.shape[0]] = 1 res = { "train_points": torch.from_numpy(tr_out).float(), "test_points": torch.from_numpy(gt_points).float(), "sv_points": data, "masks": masks, "std": s, "mean": m, "idx": index, "name": self.all_mids[index], } if self.split != "train" and self.get_image: img_lst = [] for n in range(self.all_points_sv.shape[1]): img = torch.from_numpy(plt.imread(self.imgs[index][n])).float().permute(2, 0, 1)[:3] img_lst.append(img) img = torch.stack(img_lst, dim=0) res["image"] = img return res def _render(self, cache_path, depth_pth, depth_minmax_pth): # if not os.path.exists(cache_path.split('.npy')[0] + '_color.png') and os.path.exists(cache_path): # # os.remove(cache_path) if os.path.exists(cache_path): data = np.load(cache_path) else: data, depth = self.transform(depth_pth, depth_minmax_pth) assert data.shape[0] > 600, "Only {} points found".format(data.shape[0]) data = data[np.random.choice(data.shape[0], 600, replace=False)] np.save(cache_path, data) return data