# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. """ copied and modified from https://github.com/stevenygd/PointFlow/blob/master/datasets.py """ import os import open3d as o3d import time import torch import numpy as np from loguru import logger from torch.utils.data import Dataset from torch.utils import data import random import tqdm from datasets.data_path import get_path from PIL import Image OVERFIT = 0 # taken from https://github.com/optas/latent_3d_points/blob/ # 8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py synsetid_to_cate = { '02691156': 'airplane', '02773838': 'bag', '02801938': 'basket', '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench', '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus', '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera', '02954340': 'cap', '02958343': 'car', '03001627': 'chair', '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor', '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can', '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard', '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file', '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar', '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop', '03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone', '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug', '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol', '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control', '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard', '04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel', '04554684': 'washer', '02992529': 'cellphone', '02843684': 'birdhouse', '02871439': 'bookshelf', # '02858304': 'boat', no boat in our dataset, merged into vessels # '02834778': 'bicycle', not in our taxonomy } cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()} class ShapeNet15kPointClouds(Dataset): def __init__(self, categories=['airplane'], tr_sample_size=10000, te_sample_size=10000, split='train', scale=1., normalize_per_shape=False, normalize_shape_box=False, random_subsample=False, sample_with_replacement=1, normalize_std_per_axis=False, normalize_global=False, recenter_per_shape=False, all_points_mean=None, all_points_std=None, input_dim=3, clip_forge_enable=0, clip_model=None ): self.clip_forge_enable = clip_forge_enable if clip_forge_enable: import clip _, self.clip_preprocess = clip.load(clip_model) if self.clip_forge_enable: self.img_path = [] img_path = get_path('clip_forge_image') self.normalize_shape_box = normalize_shape_box root_dir = get_path('pointflow') self.root_dir = root_dir logger.info('[DATA] cat: {}, split: {}, full path: {}; norm global={}, norm-box={}', categories, split, self.root_dir, normalize_global, normalize_shape_box) self.split = split assert self.split in ['train', 'test', 'val'] self.tr_sample_size = tr_sample_size self.te_sample_size = te_sample_size if type(categories) is str: categories = [categories] self.cates = categories if 'all' in categories: self.synset_ids = list(cate_to_synsetid.values()) else: self.synset_ids = [cate_to_synsetid[c] for c in self.cates] subdirs = self.synset_ids # assert 'v2' in root_dir, "Only supporting v2 right now." self.gravity_axis = 1 self.display_axis_order = [0, 2, 1] self.root_dir = root_dir self.split = split self.in_tr_sample_size = tr_sample_size self.in_te_sample_size = te_sample_size self.subdirs = subdirs self.scale = scale self.random_subsample = random_subsample self.sample_with_replacement = sample_with_replacement self.input_dim = input_dim self.all_cate_mids = [] self.cate_idx_lst = [] self.all_points = [] tic = time.time() for cate_idx, subd in enumerate(self.subdirs): # NOTE: [subd] here is synset id sub_path = os.path.join(root_dir, subd, self.split) if not os.path.isdir(sub_path): print("Directory missing : %s " % (sub_path)) raise ValueError('check the data path') continue if True: all_mids = [] assert(os.path.exists(sub_path)), f'path missing: {sub_path}' for x in os.listdir(sub_path): if not x.endswith('.npy'): continue all_mids.append(os.path.join(self.split, x[:-len('.npy')])) logger.info('[DATA] number of file [{}] under: {} ', len(os.listdir(sub_path)), sub_path) # NOTE: [mid] contains the split: i.e. "train/" # or "val/" or "test/" all_mids = sorted(all_mids) for mid in all_mids: # obj_fname = os.path.join(sub_path, x) if self.clip_forge_enable: synset_id = subd render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1], 'img_choy2016' #render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1]) #if not (os.path.exists(render_img_path)): continue self.img_path.append(render_img_path) assert(os.path.exists(render_img_path)), f'render img path not find: {render_img_path}' obj_fname = os.path.join(root_dir, subd, mid + ".npy") point_cloud = np.load(obj_fname) # (15k, 3) self.all_points.append(point_cloud[np.newaxis, ...]) self.cate_idx_lst.append(cate_idx) self.all_cate_mids.append((subd, mid)) logger.info('[DATA] Load data time: {:.1f}s | dir: {} | ' 'sample_with_replacement: {}; num points: {}', time.time() - tic, self.subdirs, self.sample_with_replacement, len(self.all_points)) # Shuffle the index deterministically (based on the number of examples) self.shuffle_idx = list(range(len(self.all_points))) random.Random(38383).shuffle(self.shuffle_idx) self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx] self.all_points = [self.all_points[i] for i in self.shuffle_idx] self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx] if self.clip_forge_enable: self.img_path = [self.img_path[i] for i in self.shuffle_idx] # Normalization self.all_points = np.concatenate(self.all_points) # (N, 15000, 3) self.normalize_per_shape = normalize_per_shape self.normalize_std_per_axis = normalize_std_per_axis self.recenter_per_shape = recenter_per_shape if self.normalize_shape_box: # per shape normalization B, N = self.all_points.shape[:2] self.all_points_mean = ( # B,1,3 (np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) + (np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)) / 2 self.all_points_std = np.amax( # B,1,1 ((np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) - (np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)), axis=-1).reshape(B, 1, 1) / 2 elif self.normalize_per_shape: # per shape normalization B, N = self.all_points.shape[:2] self.all_points_mean = self.all_points.mean(axis=1).reshape( B, 1, input_dim) logger.info('all_points shape: {}. mean over axis=1', self.all_points.shape) if normalize_std_per_axis: self.all_points_std = self.all_points.reshape( B, N, -1).std(axis=1).reshape(B, 1, input_dim) else: self.all_points_std = self.all_points.reshape( B, -1).std(axis=1).reshape(B, 1, 1) elif all_points_mean is not None and all_points_std is not None and not self.recenter_per_shape: # using loaded dataset stats self.all_points_mean = all_points_mean self.all_points_std = all_points_std elif self.recenter_per_shape: # per shape center # TODO: bounding box scale at the large dim and center B, N = self.all_points.shape[:2] self.all_points_mean = ( (np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) + (np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)) / 2 self.all_points_std = np.amax( ((np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) - (np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)), axis=-1).reshape(B, 1, 1) / 2 # else: # normalize across the dataset elif normalize_global: # normalize across the dataset self.all_points_mean = self.all_points.reshape( -1, input_dim).mean(axis=0).reshape(1, 1, input_dim) if normalize_std_per_axis: self.all_points_std = self.all_points.reshape( -1, input_dim).std(axis=0).reshape(1, 1, input_dim) else: self.all_points_std = self.all_points.reshape(-1).std( axis=0).reshape(1, 1, 1) logger.info('[DATA] normalize_global: mean={}, std={}', self.all_points_mean.reshape(-1), self.all_points_std.reshape(-1)) else: raise NotImplementedError('No Normalization') self.all_points = (self.all_points - self.all_points_mean) / \ self.all_points_std logger.info('[DATA] shape={}, all_points_mean:={}, std={}, max={:.3f}, min={:.3f}; num-pts={}', self.all_points.shape, self.all_points_mean.shape, self.all_points_std.shape, self.all_points.max(), self.all_points.min(), tr_sample_size) if OVERFIT: self.all_points = self.all_points[:40] # TODO: why do we need this?? self.train_points = self.all_points[:, :min( 10000, self.all_points.shape[1])] # subsample 15k points to 10k points per shape self.tr_sample_size = min(10000, tr_sample_size) self.te_sample_size = min(5000, te_sample_size) assert self.scale == 1, "Scale (!= 1) is deprecated" # Default display axis order self.display_axis_order = [0, 1, 2] def get_pc_stats(self, idx): if self.recenter_per_shape: m = self.all_points_mean[idx].reshape(1, self.input_dim) s = self.all_points_std[idx].reshape(1, -1) return m, s if self.normalize_per_shape or self.normalize_shape_box: m = self.all_points_mean[idx].reshape(1, self.input_dim) s = self.all_points_std[idx].reshape(1, -1) return m, s return self.all_points_mean.reshape(1, -1), \ self.all_points_std.reshape(1, -1) def renormalize(self, mean, std): self.all_points = self.all_points * self.all_points_std + \ self.all_points_mean self.all_points_mean = mean self.all_points_std = std self.all_points = (self.all_points - self.all_points_mean) / \ self.all_points_std self.train_points = self.all_points[:, :min( 10000, self.all_points.shape[1])] ## self.test_points = self.all_points[:, 10000:] def __len__(self): return len(self.train_points) def __getitem__(self, idx): output = {} tr_out = self.train_points[idx] if self.random_subsample and self.sample_with_replacement: tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size) elif self.random_subsample and not self.sample_with_replacement: tr_idxs = np.random.permutation( np.arange(tr_out.shape[0]))[:self.tr_sample_size] else: tr_idxs = np.arange(self.tr_sample_size) tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float() m, s = self.get_pc_stats(idx) cate_idx = self.cate_idx_lst[idx] sid, mid = self.all_cate_mids[idx] input_pts = tr_out output.update( { 'idx': idx, 'select_idx': tr_idxs, 'tr_points': tr_out, 'input_pts': input_pts, 'mean': m, 'std': s, 'cate_idx': cate_idx, 'sid': sid, 'mid': mid, 'display_axis_order': self.display_axis_order }) # read image if self.clip_forge_enable: img_path = self.img_path[idx] img_list = os.listdir(img_path) img_list = [os.path.join(img_path, p) for p in img_list if 'jpg' in p or 'png' in p] assert(len(img_list) > 0), f'get empty list at {img_path}: {os.listdir(img_path)}' # subset 5 image img_idx = np.random.choice(len(img_list), 5) img_list = [img_list[o] for o in img_idx] img_list = [Image.open(img).convert('RGB') for img in img_list] img_list = [self.clip_preprocess(img) for img in img_list] img_list = torch.stack(img_list, dim=0) # B,3,H,W all_img = img_list output['tr_img'] = all_img return output def init_np_seed(worker_id): seed = torch.initial_seed() np.random.seed(seed % 4294967296) def get_datasets(cfg, args): """ cfg: config.data sub part """ if OVERFIT: random_subsample = 0 else: random_subsample = cfg.random_subsample logger.info(f'get_datasets: tr_sample_size={cfg.tr_max_sample_points}, ' f' te_sample_size={cfg.te_max_sample_points}; ' f' random_subsample={random_subsample}' f' normalize_global={cfg.normalize_global}' f' normalize_std_per_axix={cfg.normalize_std_per_axis}' f' normalize_per_shape={cfg.normalize_per_shape}' f' recenter_per_shape={cfg.recenter_per_shape}' ) kwargs = {} tr_dataset = ShapeNet15kPointClouds( categories=cfg.cates, split='train', tr_sample_size=cfg.tr_max_sample_points, te_sample_size=cfg.te_max_sample_points, sample_with_replacement=cfg.sample_with_replacement, scale=cfg.dataset_scale, # root_dir=cfg.data_dir, normalize_shape_box=cfg.normalize_shape_box, normalize_per_shape=cfg.normalize_per_shape, normalize_std_per_axis=cfg.normalize_std_per_axis, normalize_global=cfg.normalize_global, recenter_per_shape=cfg.recenter_per_shape, random_subsample=random_subsample, clip_forge_enable=cfg.clip_forge_enable, clip_model=cfg.clip_model, **kwargs) eval_split = getattr(args, "eval_split", "val") # te_dataset has random_subsample as False, therefore not using sample_with_replacement te_dataset = ShapeNet15kPointClouds( categories=cfg.cates, split=eval_split, tr_sample_size=cfg.tr_max_sample_points, te_sample_size=cfg.te_max_sample_points, scale=cfg.dataset_scale, # root_dir=cfg.data_dir, normalize_shape_box=cfg.normalize_shape_box, normalize_per_shape=cfg.normalize_per_shape, normalize_std_per_axis=cfg.normalize_std_per_axis, normalize_global=cfg.normalize_global, recenter_per_shape=cfg.recenter_per_shape, all_points_mean=tr_dataset.all_points_mean, all_points_std=tr_dataset.all_points_std, clip_forge_enable=cfg.clip_forge_enable, clip_model=cfg.clip_model, ) return tr_dataset, te_dataset def get_data_loaders(cfg, args): tr_dataset, te_dataset = get_datasets(cfg, args) kwargs = {} if args.distributed: kwargs['sampler'] = data.distributed.DistributedSampler( tr_dataset, shuffle=True) else: kwargs['shuffle'] = True if args.eval_trainnll: kwargs['shuffle'] = False train_loader = data.DataLoader(dataset=tr_dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, drop_last=cfg.train_drop_last == 1, pin_memory=False, **kwargs) test_loader = data.DataLoader(dataset=te_dataset, batch_size=cfg.batch_size_test, shuffle=False, num_workers=cfg.num_workers, pin_memory=False, drop_last=False, ) logger.info( f'[Batch Size] train={cfg.batch_size}, test={cfg.batch_size_test}; drop-last={cfg.train_drop_last}') loaders = { "test_loader": test_loader, 'train_loader': train_loader, } return loaders