PVD/dataset/shapenet_data_sv.py
2023-04-11 13:50:00 +02:00

270 lines
9.1 KiB
Python

import hashlib
import os
import warnings
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
synset_to_label = {
"02691156": "airplane",
"02773838": "bag",
"02801938": "basket",
"02808440": "bathtub",
"02818832": "bed",
"02828884": "bench",
"02876657": "bottle",
"02880940": "bowl",
"02924116": "bus",
"02933112": "cabinet",
"02747177": "can",
"02942699": "camera",
"02954340": "cap",
"02958343": "car",
"03001627": "chair",
"03046257": "clock",
"03207941": "dishwasher",
"03211117": "monitor",
"04379243": "table",
"04401088": "telephone",
"02946921": "tin_can",
"04460130": "tower",
"04468005": "train",
"03085013": "keyboard",
"03261776": "earphone",
"03325088": "faucet",
"03337140": "file",
"03467517": "guitar",
"03513137": "helmet",
"03593526": "jar",
"03624134": "knife",
"03636649": "lamp",
"03642806": "laptop",
"03691459": "speaker",
"03710193": "mailbox",
"03759954": "microphone",
"03761084": "microwave",
"03790512": "motorcycle",
"03797390": "mug",
"03928116": "piano",
"03938244": "pillow",
"03948459": "pistol",
"03991062": "pot",
"04004475": "printer",
"04074963": "remote_control",
"04090263": "rifle",
"04099429": "rocket",
"04225987": "skateboard",
"04256520": "sofa",
"04330267": "stove",
"04530566": "vessel",
"04554684": "washer",
"02992529": "cellphone",
"02843684": "birdhouse",
"02871439": "bookshelf",
# '02858304': 'boat', no boat in our dataset, merged into vessels
# '02834778': 'bicycle', not in our taxonomy
}
# Label to Synset mapping (for ShapeNet core classes)
label_to_synset = {v: k for k, v in synset_to_label.items()}
def _convert_categories(categories):
assert categories is not None, "List of categories cannot be empty!"
if not (c in synset_to_label.keys() + label_to_synset.keys() for c in categories):
warnings.warn(
"Some or all of the categories requested are not part of \
ShapeNetCore. Data loading may fail if these categories are not avaliable."
)
synsets = [label_to_synset[c] if c in label_to_synset.keys() else c for c in categories]
return synsets
class ShapeNet_Multiview_Points(Dataset):
def __init__(
self,
root_pc: str,
root_views: str,
cache: str,
categories: list = ["chair"],
split: str = "val",
npoints=2048,
sv_samples=800,
all_points_mean=None,
all_points_std=None,
get_image=False,
):
self.root = Path(root_views)
self.split = split
self.get_image = get_image
params = {
"cat": categories,
"npoints": npoints,
"sv_samples": sv_samples,
}
params = tuple(sorted(pair for pair in params.items()))
self.cache_dir = Path(cache) / "svpoints/{}/{}".format(
"_".join(categories), hashlib.md5(bytes(repr(params), "utf-8")).hexdigest()
)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.paths = []
self.synset_idxs = []
self.synsets = _convert_categories(categories)
self.labels = [synset_to_label[s] for s in self.synsets]
self.npoints = npoints
self.sv_samples = sv_samples
self.all_points = []
self.all_points_sv = []
# loops through desired classes
for i in range(len(self.synsets)):
syn = self.synsets[i]
class_target = self.root / syn
if not class_target.exists():
raise ValueError(
"Class {0} ({1}) was not found at location {2}.".format(syn, self.labels[i], str(class_target))
)
sub_path_pc = os.path.join(root_pc, syn, split)
if not os.path.isdir(sub_path_pc):
print("Directory missing : %s" % sub_path_pc)
continue
self.all_mids = []
self.imgs = []
for x in os.listdir(sub_path_pc):
if not x.endswith(".npy"):
continue
self.all_mids.append(os.path.join(split, x[: -len(".npy")]))
for mid in tqdm(self.all_mids):
# obj_fname = os.path.join(sub_path, x)
obj_fname = os.path.join(root_pc, syn, mid + ".npy")
cams_pths = list((self.root / syn / mid.split("/")[-1]).glob("*_cam_params.npz"))
if len(cams_pths) < 20:
continue
point_cloud = np.load(obj_fname)
sv_points_group = []
img_path_group = []
(self.cache_dir / (mid.split("/")[-1])).mkdir(parents=True, exist_ok=True)
success = True
for i, cp in enumerate(cams_pths):
cp = str(cp)
vp = cp.split("cam_params")[0] + "depth.png"
depth_minmax_pth = cp.split("_cam_params")[0] + ".npy"
cache_pth = str(self.cache_dir / mid.split("/")[-1] / os.path.basename(depth_minmax_pth))
cam_params = np.load(cp)
extr = cam_params["extr"]
intr = cam_params["intr"]
self.transform = DepthToSingleViewPoints(cam_ext=extr, cam_int=intr)
try:
sv_point_cloud = self._render(cache_pth, vp, depth_minmax_pth)
img_path_group.append(vp)
sv_points_group.append(sv_point_cloud)
except Exception as e:
print(e)
success = False
break
if not success:
continue
self.all_points_sv.append(np.stack(sv_points_group, axis=0))
self.all_points.append(point_cloud)
self.imgs.append(img_path_group)
self.all_points = np.stack(self.all_points, axis=0)
self.all_points_sv = np.stack(self.all_points_sv, axis=0)
if all_points_mean is not None and all_points_std is not None: # using loaded dataset stats
self.all_points_mean = all_points_mean
self.all_points_std = all_points_std
else: # normalize across the dataset
self.all_points_mean = self.all_points.reshape(-1, 3).mean(axis=0).reshape(1, 1, 3)
self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
self.train_points = self.all_points[:, :10000]
self.test_points = self.all_points[:, 10000:]
self.all_points_sv = (self.all_points_sv - self.all_points_mean) / self.all_points_std
def get_pc_stats(self, idx):
return self.all_points_mean.reshape(1, 1, -1), self.all_points_std.reshape(1, 1, -1)
def __len__(self):
"""Returns the length of the dataset."""
return len(self.all_points)
def __getitem__(self, index):
tr_out = self.train_points[index]
tr_idxs = np.random.choice(tr_out.shape[0], self.npoints)
tr_out = tr_out[tr_idxs, :]
gt_points = self.test_points[index][: self.npoints]
m, s = self.get_pc_stats(index)
sv_points = self.all_points_sv[index]
idxs = np.arange(0, sv_points.shape[-2])[
: self.sv_samples
] # np.random.choice(sv_points.shape[0], 500, replace=False)
data = torch.cat(
[
torch.from_numpy(sv_points[:, idxs]).float(),
torch.zeros(sv_points.shape[0], self.npoints - idxs.shape[0], sv_points.shape[2]),
],
dim=1,
)
masks = torch.zeros_like(data)
masks[:, : idxs.shape[0]] = 1
res = {
"train_points": torch.from_numpy(tr_out).float(),
"test_points": torch.from_numpy(gt_points).float(),
"sv_points": data,
"masks": masks,
"std": s,
"mean": m,
"idx": index,
"name": self.all_mids[index],
}
if self.split != "train" and self.get_image:
img_lst = []
for n in range(self.all_points_sv.shape[1]):
img = torch.from_numpy(plt.imread(self.imgs[index][n])).float().permute(2, 0, 1)[:3]
img_lst.append(img)
img = torch.stack(img_lst, dim=0)
res["image"] = img
return res
def _render(self, cache_path, depth_pth, depth_minmax_pth):
# if not os.path.exists(cache_path.split('.npy')[0] + '_color.png') and os.path.exists(cache_path):
#
# os.remove(cache_path)
if os.path.exists(cache_path):
data = np.load(cache_path)
else:
data, depth = self.transform(depth_pth, depth_minmax_pth)
assert data.shape[0] > 600, "Only {} points found".format(data.shape[0])
data = data[np.random.choice(data.shape[0], 600, replace=False)]
np.save(cache_path, data)
return data