LION/datasets/pointflow_datasets.py

447 lines
18 KiB
Python
Raw Normal View History

2023-01-23 05:14:49 +00:00
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
""" copied and modified from https://github.com/stevenygd/PointFlow/blob/master/datasets.py """
import os
import open3d as o3d
import time
import torch
import numpy as np
from loguru import logger
from torch.utils.data import Dataset
from torch.utils import data
import random
import tqdm
from datasets.data_path import get_path
2023-04-03 21:03:27 +00:00
from PIL import Image
2023-01-23 05:14:49 +00:00
OVERFIT = 0
# taken from https://github.com/optas/latent_3d_points/blob/
# 8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py
synsetid_to_cate = {
'02691156': 'airplane',
'02773838': 'bag',
'02801938': 'basket',
'02808440': 'bathtub',
'02818832': 'bed',
'02828884': 'bench',
'02876657': 'bottle',
'02880940': 'bowl',
'02924116': 'bus',
'02933112': 'cabinet',
'02747177': 'can',
'02942699': 'camera',
'02954340': 'cap',
'02958343': 'car',
'03001627': 'chair',
'03046257': 'clock',
'03207941': 'dishwasher',
'03211117': 'monitor',
'04379243': 'table',
'04401088': 'telephone',
'02946921': 'tin_can',
'04460130': 'tower',
'04468005': 'train',
'03085013': 'keyboard',
'03261776': 'earphone',
'03325088': 'faucet',
'03337140': 'file',
'03467517': 'guitar',
'03513137': 'helmet',
'03593526': 'jar',
'03624134': 'knife',
'03636649': 'lamp',
'03642806': 'laptop',
'03691459': 'speaker',
'03710193': 'mailbox',
'03759954': 'microphone',
'03761084': 'microwave',
'03790512': 'motorcycle',
'03797390': 'mug',
'03928116': 'piano',
'03938244': 'pillow',
'03948459': 'pistol',
'03991062': 'pot',
'04004475': 'printer',
'04074963': 'remote_control',
'04090263': 'rifle',
'04099429': 'rocket',
'04225987': 'skateboard',
'04256520': 'sofa',
'04330267': 'stove',
'04530566': 'vessel',
'04554684': 'washer',
'02992529': 'cellphone',
'02843684': 'birdhouse',
'02871439': 'bookshelf',
# '02858304': 'boat', no boat in our dataset, merged into vessels
# '02834778': 'bicycle', not in our taxonomy
}
cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()}
class ShapeNet15kPointClouds(Dataset):
def __init__(self,
categories=['airplane'],
tr_sample_size=10000,
te_sample_size=10000,
split='train',
scale=1.,
normalize_per_shape=False,
normalize_shape_box=False,
random_subsample=False,
sample_with_replacement=1,
normalize_std_per_axis=False,
normalize_global=False,
recenter_per_shape=False,
all_points_mean=None,
all_points_std=None,
input_dim=3,
2023-04-03 21:03:27 +00:00
clip_forge_enable=0, clip_model=None
2023-01-23 05:14:49 +00:00
):
2023-04-03 21:03:27 +00:00
self.clip_forge_enable = clip_forge_enable
if clip_forge_enable:
import clip
_, self.clip_preprocess = clip.load(clip_model)
if self.clip_forge_enable:
self.img_path = []
img_path = get_path('clip_forge_image')
2023-01-23 05:14:49 +00:00
self.normalize_shape_box = normalize_shape_box
root_dir = get_path('pointflow')
self.root_dir = root_dir
logger.info('[DATA] cat: {}, split: {}, full path: {}; norm global={}, norm-box={}',
categories, split, self.root_dir, normalize_global, normalize_shape_box)
self.split = split
assert self.split in ['train', 'test', 'val']
self.tr_sample_size = tr_sample_size
self.te_sample_size = te_sample_size
if type(categories) is str:
categories = [categories]
self.cates = categories
if 'all' in categories:
self.synset_ids = list(cate_to_synsetid.values())
else:
self.synset_ids = [cate_to_synsetid[c] for c in self.cates]
subdirs = self.synset_ids
# assert 'v2' in root_dir, "Only supporting v2 right now."
self.gravity_axis = 1
self.display_axis_order = [0, 2, 1]
self.root_dir = root_dir
self.split = split
self.in_tr_sample_size = tr_sample_size
self.in_te_sample_size = te_sample_size
self.subdirs = subdirs
self.scale = scale
self.random_subsample = random_subsample
self.sample_with_replacement = sample_with_replacement
self.input_dim = input_dim
self.all_cate_mids = []
self.cate_idx_lst = []
self.all_points = []
tic = time.time()
for cate_idx, subd in enumerate(self.subdirs):
# NOTE: [subd] here is synset id
sub_path = os.path.join(root_dir, subd, self.split)
if not os.path.isdir(sub_path):
print("Directory missing : %s " % (sub_path))
raise ValueError('check the data path')
continue
2023-04-03 21:03:27 +00:00
2023-01-23 05:14:49 +00:00
if True:
all_mids = []
assert(os.path.exists(sub_path)), f'path missing: {sub_path}'
for x in os.listdir(sub_path):
if not x.endswith('.npy'):
continue
all_mids.append(os.path.join(self.split, x[:-len('.npy')]))
logger.info('[DATA] number of file [{}] under: {} ',
len(os.listdir(sub_path)), sub_path)
# NOTE: [mid] contains the split: i.e. "train/<mid>"
# or "val/<mid>" or "test/<mid>"
all_mids = sorted(all_mids)
for mid in all_mids:
# obj_fname = os.path.join(sub_path, x)
2023-04-03 21:03:27 +00:00
if self.clip_forge_enable:
synset_id = subd
render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1], 'img_choy2016'
#render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1])
#if not (os.path.exists(render_img_path)): continue
self.img_path.append(render_img_path)
assert(os.path.exists(render_img_path)), f'render img path not find: {render_img_path}'
2023-01-23 05:14:49 +00:00
obj_fname = os.path.join(root_dir, subd, mid + ".npy")
point_cloud = np.load(obj_fname) # (15k, 3)
self.all_points.append(point_cloud[np.newaxis, ...])
self.cate_idx_lst.append(cate_idx)
self.all_cate_mids.append((subd, mid))
logger.info('[DATA] Load data time: {:.1f}s | dir: {} | '
'sample_with_replacement: {}; num points: {}', time.time() - tic, self.subdirs,
self.sample_with_replacement, len(self.all_points))
# Shuffle the index deterministically (based on the number of examples)
self.shuffle_idx = list(range(len(self.all_points)))
random.Random(38383).shuffle(self.shuffle_idx)
self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
self.all_points = [self.all_points[i] for i in self.shuffle_idx]
self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx]
2023-04-03 21:03:27 +00:00
if self.clip_forge_enable:
self.img_path = [self.img_path[i] for i in self.shuffle_idx]
2023-01-23 05:14:49 +00:00
# Normalization
self.all_points = np.concatenate(self.all_points) # (N, 15000, 3)
self.normalize_per_shape = normalize_per_shape
self.normalize_std_per_axis = normalize_std_per_axis
self.recenter_per_shape = recenter_per_shape
if self.normalize_shape_box: # per shape normalization
B, N = self.all_points.shape[:2]
self.all_points_mean = ( # B,1,3
(np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) +
(np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)) / 2
self.all_points_std = np.amax( # B,1,1
((np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) -
(np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)),
axis=-1).reshape(B, 1, 1) / 2
elif self.normalize_per_shape: # per shape normalization
B, N = self.all_points.shape[:2]
self.all_points_mean = self.all_points.mean(axis=1).reshape(
B, 1, input_dim)
logger.info('all_points shape: {}. mean over axis=1',
self.all_points.shape)
if normalize_std_per_axis:
self.all_points_std = self.all_points.reshape(
B, N, -1).std(axis=1).reshape(B, 1, input_dim)
else:
self.all_points_std = self.all_points.reshape(
B, -1).std(axis=1).reshape(B, 1, 1)
elif all_points_mean is not None and all_points_std is not None and not self.recenter_per_shape:
# using loaded dataset stats
self.all_points_mean = all_points_mean
self.all_points_std = all_points_std
elif self.recenter_per_shape: # per shape center
# TODO: bounding box scale at the large dim and center
B, N = self.all_points.shape[:2]
self.all_points_mean = (
(np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) +
(np.amin(self.all_points, axis=1)).reshape(B, 1,
input_dim)) / 2
self.all_points_std = np.amax(
((np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) -
(np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)),
axis=-1).reshape(B, 1, 1) / 2
# else: # normalize across the dataset
elif normalize_global: # normalize across the dataset
self.all_points_mean = self.all_points.reshape(
-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
if normalize_std_per_axis:
self.all_points_std = self.all_points.reshape(
-1, input_dim).std(axis=0).reshape(1, 1, input_dim)
else:
self.all_points_std = self.all_points.reshape(-1).std(
axis=0).reshape(1, 1, 1)
logger.info('[DATA] normalize_global: mean={}, std={}',
self.all_points_mean.reshape(-1),
self.all_points_std.reshape(-1))
else:
raise NotImplementedError('No Normalization')
self.all_points = (self.all_points - self.all_points_mean) / \
self.all_points_std
logger.info('[DATA] shape={}, all_points_mean:={}, std={}, max={:.3f}, min={:.3f}; num-pts={}',
self.all_points.shape,
self.all_points_mean.shape, self.all_points_std.shape,
self.all_points.max(), self.all_points.min(), tr_sample_size)
if OVERFIT:
self.all_points = self.all_points[:40]
# TODO: why do we need this??
self.train_points = self.all_points[:, :min(
2023-04-03 21:03:27 +00:00
10000, self.all_points.shape[1])] # subsample 15k points to 10k points per shape
2023-01-23 05:14:49 +00:00
self.tr_sample_size = min(10000, tr_sample_size)
self.te_sample_size = min(5000, te_sample_size)
assert self.scale == 1, "Scale (!= 1) is deprecated"
# Default display axis order
self.display_axis_order = [0, 1, 2]
def get_pc_stats(self, idx):
if self.recenter_per_shape:
m = self.all_points_mean[idx].reshape(1, self.input_dim)
s = self.all_points_std[idx].reshape(1, -1)
return m, s
if self.normalize_per_shape or self.normalize_shape_box:
m = self.all_points_mean[idx].reshape(1, self.input_dim)
s = self.all_points_std[idx].reshape(1, -1)
return m, s
return self.all_points_mean.reshape(1, -1), \
self.all_points_std.reshape(1, -1)
def renormalize(self, mean, std):
self.all_points = self.all_points * self.all_points_std + \
self.all_points_mean
self.all_points_mean = mean
self.all_points_std = std
self.all_points = (self.all_points - self.all_points_mean) / \
self.all_points_std
self.train_points = self.all_points[:, :min(
10000, self.all_points.shape[1])]
## self.test_points = self.all_points[:, 10000:]
def __len__(self):
return len(self.train_points)
def __getitem__(self, idx):
output = {}
tr_out = self.train_points[idx]
if self.random_subsample and self.sample_with_replacement:
tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size)
elif self.random_subsample and not self.sample_with_replacement:
tr_idxs = np.random.permutation(
np.arange(tr_out.shape[0]))[:self.tr_sample_size]
else:
tr_idxs = np.arange(self.tr_sample_size)
tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float()
m, s = self.get_pc_stats(idx)
cate_idx = self.cate_idx_lst[idx]
sid, mid = self.all_cate_mids[idx]
input_pts = tr_out
2023-04-03 21:03:27 +00:00
2023-01-23 05:14:49 +00:00
output.update(
{
'idx': idx,
'select_idx': tr_idxs,
'tr_points': tr_out,
'input_pts': input_pts,
'mean': m,
'std': s,
'cate_idx': cate_idx,
'sid': sid,
'mid': mid,
'display_axis_order': self.display_axis_order
})
2023-04-03 21:03:27 +00:00
# read image
if self.clip_forge_enable:
img_path = self.img_path[idx]
img_list = os.listdir(img_path)
img_list = [os.path.join(img_path, p) for p in img_list if 'jpg' in p or 'png' in p]
assert(len(img_list) > 0), f'get empty list at {img_path}: {os.listdir(img_path)}'
# subset 5 image
img_idx = np.random.choice(len(img_list), 5)
img_list = [img_list[o] for o in img_idx]
img_list = [Image.open(img).convert('RGB') for img in img_list]
img_list = [self.clip_preprocess(img) for img in img_list]
img_list = torch.stack(img_list, dim=0) # B,3,H,W
all_img = img_list
output['tr_img'] = all_img
2023-01-23 05:14:49 +00:00
return output
def init_np_seed(worker_id):
seed = torch.initial_seed()
np.random.seed(seed % 4294967296)
def get_datasets(cfg, args):
"""
cfg: config.data sub part
"""
if OVERFIT:
random_subsample = 0
else:
random_subsample = cfg.random_subsample
logger.info(f'get_datasets: tr_sample_size={cfg.tr_max_sample_points}, '
f' te_sample_size={cfg.te_max_sample_points}; '
f' random_subsample={random_subsample}'
f' normalize_global={cfg.normalize_global}'
f' normalize_std_per_axix={cfg.normalize_std_per_axis}'
f' normalize_per_shape={cfg.normalize_per_shape}'
f' recenter_per_shape={cfg.recenter_per_shape}'
)
kwargs = {}
tr_dataset = ShapeNet15kPointClouds(
categories=cfg.cates,
split='train',
tr_sample_size=cfg.tr_max_sample_points,
te_sample_size=cfg.te_max_sample_points,
sample_with_replacement=cfg.sample_with_replacement,
scale=cfg.dataset_scale, # root_dir=cfg.data_dir,
normalize_shape_box=cfg.normalize_shape_box,
normalize_per_shape=cfg.normalize_per_shape,
normalize_std_per_axis=cfg.normalize_std_per_axis,
normalize_global=cfg.normalize_global,
recenter_per_shape=cfg.recenter_per_shape,
random_subsample=random_subsample,
2023-04-03 21:03:27 +00:00
clip_forge_enable=cfg.clip_forge_enable,
clip_model=cfg.clip_model,
2023-01-23 05:14:49 +00:00
**kwargs)
eval_split = getattr(args, "eval_split", "val")
# te_dataset has random_subsample as False, therefore not using sample_with_replacement
te_dataset = ShapeNet15kPointClouds(
categories=cfg.cates,
split=eval_split,
tr_sample_size=cfg.tr_max_sample_points,
te_sample_size=cfg.te_max_sample_points,
scale=cfg.dataset_scale, # root_dir=cfg.data_dir,
normalize_shape_box=cfg.normalize_shape_box,
normalize_per_shape=cfg.normalize_per_shape,
normalize_std_per_axis=cfg.normalize_std_per_axis,
normalize_global=cfg.normalize_global,
recenter_per_shape=cfg.recenter_per_shape,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
2023-04-03 21:03:27 +00:00
clip_forge_enable=cfg.clip_forge_enable,
clip_model=cfg.clip_model,
2023-01-23 05:14:49 +00:00
)
return tr_dataset, te_dataset
def get_data_loaders(cfg, args):
tr_dataset, te_dataset = get_datasets(cfg, args)
kwargs = {}
if args.distributed:
kwargs['sampler'] = data.distributed.DistributedSampler(
tr_dataset, shuffle=True)
else:
kwargs['shuffle'] = True
if args.eval_trainnll:
kwargs['shuffle'] = False
train_loader = data.DataLoader(dataset=tr_dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
drop_last=cfg.train_drop_last == 1,
pin_memory=False, **kwargs)
test_loader = data.DataLoader(dataset=te_dataset,
batch_size=cfg.batch_size_test,
shuffle=False,
num_workers=cfg.num_workers,
pin_memory=False,
drop_last=False,
)
logger.info(
f'[Batch Size] train={cfg.batch_size}, test={cfg.batch_size_test}; drop-last={cfg.train_drop_last}')
loaders = {
"test_loader": test_loader,
'train_loader': train_loader,
}
return loaders