PVD/datasets/shapenet_data_sv.py

258 lines
10 KiB
Python
Raw Normal View History

2021-10-19 20:54:46 +00:00
import warnings
from torch.utils.data import Dataset
from tqdm import tqdm
from pathlib import Path
import open3d as o3d
import os
import numpy as np
import hashlib
import torch
import matplotlib.pyplot as plt
synset_to_label = {
'02691156': 'airplane', '02773838': 'bag', '02801938': 'basket',
'02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench',
'02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus',
'02933112': 'cabinet', '02747177': 'can', '02942699': 'camera',
'02954340': 'cap', '02958343': 'car', '03001627': 'chair',
'03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor',
'04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can',
'04460130': 'tower', '04468005': 'train', '03085013': 'keyboard',
'03261776': 'earphone', '03325088': 'faucet', '03337140': 'file',
'03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar',
'03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop',
'03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone',
'03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug',
'03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol',
'03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control',
'04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard',
'04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel',
'04554684': 'washer', '02992529': 'cellphone',
'02843684': 'birdhouse', '02871439': 'bookshelf',
# '02858304': 'boat', no boat in our dataset, merged into vessels
# '02834778': 'bicycle', not in our taxonomy
}
# Label to Synset mapping (for ShapeNet core classes)
label_to_synset = {v: k for k, v in synset_to_label.items()}
def _convert_categories(categories):
assert categories is not None, 'List of categories cannot be empty!'
if not (c in synset_to_label.keys() + label_to_synset.keys()
for c in categories):
warnings.warn('Some or all of the categories requested are not part of \
ShapeNetCore. Data loading may fail if these categories are not avaliable.')
synsets = [label_to_synset[c] if c in label_to_synset.keys()
else c for c in categories]
return synsets
class ShapeNet_Multiview_Points(Dataset):
def __init__(self, root_pc:str, root_views: str, cache: str, categories: list = ['chair'], split: str= 'val',
npoints=2048, sv_samples=800, all_points_mean=None, all_points_std=None, get_image=False):
self.root = Path(root_views)
self.split = split
self.get_image = get_image
params = {
'cat': categories,
'npoints': npoints,
'sv_samples': sv_samples,
}
params = tuple(sorted(pair for pair in params.items()))
self.cache_dir = Path(cache) / 'svpoints/{}/{}'.format('_'.join(categories), hashlib.md5(bytes(repr(params), 'utf-8')).hexdigest())
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.paths = []
self.synset_idxs = []
self.synsets = _convert_categories(categories)
self.labels = [synset_to_label[s] for s in self.synsets]
self.npoints = npoints
self.sv_samples = sv_samples
self.all_points = []
self.all_points_sv = []
# loops through desired classes
for i in range(len(self.synsets)):
syn = self.synsets[i]
class_target = self.root / syn
if not class_target.exists():
raise ValueError('Class {0} ({1}) was not found at location {2}.'.format(
syn, self.labels[i], str(class_target)))
sub_path_pc = os.path.join(root_pc, syn, split)
if not os.path.isdir(sub_path_pc):
print("Directory missing : %s" % sub_path_pc)
continue
self.all_mids = []
self.imgs = []
for x in os.listdir(sub_path_pc):
if not x.endswith('.npy'):
continue
self.all_mids.append(os.path.join(split, x[:-len('.npy')]))
for mid in tqdm(self.all_mids):
# obj_fname = os.path.join(sub_path, x)
obj_fname = os.path.join(root_pc, syn, mid + ".npy")
cams_pths = list((self.root/ syn/ mid.split('/')[-1]).glob('*_cam_params.npz'))
if len(cams_pths) < 20:
continue
point_cloud = np.load(obj_fname)
sv_points_group = []
img_path_group = []
(self.cache_dir / (mid.split('/')[-1])).mkdir(parents=True, exist_ok=True)
success = True
for i, cp in enumerate(cams_pths):
cp = str(cp)
vp = cp.split('cam_params')[0] + 'depth.png'
depth_minmax_pth = cp.split('_cam_params')[0] + '.npy'
cache_pth = str(self.cache_dir / mid.split('/')[-1] / os.path.basename(depth_minmax_pth) )
cam_params = np.load(cp)
extr = cam_params['extr']
intr = cam_params['intr']
self.transform = DepthToSingleViewPoints(cam_ext=extr, cam_int=intr)
try:
sv_point_cloud = self._render(cache_pth, vp, depth_minmax_pth)
img_path_group.append(vp)
sv_points_group.append(sv_point_cloud)
except Exception as e:
print(e)
success=False
break
if not success:
continue
self.all_points_sv.append(np.stack(sv_points_group, axis=0))
self.all_points.append(point_cloud)
self.imgs.append(img_path_group)
self.all_points = np.stack(self.all_points, axis=0)
self.all_points_sv = np.stack(self.all_points_sv, axis=0)
if all_points_mean is not None and all_points_std is not None: # using loaded dataset stats
self.all_points_mean = all_points_mean
self.all_points_std = all_points_std
else: # normalize across the dataset
self.all_points_mean = self.all_points.reshape(-1, 3).mean(axis=0).reshape(1, 1, 3)
self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
self.train_points = self.all_points[:,:10000]
self.test_points = self.all_points[:,10000:]
self.all_points_sv = (self.all_points_sv - self.all_points_mean) / self.all_points_std
def get_pc_stats(self, idx):
return self.all_points_mean.reshape(1,1, -1), self.all_points_std.reshape(1,1, -1)
def __len__(self):
"""Returns the length of the dataset. """
return len(self.all_points)
def __getitem__(self, index):
tr_out = self.train_points[index]
tr_idxs = np.random.choice(tr_out.shape[0], self.npoints)
tr_out = tr_out[tr_idxs, :]
gt_points = self.test_points[index][:self.npoints]
m, s = self.get_pc_stats(index)
sv_points = self.all_points_sv[index]
idxs = np.arange(0, sv_points.shape[-2])[:self.sv_samples]#np.random.choice(sv_points.shape[0], 500, replace=False)
data = torch.cat([torch.from_numpy(sv_points[:,idxs]).float(),
torch.zeros(sv_points.shape[0], self.npoints - idxs.shape[0], sv_points.shape[2])], dim=1)
masks = torch.zeros_like(data)
masks[:,:idxs.shape[0]] = 1
res = {'train_points': torch.from_numpy(tr_out).float(),
'test_points': torch.from_numpy(gt_points).float(),
'sv_points': data,
'masks': masks,
'std': s, 'mean': m,
'idx': index,
'name':self.all_mids[index]
}
if self.split != 'train' and self.get_image:
img_lst = []
for n in range(self.all_points_sv.shape[1]):
img = torch.from_numpy(plt.imread(self.imgs[index][n])).float().permute(2,0,1)[:3]
img_lst.append(img)
img = torch.stack(img_lst, dim=0)
res['image'] = img
return res
def _render(self, cache_path, depth_pth, depth_minmax_pth):
# if not os.path.exists(cache_path.split('.npy')[0] + '_color.png') and os.path.exists(cache_path):
#
# os.remove(cache_path)
if os.path.exists(cache_path):
data = np.load(cache_path)
else:
data, depth = self.transform(depth_pth, depth_minmax_pth)
assert data.shape[0] > 600, 'Only {} points found'.format(data.shape[0])
data = data[np.random.choice(data.shape[0], 600, replace=False)]
np.save(cache_path, data)
return data
class DepthToSingleViewPoints(object):
'''
render a view then save mask
'''
def __init__(self, cam_ext, cam_int):
self.cam_ext = cam_ext.reshape(4,4)
self.cam_int = cam_int.reshape(3,3)
def __call__(self, depth_pth, depth_minmax_pth):
depth_minmax = np.load(depth_minmax_pth)
depth_img = plt.imread(depth_pth)[...,0]
mask = np.where(depth_img == 0, -1.0, 1.0)
depth_img = 1 - depth_img
depth_img = (depth_img * (np.max(depth_minmax) - np.min(depth_minmax)) + np.min(depth_minmax)) * mask
intr = o3d.camera.PinholeCameraIntrinsic(depth_img.shape[0], depth_img.shape[1],
self.cam_int[0, 0], self.cam_int[1, 1], self.cam_int[0,2],
self.cam_int[1,2])
depth_im = o3d.geometry.Image(depth_img.astype(np.float32, copy=False))
# rgbd_im = o3d.geometry.RGBDImage.create_from_color_and_depth(color_im, depth_im)
pcd = o3d.geometry.PointCloud.create_from_depth_image(depth_im, intr, self.cam_ext, depth_scale=1.)
pc = np.asarray(pcd.points)
return pc, depth_img
def __repr__(self):
return 'MeshToMaskedVoxel_'+str(self.radius)+str(self.resolution)+str(self.elev )+str(self.azim)+str(self.img_size )