PVD/dataset/shapenet_data_sv.py

270 lines
9.1 KiB
Python
Raw Normal View History

2023-04-11 09:12:58 +00:00
import hashlib
import os
2021-10-19 20:54:46 +00:00
import warnings
from pathlib import Path
import matplotlib.pyplot as plt
2023-04-11 09:12:58 +00:00
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
2021-10-19 20:54:46 +00:00
synset_to_label = {
2023-04-11 09:12:58 +00:00
"02691156": "airplane",
"02773838": "bag",
"02801938": "basket",
"02808440": "bathtub",
"02818832": "bed",
"02828884": "bench",
"02876657": "bottle",
"02880940": "bowl",
"02924116": "bus",
"02933112": "cabinet",
"02747177": "can",
"02942699": "camera",
"02954340": "cap",
"02958343": "car",
"03001627": "chair",
"03046257": "clock",
"03207941": "dishwasher",
"03211117": "monitor",
"04379243": "table",
"04401088": "telephone",
"02946921": "tin_can",
"04460130": "tower",
"04468005": "train",
"03085013": "keyboard",
"03261776": "earphone",
"03325088": "faucet",
"03337140": "file",
"03467517": "guitar",
"03513137": "helmet",
"03593526": "jar",
"03624134": "knife",
"03636649": "lamp",
"03642806": "laptop",
"03691459": "speaker",
"03710193": "mailbox",
"03759954": "microphone",
"03761084": "microwave",
"03790512": "motorcycle",
"03797390": "mug",
"03928116": "piano",
"03938244": "pillow",
"03948459": "pistol",
"03991062": "pot",
"04004475": "printer",
"04074963": "remote_control",
"04090263": "rifle",
"04099429": "rocket",
"04225987": "skateboard",
"04256520": "sofa",
"04330267": "stove",
"04530566": "vessel",
"04554684": "washer",
"02992529": "cellphone",
"02843684": "birdhouse",
"02871439": "bookshelf",
2021-10-19 20:54:46 +00:00
# '02858304': 'boat', no boat in our dataset, merged into vessels
# '02834778': 'bicycle', not in our taxonomy
}
# Label to Synset mapping (for ShapeNet core classes)
label_to_synset = {v: k for k, v in synset_to_label.items()}
2023-04-11 09:12:58 +00:00
2021-10-19 20:54:46 +00:00
def _convert_categories(categories):
2023-04-11 09:12:58 +00:00
assert categories is not None, "List of categories cannot be empty!"
if not (c in synset_to_label.keys() + label_to_synset.keys() for c in categories):
warnings.warn(
"Some or all of the categories requested are not part of \
ShapeNetCore. Data loading may fail if these categories are not avaliable."
)
synsets = [label_to_synset[c] if c in label_to_synset.keys() else c for c in categories]
2021-10-19 20:54:46 +00:00
return synsets
class ShapeNet_Multiview_Points(Dataset):
2023-04-11 09:12:58 +00:00
def __init__(
self,
root_pc: str,
root_views: str,
cache: str,
categories: list = ["chair"],
split: str = "val",
npoints=2048,
sv_samples=800,
all_points_mean=None,
all_points_std=None,
get_image=False,
):
2021-10-19 20:54:46 +00:00
self.root = Path(root_views)
self.split = split
self.get_image = get_image
params = {
2023-04-11 09:12:58 +00:00
"cat": categories,
"npoints": npoints,
"sv_samples": sv_samples,
2021-10-19 20:54:46 +00:00
}
params = tuple(sorted(pair for pair in params.items()))
2023-04-11 09:12:58 +00:00
self.cache_dir = Path(cache) / "svpoints/{}/{}".format(
"_".join(categories), hashlib.md5(bytes(repr(params), "utf-8")).hexdigest()
)
2021-10-19 20:54:46 +00:00
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.paths = []
self.synset_idxs = []
self.synsets = _convert_categories(categories)
self.labels = [synset_to_label[s] for s in self.synsets]
self.npoints = npoints
self.sv_samples = sv_samples
self.all_points = []
self.all_points_sv = []
# loops through desired classes
for i in range(len(self.synsets)):
syn = self.synsets[i]
class_target = self.root / syn
if not class_target.exists():
2023-04-11 09:12:58 +00:00
raise ValueError(
"Class {0} ({1}) was not found at location {2}.".format(syn, self.labels[i], str(class_target))
)
2021-10-19 20:54:46 +00:00
sub_path_pc = os.path.join(root_pc, syn, split)
if not os.path.isdir(sub_path_pc):
print("Directory missing : %s" % sub_path_pc)
continue
self.all_mids = []
self.imgs = []
for x in os.listdir(sub_path_pc):
2023-04-11 09:12:58 +00:00
if not x.endswith(".npy"):
2021-10-19 20:54:46 +00:00
continue
2023-04-11 09:12:58 +00:00
self.all_mids.append(os.path.join(split, x[: -len(".npy")]))
2021-10-19 20:54:46 +00:00
for mid in tqdm(self.all_mids):
# obj_fname = os.path.join(sub_path, x)
obj_fname = os.path.join(root_pc, syn, mid + ".npy")
2023-04-11 09:12:58 +00:00
cams_pths = list((self.root / syn / mid.split("/")[-1]).glob("*_cam_params.npz"))
2021-10-19 20:54:46 +00:00
if len(cams_pths) < 20:
continue
point_cloud = np.load(obj_fname)
sv_points_group = []
img_path_group = []
2023-04-11 09:12:58 +00:00
(self.cache_dir / (mid.split("/")[-1])).mkdir(parents=True, exist_ok=True)
2021-10-19 20:54:46 +00:00
success = True
for i, cp in enumerate(cams_pths):
cp = str(cp)
2023-04-11 09:12:58 +00:00
vp = cp.split("cam_params")[0] + "depth.png"
depth_minmax_pth = cp.split("_cam_params")[0] + ".npy"
cache_pth = str(self.cache_dir / mid.split("/")[-1] / os.path.basename(depth_minmax_pth))
2021-10-19 20:54:46 +00:00
cam_params = np.load(cp)
2023-04-11 09:12:58 +00:00
extr = cam_params["extr"]
intr = cam_params["intr"]
2021-10-19 20:54:46 +00:00
self.transform = DepthToSingleViewPoints(cam_ext=extr, cam_int=intr)
try:
sv_point_cloud = self._render(cache_pth, vp, depth_minmax_pth)
img_path_group.append(vp)
sv_points_group.append(sv_point_cloud)
except Exception as e:
print(e)
2023-04-11 09:12:58 +00:00
success = False
2021-10-19 20:54:46 +00:00
break
if not success:
continue
self.all_points_sv.append(np.stack(sv_points_group, axis=0))
self.all_points.append(point_cloud)
self.imgs.append(img_path_group)
self.all_points = np.stack(self.all_points, axis=0)
self.all_points_sv = np.stack(self.all_points_sv, axis=0)
if all_points_mean is not None and all_points_std is not None: # using loaded dataset stats
self.all_points_mean = all_points_mean
self.all_points_std = all_points_std
else: # normalize across the dataset
self.all_points_mean = self.all_points.reshape(-1, 3).mean(axis=0).reshape(1, 1, 3)
self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
2023-04-11 09:12:58 +00:00
self.train_points = self.all_points[:, :10000]
self.test_points = self.all_points[:, 10000:]
2021-10-19 20:54:46 +00:00
self.all_points_sv = (self.all_points_sv - self.all_points_mean) / self.all_points_std
def get_pc_stats(self, idx):
2023-04-11 09:12:58 +00:00
return self.all_points_mean.reshape(1, 1, -1), self.all_points_std.reshape(1, 1, -1)
2021-10-19 20:54:46 +00:00
def __len__(self):
2023-04-11 09:12:58 +00:00
"""Returns the length of the dataset."""
2021-10-19 20:54:46 +00:00
return len(self.all_points)
def __getitem__(self, index):
tr_out = self.train_points[index]
tr_idxs = np.random.choice(tr_out.shape[0], self.npoints)
tr_out = tr_out[tr_idxs, :]
2023-04-11 09:12:58 +00:00
gt_points = self.test_points[index][: self.npoints]
2021-10-19 20:54:46 +00:00
m, s = self.get_pc_stats(index)
sv_points = self.all_points_sv[index]
2023-04-11 09:12:58 +00:00
idxs = np.arange(0, sv_points.shape[-2])[
: self.sv_samples
] # np.random.choice(sv_points.shape[0], 500, replace=False)
data = torch.cat(
[
torch.from_numpy(sv_points[:, idxs]).float(),
torch.zeros(sv_points.shape[0], self.npoints - idxs.shape[0], sv_points.shape[2]),
],
dim=1,
)
2021-10-19 20:54:46 +00:00
masks = torch.zeros_like(data)
2023-04-11 09:12:58 +00:00
masks[:, : idxs.shape[0]] = 1
res = {
"train_points": torch.from_numpy(tr_out).float(),
"test_points": torch.from_numpy(gt_points).float(),
"sv_points": data,
"masks": masks,
"std": s,
"mean": m,
"idx": index,
"name": self.all_mids[index],
}
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
if self.split != "train" and self.get_image:
2021-10-19 20:54:46 +00:00
img_lst = []
for n in range(self.all_points_sv.shape[1]):
2023-04-11 09:12:58 +00:00
img = torch.from_numpy(plt.imread(self.imgs[index][n])).float().permute(2, 0, 1)[:3]
2021-10-19 20:54:46 +00:00
img_lst.append(img)
img = torch.stack(img_lst, dim=0)
2023-04-11 09:12:58 +00:00
res["image"] = img
2021-10-19 20:54:46 +00:00
return res
def _render(self, cache_path, depth_pth, depth_minmax_pth):
# if not os.path.exists(cache_path.split('.npy')[0] + '_color.png') and os.path.exists(cache_path):
#
# os.remove(cache_path)
if os.path.exists(cache_path):
data = np.load(cache_path)
else:
data, depth = self.transform(depth_pth, depth_minmax_pth)
2023-04-11 09:12:58 +00:00
assert data.shape[0] > 600, "Only {} points found".format(data.shape[0])
2021-10-19 20:54:46 +00:00
data = data[np.random.choice(data.shape[0], 600, replace=False)]
np.save(cache_path, data)
return data