commit 2f6aa752a6d37b85a380766000d869bbc3bc96d4 Author: Linqi (Alex) Zhou Date: Tue Oct 19 13:54:46 2021 -0700 PVD diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1933b04 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +.idea +data +output +__pycache__ +*.png +*.net +*.npy +*.npz +eval_model +eval_data diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..45bf27f --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "metrics/ChamferDistancePytorch"] + path = metrics/ChamferDistancePytorch + url = https://github.com/ThibaultGROUEIX/ChamferDistancePytorch diff --git a/README.md b/README.md new file mode 100644 index 0000000..a7bbcfe --- /dev/null +++ b/README.md @@ -0,0 +1,59 @@ +# Shape Generation and Completion Through Point-Voxel Diffusion + +[Project]() | [Paper]() + +Implementation of + +## Pretrained Models + +Pretrained models can be accessed [here](https://www.dropbox.com/s/a3xydf594fzaokl/cifar10_pretrained.rar?dl=0). + +## Requirements: + +Make sure the following environments are installed. + +``` +python==3.6 +pytorch==1.4.0 +torchvision==0.5.0 +cudatoolkit==10.1 +matplotlib==2.2.5 +tqdm==4.32.1 +open3d==0.9.0 +``` +The code was tested on Unbuntu with Titan RTX. + + +## Training on CIFAR-10: + +```bash +$ python train_cifar.py +``` + +Please refer to the python file for optimal training parameters. + +## Results + +Some generative results are as follows. +

+ + +

+ + + +## Reference + +``` +@inproceedings{han2020joint, + title={Joint Training of Variational Auto-Encoder and Latent Energy-Based Model}, + author={Han, Tian and Nijkamp, Erik and Zhou, Linqi and Pang, Bo and Zhu, Song-Chun and Wu, Ying Nian}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={7978--7987}, + year={2020} +} +``` + +## Acknowledgement + +For any questions related to codes and experiment setting, please contact Linqi (Alex) Zhou (alexzhou907@gmail.com). For questions related to model and algorithm in the paper, please contact Tian Han (hantian@ucla.edu). Thanks to [@Tian Han ](https://github.com/hthth0801?tab=repositories) and [@Erik Njikamp](https://github.com/enijkamp) for their colloboration and guidance. \ No newline at end of file diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/datasets/partnet.py b/datasets/partnet.py new file mode 100644 index 0000000..72417a0 --- /dev/null +++ b/datasets/partnet.py @@ -0,0 +1,213 @@ +from torch.utils.data import Dataset, DataLoader +import torch +import numpy as np +import os +import json +import random +import trimesh +import csv +from plyfile import PlyData, PlyElement +from glob import glob + +def project_pc_to_image(points, resolution=64): + """project point clouds into 2D image + :param points: (n, 3) range(-1, 1) + :return: binary image + """ + img = [] + for i in range(3): + canvas = np.zeros((resolution, resolution)) + axis = [0, 1, 2] + axis.remove(i) + proj_points = (points[:, axis] + 1) / 2 * resolution + proj_points = proj_points.astype(np.int) + canvas[proj_points[:, 0], proj_points[:, 1]] = 1 + img.append(canvas) + img = np.concatenate(img, axis=1) + return img + + +def write_ply(points, filename, text=False): + """ input: Nx3, write points to filename as PLY format. """ + points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] + vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) + el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) + with open(filename, mode='wb') as f: + PlyData([el], text=text).write(f) + + +def rotate_point_cloud(points, transformation_mat): + + new_points = np.dot(transformation_mat, points.T).T + + return new_points + + +def rotate_point_cloud_by_axis_angle(points, axis, angle_deg): + """ align 3depn shapes to shapenet coordinates""" + # angle = math.radians(angle_deg) + # rot_m = pymesh.Quaternion.fromAxisAngle(axis, angle) + # rot_m = rot_m.to_matrix() + rot_m = np.array([[ 2.22044605e-16, 0.00000000e+00, 1.00000000e+00], + [ 0.00000000e+00, 1.00000000e+00, 0.00000000e+00], + [-1.00000000e+00, 0.00000000e+00, 2.22044605e-16]]) + + new_points = rotate_point_cloud(points, rot_m) + + return new_points + + +def downsample_point_cloud(points, n_pts): + """downsample points by random choice + :param points: (n, 3) + :param n_pts: int + :return: + """ + p_idx = random.choices(list(range(points.shape[0])), k=n_pts) + return points[p_idx] + + +def upsample_point_cloud(points, n_pts): + """upsample points by random choice + :param points: (n, 3) + :param n_pts: int, > n + :return: + """ + p_idx = random.choices(list(range(points.shape[0])), k=n_pts - points.shape[0]) + dup_points = points[p_idx] + points = np.concatenate([points, dup_points], axis=0) + return points + + +def sample_point_cloud_by_n(points, n_pts): + """resample point cloud to given number of points""" + if n_pts > points.shape[0]: + return upsample_point_cloud(points, n_pts) + elif n_pts < points.shape[0]: + return downsample_point_cloud(points, n_pts) + else: + return points + + + +def collect_data_id(split_dir, classname, phase): + filename = os.path.join(split_dir, "{}.{}.json".format(classname, phase)) + if not os.path.exists(filename): + raise ValueError("Invalid filepath: {}".format(filename)) + + all_ids = [] + with open(filename, 'r') as fp: + info = json.load(fp) + for item in info: + all_ids.append(item["anno_id"]) + + return all_ids + + + +class GANdatasetPartNet(Dataset): + def __init__(self, phase, data_root, category, n_pts): + super(GANdatasetPartNet, self).__init__() + if phase == "validation": + phase = "val" + + self.phase = phase + self.aug = phase == "train" + + self.data_root = data_root + + shape_names = collect_data_id(os.path.join(self.data_root, 'partnet_labels/partnet_train_val_test_split'), category, phase) + self.shape_names = [] + for name in shape_names: + path = os.path.join(self.data_root, 'partnet_labels/partnet_pc_label', name) + if os.path.exists(path): + self.shape_names.append(name) + + self.n_pts = n_pts + self.raw_n_pts = self.n_pts // 2 + + self.rng = random.Random(1234) + + @staticmethod + def load_point_cloud(path): + pc = trimesh.load(path) + pc = pc.vertices / 2.0 # scale to unit sphere + return pc + + @staticmethod + def read_point_cloud_part_label(path): + with open(path, 'r') as fp: + labels = fp.readlines() + labels = np.array([int(x) for x in labels]) + return labels + + def random_rm_parts(self, raw_pc, part_labels): + part_ids = sorted(np.unique(part_labels).tolist()) + if self.phase == "train": + random.shuffle(part_ids) + n_part_keep = random.randint(1, max(1, len(part_ids) - 1)) + else: + self.rng.shuffle(part_ids) + n_part_keep = self.rng.randint(1, max(1, len(part_ids) - 1)) + part_ids_keep = part_ids[:n_part_keep] + point_idx = [] + for i in part_ids_keep: + point_idx.extend(np.where(part_labels == i)[0].tolist()) + raw_pc = raw_pc[point_idx] + return raw_pc, n_part_keep + + def __getitem__(self, index): + raw_shape_name = self.shape_names[index] + raw_ply_path = os.path.join(self.data_root, 'partnet_data', raw_shape_name, 'point_sample/ply-10000.ply') + raw_pc = self.load_point_cloud(raw_ply_path) + + raw_label_path = os.path.join(self.data_root, 'partnet_labels/partnet_pc_label', raw_shape_name, 'label-merge-level1-10000.txt') + part_labels = self.read_point_cloud_part_label(raw_label_path) + raw_pc, n_part_keep = self.random_rm_parts(raw_pc, part_labels) + raw_pc = sample_point_cloud_by_n(raw_pc, self.raw_n_pts) + raw_pc = torch.tensor(raw_pc, dtype=torch.float32).transpose(1, 0) + + real_shape_name = self.shape_names[index] + real_ply_path = os.path.join(self.data_root, 'partnet_data', real_shape_name, 'point_sample/ply-10000.ply') + real_pc = self.load_point_cloud(real_ply_path) + real_pc = sample_point_cloud_by_n(real_pc, self.n_pts) + real_pc = torch.tensor(real_pc, dtype=torch.float32).transpose(1, 0) + + return {"raw": raw_pc, "real": real_pc, "raw_id": raw_shape_name, "real_id": real_shape_name, + 'n_part_keep': n_part_keep, 'idx': index} + + def __len__(self): + return len(self.shape_names) + + + + +if __name__ == '__main__': + data_root = '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetPointCloud' + data_raw_root = '/viscam/u/alexzhou907/01DATA/shapenet/shapenet_dim32_sdf_pc' + pc_dataroot = '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k' + + sn_root = '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2' + classes = 'car' + npoints = 2048 + # from datasets.shapenet_data_pc import ShapeNet15kPointClouds + # pc_ds = ShapeNet15kPointClouds(root_dir=pc_dataroot, + # categories=[classes], split='train', + # tr_sample_size=npoints, + # te_sample_size=npoints, + # scale=1., + # normalize_per_shape=False, + # normalize_std_per_axis=False, + # random_subsample=True) + + train_ds = GANdatasetPartNet('test', pc_dataroot, data_raw_root, classes, npoints, np.array([0,0,0]), + np.array([1, 1, 1])) + + d1 = train_ds[0] + real = d1['real'] + raw = d1['raw'] + m, s = d1['m'], d1['s'] + x = (torch.cat([raw, real], dim=-1) * s + m).transpose(0,1) + + write_ply(x.numpy(), 'x.ply') + pass diff --git a/datasets/shapenet_data_pc.py b/datasets/shapenet_data_pc.py new file mode 100644 index 0000000..78f4709 --- /dev/null +++ b/datasets/shapenet_data_pc.py @@ -0,0 +1,268 @@ +import os +import torch +import numpy as np +from torch.utils.data import Dataset +from torch.utils import data +import random +import open3d as o3d +import numpy as np +import torch.nn.functional as F + +# taken from https://github.com/optas/latent_3d_points/blob/8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py +synsetid_to_cate = { + '02691156': 'airplane', '02773838': 'bag', '02801938': 'basket', + '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench', + '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus', + '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera', + '02954340': 'cap', '02958343': 'car', '03001627': 'chair', + '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor', + '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can', + '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard', + '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file', + '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar', + '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop', + '03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone', + '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug', + '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol', + '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control', + '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard', + '04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel', + '04554684': 'washer', '02992529': 'cellphone', + '02843684': 'birdhouse', '02871439': 'bookshelf', + # '02858304': 'boat', no boat in our dataset, merged into vessels + # '02834778': 'bicycle', not in our taxonomy +} +cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()} + + +class Uniform15KPC(Dataset): + def __init__(self, root_dir, subdirs, tr_sample_size=10000, + te_sample_size=10000, split='train', scale=1., + normalize_per_shape=False, box_per_shape=False, + random_subsample=False, + normalize_std_per_axis=False, + all_points_mean=None, all_points_std=None, + input_dim=3, use_mask=False): + self.root_dir = root_dir + self.split = split + self.in_tr_sample_size = tr_sample_size + self.in_te_sample_size = te_sample_size + self.subdirs = subdirs + self.scale = scale + self.random_subsample = random_subsample + self.input_dim = input_dim + self.use_mask = use_mask + self.box_per_shape = box_per_shape + if use_mask: + self.mask_transform = PointCloudMasks(radius=5, elev=5, azim=90) + + self.all_cate_mids = [] + self.cate_idx_lst = [] + self.all_points = [] + for cate_idx, subd in enumerate(self.subdirs): + # NOTE: [subd] here is synset id + sub_path = os.path.join(root_dir, subd, self.split) + if not os.path.isdir(sub_path): + print("Directory missing : %s" % sub_path) + continue + + all_mids = [] + for x in os.listdir(sub_path): + if not x.endswith('.npy'): + continue + all_mids.append(os.path.join(self.split, x[:-len('.npy')])) + + # NOTE: [mid] contains the split: i.e. "train/" or "val/" or "test/" + for mid in all_mids: + # obj_fname = os.path.join(sub_path, x) + obj_fname = os.path.join(root_dir, subd, mid + ".npy") + try: + point_cloud = np.load(obj_fname) # (15k, 3) + + except: + continue + + assert point_cloud.shape[0] == 15000 + self.all_points.append(point_cloud[np.newaxis, ...]) + self.cate_idx_lst.append(cate_idx) + self.all_cate_mids.append((subd, mid)) + + # Shuffle the index deterministically (based on the number of examples) + self.shuffle_idx = list(range(len(self.all_points))) + random.Random(38383).shuffle(self.shuffle_idx) + self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx] + self.all_points = [self.all_points[i] for i in self.shuffle_idx] + self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx] + + # Normalization + self.all_points = np.concatenate(self.all_points) # (N, 15000, 3) + self.normalize_per_shape = normalize_per_shape + self.normalize_std_per_axis = normalize_std_per_axis + if all_points_mean is not None and all_points_std is not None: # using loaded dataset stats + self.all_points_mean = all_points_mean + self.all_points_std = all_points_std + elif self.normalize_per_shape: # per shape normalization + B, N = self.all_points.shape[:2] + self.all_points_mean = self.all_points.mean(axis=1).reshape(B, 1, input_dim) + if normalize_std_per_axis: + self.all_points_std = self.all_points.reshape(B, N, -1).std(axis=1).reshape(B, 1, input_dim) + else: + self.all_points_std = self.all_points.reshape(B, -1).std(axis=1).reshape(B, 1, 1) + elif self.box_per_shape: + B, N = self.all_points.shape[:2] + self.all_points_mean = self.all_points.min(axis=1).reshape(B, 1, input_dim) + + self.all_points_std = self.all_points.max(axis=1).reshape(B, 1, input_dim) - self.all_points.min(axis=1).reshape(B, 1, input_dim) + + else: # normalize across the dataset + self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim) + if normalize_std_per_axis: + self.all_points_std = self.all_points.reshape(-1, input_dim).std(axis=0).reshape(1, 1, input_dim) + else: + self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1) + + self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std + if self.box_per_shape: + self.all_points = self.all_points - 0.5 + self.train_points = self.all_points[:, :10000] + self.test_points = self.all_points[:, 10000:] + + self.tr_sample_size = min(10000, tr_sample_size) + self.te_sample_size = min(5000, te_sample_size) + print("Total number of data:%d" % len(self.train_points)) + print("Min number of points: (train)%d (test)%d" + % (self.tr_sample_size, self.te_sample_size)) + assert self.scale == 1, "Scale (!= 1) is deprecated" + + def get_pc_stats(self, idx): + if self.normalize_per_shape or self.box_per_shape: + m = self.all_points_mean[idx].reshape(1, self.input_dim) + s = self.all_points_std[idx].reshape(1, -1) + return m, s + + + return self.all_points_mean.reshape(1, -1), self.all_points_std.reshape(1, -1) + + def renormalize(self, mean, std): + self.all_points = self.all_points * self.all_points_std + self.all_points_mean + self.all_points_mean = mean + self.all_points_std = std + self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std + self.train_points = self.all_points[:, :10000] + self.test_points = self.all_points[:, 10000:] + + def __len__(self): + return len(self.train_points) + + def __getitem__(self, idx): + tr_out = self.train_points[idx] + if self.random_subsample: + tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size) + else: + tr_idxs = np.arange(self.tr_sample_size) + tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float() + + te_out = self.test_points[idx] + if self.random_subsample: + te_idxs = np.random.choice(te_out.shape[0], self.te_sample_size) + else: + te_idxs = np.arange(self.te_sample_size) + te_out = torch.from_numpy(te_out[te_idxs, :]).float() + + m, s = self.get_pc_stats(idx) + cate_idx = self.cate_idx_lst[idx] + sid, mid = self.all_cate_mids[idx] + + out = { + 'idx': idx, + 'train_points': tr_out, + 'test_points': te_out, + 'mean': m, 'std': s, 'cate_idx': cate_idx, + 'sid': sid, 'mid': mid + } + + if self.use_mask: + # masked = torch.from_numpy(self.mask_transform(self.all_points[idx])) + # ss = min(masked.shape[0], self.in_tr_sample_size//2) + # masked = masked[:ss] + # + # tr_mask = torch.ones_like(masked) + # masked = torch.cat([masked, torch.zeros(self.in_tr_sample_size - ss, 3)],dim=0)#F.pad(masked, (self.in_tr_sample_size-masked.shape[0], 0), "constant", 0) + # + # tr_mask = torch.cat([tr_mask, torch.zeros(self.in_tr_sample_size- ss, 3)],dim=0)#F.pad(tr_mask, (self.in_tr_sample_size-tr_mask.shape[0], 0), "constant", 0) + # out['train_points_masked'] = masked + # out['train_masks'] = tr_mask + tr_mask = self.mask_transform(tr_out) + out['train_masks'] = tr_mask + + return out + + +class ShapeNet15kPointClouds(Uniform15KPC): + def __init__(self, root_dir="data/ShapeNetCore.v2.PC15k", + categories=['airplane'], tr_sample_size=10000, te_sample_size=2048, + split='train', scale=1., normalize_per_shape=False, + normalize_std_per_axis=False, box_per_shape=False, + random_subsample=False, + all_points_mean=None, all_points_std=None, + use_mask=False): + self.root_dir = root_dir + self.split = split + assert self.split in ['train', 'test', 'val'] + self.tr_sample_size = tr_sample_size + self.te_sample_size = te_sample_size + self.cates = categories + if 'all' in categories: + self.synset_ids = list(cate_to_synsetid.values()) + else: + self.synset_ids = [cate_to_synsetid[c] for c in self.cates] + + # assert 'v2' in root_dir, "Only supporting v2 right now." + self.gravity_axis = 1 + self.display_axis_order = [0, 2, 1] + + super(ShapeNet15kPointClouds, self).__init__( + root_dir, self.synset_ids, + tr_sample_size=tr_sample_size, + te_sample_size=te_sample_size, + split=split, scale=scale, + normalize_per_shape=normalize_per_shape, box_per_shape=box_per_shape, + normalize_std_per_axis=normalize_std_per_axis, + random_subsample=random_subsample, + all_points_mean=all_points_mean, all_points_std=all_points_std, + input_dim=3, use_mask=use_mask) + + + +class PointCloudMasks(object): + ''' + render a view then save mask + ''' + def __init__(self, radius : float=10, elev: float =45, azim:float=315, ): + + self.radius = radius + self.elev = elev + self.azim = azim + + + def __call__(self, points): + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + + camera = [self.radius * np.sin(90-self.elev) * np.cos(self.azim), + self.radius * np.cos(90 - self.elev), + self.radius * np.sin(90 - self.elev) * np.sin(self.azim), + ] + # camera = [0,self.radius,0] + _, pt_map = pcd.hidden_point_removal(camera, self.radius) + + mask = torch.zeros_like(points) + mask[pt_map] = 1 + + return mask #points[pt_map] + + +#################################################################################### + + diff --git a/datasets/shapenet_data_sv.py b/datasets/shapenet_data_sv.py new file mode 100644 index 0000000..54edb92 --- /dev/null +++ b/datasets/shapenet_data_sv.py @@ -0,0 +1,257 @@ +import warnings +from torch.utils.data import Dataset +from tqdm import tqdm +from pathlib import Path +import open3d as o3d +import os +import numpy as np + +import hashlib +import torch +import matplotlib.pyplot as plt + +synset_to_label = { + '02691156': 'airplane', '02773838': 'bag', '02801938': 'basket', + '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench', + '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus', + '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera', + '02954340': 'cap', '02958343': 'car', '03001627': 'chair', + '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor', + '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can', + '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard', + '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file', + '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar', + '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop', + '03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone', + '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug', + '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol', + '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control', + '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard', + '04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel', + '04554684': 'washer', '02992529': 'cellphone', + '02843684': 'birdhouse', '02871439': 'bookshelf', + # '02858304': 'boat', no boat in our dataset, merged into vessels + # '02834778': 'bicycle', not in our taxonomy +} + +# Label to Synset mapping (for ShapeNet core classes) +label_to_synset = {v: k for k, v in synset_to_label.items()} + +def _convert_categories(categories): + assert categories is not None, 'List of categories cannot be empty!' + if not (c in synset_to_label.keys() + label_to_synset.keys() + for c in categories): + warnings.warn('Some or all of the categories requested are not part of \ + ShapeNetCore. Data loading may fail if these categories are not avaliable.') + synsets = [label_to_synset[c] if c in label_to_synset.keys() + else c for c in categories] + return synsets + + +class ShapeNet_Multiview_Points(Dataset): + def __init__(self, root_pc:str, root_views: str, cache: str, categories: list = ['chair'], split: str= 'val', + npoints=2048, sv_samples=800, all_points_mean=None, all_points_std=None, get_image=False): + self.root = Path(root_views) + self.split = split + self.get_image = get_image + params = { + 'cat': categories, + 'npoints': npoints, + 'sv_samples': sv_samples, + } + params = tuple(sorted(pair for pair in params.items())) + self.cache_dir = Path(cache) / 'svpoints/{}/{}'.format('_'.join(categories), hashlib.md5(bytes(repr(params), 'utf-8')).hexdigest()) + + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.paths = [] + self.synset_idxs = [] + self.synsets = _convert_categories(categories) + self.labels = [synset_to_label[s] for s in self.synsets] + self.npoints = npoints + self.sv_samples = sv_samples + + self.all_points = [] + self.all_points_sv = [] + + # loops through desired classes + for i in range(len(self.synsets)): + + syn = self.synsets[i] + class_target = self.root / syn + if not class_target.exists(): + raise ValueError('Class {0} ({1}) was not found at location {2}.'.format( + syn, self.labels[i], str(class_target))) + + + sub_path_pc = os.path.join(root_pc, syn, split) + if not os.path.isdir(sub_path_pc): + print("Directory missing : %s" % sub_path_pc) + continue + + self.all_mids = [] + self.imgs = [] + for x in os.listdir(sub_path_pc): + if not x.endswith('.npy'): + continue + self.all_mids.append(os.path.join(split, x[:-len('.npy')])) + + for mid in tqdm(self.all_mids): + # obj_fname = os.path.join(sub_path, x) + obj_fname = os.path.join(root_pc, syn, mid + ".npy") + cams_pths = list((self.root/ syn/ mid.split('/')[-1]).glob('*_cam_params.npz')) + if len(cams_pths) < 20: + continue + point_cloud = np.load(obj_fname) + sv_points_group = [] + img_path_group = [] + (self.cache_dir / (mid.split('/')[-1])).mkdir(parents=True, exist_ok=True) + success = True + for i, cp in enumerate(cams_pths): + cp = str(cp) + vp = cp.split('cam_params')[0] + 'depth.png' + depth_minmax_pth = cp.split('_cam_params')[0] + '.npy' + cache_pth = str(self.cache_dir / mid.split('/')[-1] / os.path.basename(depth_minmax_pth) ) + + cam_params = np.load(cp) + extr = cam_params['extr'] + intr = cam_params['intr'] + + self.transform = DepthToSingleViewPoints(cam_ext=extr, cam_int=intr) + + try: + sv_point_cloud = self._render(cache_pth, vp, depth_minmax_pth) + + img_path_group.append(vp) + + sv_points_group.append(sv_point_cloud) + except Exception as e: + print(e) + success=False + break + if not success: + continue + self.all_points_sv.append(np.stack(sv_points_group, axis=0)) + self.all_points.append(point_cloud) + self.imgs.append(img_path_group) + + self.all_points = np.stack(self.all_points, axis=0) + + self.all_points_sv = np.stack(self.all_points_sv, axis=0) + if all_points_mean is not None and all_points_std is not None: # using loaded dataset stats + self.all_points_mean = all_points_mean + self.all_points_std = all_points_std + else: # normalize across the dataset + self.all_points_mean = self.all_points.reshape(-1, 3).mean(axis=0).reshape(1, 1, 3) + self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1) + + self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std + self.train_points = self.all_points[:,:10000] + self.test_points = self.all_points[:,10000:] + self.all_points_sv = (self.all_points_sv - self.all_points_mean) / self.all_points_std + + def get_pc_stats(self, idx): + + return self.all_points_mean.reshape(1,1, -1), self.all_points_std.reshape(1,1, -1) + + def __len__(self): + """Returns the length of the dataset. """ + return len(self.all_points) + + def __getitem__(self, index): + + + tr_out = self.train_points[index] + tr_idxs = np.random.choice(tr_out.shape[0], self.npoints) + tr_out = tr_out[tr_idxs, :] + + gt_points = self.test_points[index][:self.npoints] + + m, s = self.get_pc_stats(index) + + sv_points = self.all_points_sv[index] + + idxs = np.arange(0, sv_points.shape[-2])[:self.sv_samples]#np.random.choice(sv_points.shape[0], 500, replace=False) + + data = torch.cat([torch.from_numpy(sv_points[:,idxs]).float(), + torch.zeros(sv_points.shape[0], self.npoints - idxs.shape[0], sv_points.shape[2])], dim=1) + masks = torch.zeros_like(data) + masks[:,:idxs.shape[0]] = 1 + + res = {'train_points': torch.from_numpy(tr_out).float(), + 'test_points': torch.from_numpy(gt_points).float(), + 'sv_points': data, + 'masks': masks, + 'std': s, 'mean': m, + 'idx': index, + 'name':self.all_mids[index] + } + + if self.split != 'train' and self.get_image: + + img_lst = [] + for n in range(self.all_points_sv.shape[1]): + + img = torch.from_numpy(plt.imread(self.imgs[index][n])).float().permute(2,0,1)[:3] + + img_lst.append(img) + + img = torch.stack(img_lst, dim=0) + + res['image'] = img + + return res + + + + def _render(self, cache_path, depth_pth, depth_minmax_pth): + # if not os.path.exists(cache_path.split('.npy')[0] + '_color.png') and os.path.exists(cache_path): + # + # os.remove(cache_path) + + if os.path.exists(cache_path): + data = np.load(cache_path) + else: + + data, depth = self.transform(depth_pth, depth_minmax_pth) + assert data.shape[0] > 600, 'Only {} points found'.format(data.shape[0]) + data = data[np.random.choice(data.shape[0], 600, replace=False)] + np.save(cache_path, data) + + return data + + + + +class DepthToSingleViewPoints(object): + ''' + render a view then save mask + ''' + def __init__(self, cam_ext, cam_int): + + self.cam_ext = cam_ext.reshape(4,4) + self.cam_int = cam_int.reshape(3,3) + + + def __call__(self, depth_pth, depth_minmax_pth): + + depth_minmax = np.load(depth_minmax_pth) + depth_img = plt.imread(depth_pth)[...,0] + mask = np.where(depth_img == 0, -1.0, 1.0) + depth_img = 1 - depth_img + depth_img = (depth_img * (np.max(depth_minmax) - np.min(depth_minmax)) + np.min(depth_minmax)) * mask + + intr = o3d.camera.PinholeCameraIntrinsic(depth_img.shape[0], depth_img.shape[1], + self.cam_int[0, 0], self.cam_int[1, 1], self.cam_int[0,2], + self.cam_int[1,2]) + + depth_im = o3d.geometry.Image(depth_img.astype(np.float32, copy=False)) + + # rgbd_im = o3d.geometry.RGBDImage.create_from_color_and_depth(color_im, depth_im) + pcd = o3d.geometry.PointCloud.create_from_depth_image(depth_im, intr, self.cam_ext, depth_scale=1.) + pc = np.asarray(pcd.points) + + return pc, depth_img + + def __repr__(self): + return 'MeshToMaskedVoxel_'+str(self.radius)+str(self.resolution)+str(self.elev )+str(self.azim)+str(self.img_size ) + diff --git a/metrics/.gitignore b/metrics/.gitignore new file mode 100644 index 0000000..98c2a8b --- /dev/null +++ b/metrics/.gitignore @@ -0,0 +1 @@ +StructuralLosses diff --git a/metrics/ChamferDistancePytorch/.gitignore b/metrics/ChamferDistancePytorch/.gitignore new file mode 100644 index 0000000..abd32e8 --- /dev/null +++ b/metrics/ChamferDistancePytorch/.gitignore @@ -0,0 +1 @@ +*__pycache__* \ No newline at end of file diff --git a/metrics/ChamferDistancePytorch/LICENSE b/metrics/ChamferDistancePytorch/LICENSE new file mode 100644 index 0000000..794e2df --- /dev/null +++ b/metrics/ChamferDistancePytorch/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 ThibaultGROUEIX + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/metrics/ChamferDistancePytorch/README.md b/metrics/ChamferDistancePytorch/README.md new file mode 100644 index 0000000..9b4df4f --- /dev/null +++ b/metrics/ChamferDistancePytorch/README.md @@ -0,0 +1,101 @@ +# Pytorch Chamfer Distance. + +Include a **CUDA** version, and a **PYTHON** version with pytorch standard operations. +NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt thresholds accordingly. + +- [x] F - Score + + + +### CUDA VERSION + +- [x] JIT compilation +- [x] Supports multi-gpu +- [x] 2D point clouds. +- [x] 3D point clouds. +- [x] 5D point clouds. +- [x] Contiguous() safe. + + + +### Python Version + +- [x] Supports any dimension + + + +### Usage + +```python +import torch, chamfer3D.dist_chamfer_3D, fscore +chamLoss = chamfer3D.dist_chamfer_3D.chamfer_3DDist() +points1 = torch.rand(32, 1000, 3).cuda() +points2 = torch.rand(32, 2000, 3, requires_grad=True).cuda() +dist1, dist2, idx1, idx2 = chamLoss(points1, points2) +f_score, precision, recall = fscore.fscore(dist1, dist2) +``` + + + +### Add it to your project as a submodule + +```shell +git submodule add https://github.com/ThibaultGROUEIX/ChamferDistancePytorch +``` + + + +### Benchmark: [forward + backward] pass +- [x] CUDA 10.1, NVIDIA 435, Pytorch 1.4 +- [x] p1 : 32 x 2000 x dim +- [x] p2 : 32 x 1000 x dim + +| *Timing (sec * 1000)* | 2D | 3D | 5D | +| ---------- | -------- | ------- | ------- | +| **Cuda Compiled** | **1.2** | 1.4 |1.8 | +| **Cuda JIT** | 1.3 | **1.4** |**1.5** | +| **Python** | 37 | 37 | 37 | + + +| *Memory (MB)* | 2D | 3D | 5D | +| ---------- | -------- | ------- | ------- | +| **Cuda Compiled** | 529 | 529 | 549 | +| **Cuda JIT** | **520** | **529** |**549** | +| **Python** | 2495 | 2495 | 2495 | + + + +### What is the chamfer distance ? + +[Stanford course](http://graphics.stanford.edu/courses/cs468-17-spring/LectureSlides/L14%20-%203d%20deep%20learning%20on%20point%20cloud%20representation%20(analysis).pdf) on 3D deep Learning + + + +### Aknowledgment + +Original backbone from [Fei Xia](https://github.com/fxia22/pointGAN/blob/master/nndistance/src/nnd_cuda.cu). + +JIT cool trick from [Christian Diller](https://github.com/chrdiller) + +### Troubleshoot + +- `Undefined symbol: Zxxxxxxxxxxxxxxxxx `: + +--> Fix: Make sure to `import torch` before you `import chamfer`. +--> Use pytorch.version >= 1.1.0 + +- [RuntimeError: Ninja is required to load C++ extension](https://github.com/zhanghang1989/PyTorch-Encoding/issues/167) + +```shell +wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip +sudo unzip ninja-linux.zip -d /usr/local/bin/ +sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force +``` + + + + + +#### TODO: + +* Discuss behaviour of torch.min() and tensor.min() which causes issues in some pytorch versions diff --git a/metrics/ChamferDistancePytorch/chamfer2D/chamfer2D.cu b/metrics/ChamferDistancePytorch/chamfer2D/chamfer2D.cu new file mode 100644 index 0000000..567dd1a --- /dev/null +++ b/metrics/ChamferDistancePytorch/chamfer2D/chamfer2D.cu @@ -0,0 +1,182 @@ + +#include +#include + +#include +#include + +#include + + + +__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ + const int batch=512; + __shared__ float buf[batch*2]; + for (int i=blockIdx.x;ibest){ + result[(i*n+j)]=best; + result_i[(i*n+j)]=best_i; + } + } + __syncthreads(); + } + } +} +// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ + + const auto batch_size = xyz1.size(0); + const auto n = xyz1.size(1); //num_points point cloud A + const auto m = xyz2.size(1); //num_points point cloud B + + NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); + NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + + +} +__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ + for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); + NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + +} + diff --git a/metrics/ChamferDistancePytorch/chamfer2D/chamfer_cuda.cpp b/metrics/ChamferDistancePytorch/chamfer2D/chamfer_cuda.cpp new file mode 100644 index 0000000..67574e2 --- /dev/null +++ b/metrics/ChamferDistancePytorch/chamfer2D/chamfer_cuda.cpp @@ -0,0 +1,33 @@ +#include +#include + +///TMP +//#include "common.h" +/// NOT TMP + + +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); + + +int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); + + + + +int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { + return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); +} + + +int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, + at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { + + return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); + m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); +} \ No newline at end of file diff --git a/metrics/ChamferDistancePytorch/chamfer2D/dist_chamfer_2D.py b/metrics/ChamferDistancePytorch/chamfer2D/dist_chamfer_2D.py new file mode 100644 index 0000000..f92e6f1 --- /dev/null +++ b/metrics/ChamferDistancePytorch/chamfer2D/dist_chamfer_2D.py @@ -0,0 +1,73 @@ +from torch import nn +from torch.autograd import Function +import torch +import importlib +import os +chamfer_found = importlib.find_loader("chamfer_2D") is not None +if not chamfer_found: + ## Cool trick from https://github.com/chrdiller + print("Jitting Chamfer 2D") + + from torch.utils.cpp_extension import load + chamfer_2D = load(name="chamfer_2D", + sources=[ + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer2D.cu"]), + ]) + print("Loaded JIT 2D CUDA chamfer distance") + +else: + import chamfer_2D + print("Loaded compiled 2D CUDA chamfer distance") + +# Chamfer's distance module @thibaultgroueix +# GPU tensors only +class chamfer_2DFunction(Function): + @staticmethod + def forward(ctx, xyz1, xyz2): + batchsize, n, _ = xyz1.size() + _, m, _ = xyz2.size() + device = xyz1.device + + dist1 = torch.zeros(batchsize, n) + dist2 = torch.zeros(batchsize, m) + + idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) + idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) + + dist1 = dist1.to(device) + dist2 = dist2.to(device) + idx1 = idx1.to(device) + idx2 = idx2.to(device) + torch.cuda.set_device(device) + + chamfer_2D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) + ctx.save_for_backward(xyz1, xyz2, idx1, idx2) + return dist1, dist2, idx1, idx2 + + @staticmethod + def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): + xyz1, xyz2, idx1, idx2 = ctx.saved_tensors + graddist1 = graddist1.contiguous() + graddist2 = graddist2.contiguous() + device = graddist1.device + + gradxyz1 = torch.zeros(xyz1.size()) + gradxyz2 = torch.zeros(xyz2.size()) + + gradxyz1 = gradxyz1.to(device) + gradxyz2 = gradxyz2.to(device) + chamfer_2D.backward( + xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 + ) + return gradxyz1, gradxyz2 + + +class chamfer_2DDist(nn.Module): + def __init__(self): + super(chamfer_2DDist, self).__init__() + + def forward(self, input1, input2): + input1 = input1.contiguous() + input2 = input2.contiguous() + return chamfer_2DFunction.apply(input1, input2) diff --git a/metrics/ChamferDistancePytorch/chamfer2D/setup.py b/metrics/ChamferDistancePytorch/chamfer2D/setup.py new file mode 100644 index 0000000..6fca431 --- /dev/null +++ b/metrics/ChamferDistancePytorch/chamfer2D/setup.py @@ -0,0 +1,16 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='chamfer_2D', + ext_modules=[ + CUDAExtension('chamfer_2D', [ + "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), + "/".join(__file__.split('/')[:-1] + ['chamfer2D.cu']), + ]), + ], + + extra_cuda_cflags=['--compiler-bindir=/usr/bin/gcc-8'], + cmdclass={ + 'build_ext': BuildExtension + }) \ No newline at end of file diff --git a/metrics/ChamferDistancePytorch/chamfer3D/chamfer3D.cu b/metrics/ChamferDistancePytorch/chamfer3D/chamfer3D.cu new file mode 100644 index 0000000..d5b886d --- /dev/null +++ b/metrics/ChamferDistancePytorch/chamfer3D/chamfer3D.cu @@ -0,0 +1,196 @@ + +#include +#include + +#include +#include + +#include + + + +__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ + const int batch=512; + __shared__ float buf[batch*3]; + for (int i=blockIdx.x;ibest){ + result[(i*n+j)]=best; + result_i[(i*n+j)]=best_i; + } + } + __syncthreads(); + } + } +} +// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ + + const auto batch_size = xyz1.size(0); + const auto n = xyz1.size(1); //num_points point cloud A + const auto m = xyz2.size(1); //num_points point cloud B + + NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); + NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + + +} +__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ + for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); + NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + +} + diff --git a/metrics/ChamferDistancePytorch/chamfer3D/chamfer_cuda.cpp b/metrics/ChamferDistancePytorch/chamfer3D/chamfer_cuda.cpp new file mode 100644 index 0000000..67574e2 --- /dev/null +++ b/metrics/ChamferDistancePytorch/chamfer3D/chamfer_cuda.cpp @@ -0,0 +1,33 @@ +#include +#include + +///TMP +//#include "common.h" +/// NOT TMP + + +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); + + +int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); + + + + +int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { + return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); +} + + +int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, + at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { + + return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); + m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); +} \ No newline at end of file diff --git a/metrics/ChamferDistancePytorch/chamfer3D/dist_chamfer_3D.py b/metrics/ChamferDistancePytorch/chamfer3D/dist_chamfer_3D.py new file mode 100644 index 0000000..f3f7587 --- /dev/null +++ b/metrics/ChamferDistancePytorch/chamfer3D/dist_chamfer_3D.py @@ -0,0 +1,77 @@ +from torch import nn +from torch.autograd import Function +import torch +import importlib +import os +chamfer_found = importlib.find_loader("chamfer_3D") is not None +if not chamfer_found: + ## Cool trick from https://github.com/chrdiller + print("Jitting Chamfer 3D") + + from torch.utils.cpp_extension import load + chamfer_3D = load(name="chamfer_3D", + sources=[ + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]), + ], + + extra_cuda_cflags=['--compiler-bindir=/usr/bin/gcc-8'],) + print("Loaded JIT 3D CUDA chamfer distance") + +else: + import chamfer_3D + print("Loaded compiled 3D CUDA chamfer distance") + + +# Chamfer's distance module @thibaultgroueix +# GPU tensors only +class chamfer_3DFunction(Function): + @staticmethod + def forward(ctx, xyz1, xyz2): + batchsize, n, _ = xyz1.size() + _, m, _ = xyz2.size() + device = xyz1.device + + dist1 = torch.zeros(batchsize, n) + dist2 = torch.zeros(batchsize, m) + + idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) + idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) + + dist1 = dist1.to(device) + dist2 = dist2.to(device) + idx1 = idx1.to(device) + idx2 = idx2.to(device) + torch.cuda.set_device(device) + + chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) + ctx.save_for_backward(xyz1, xyz2, idx1, idx2) + return dist1, dist2, idx1, idx2 + + @staticmethod + def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): + xyz1, xyz2, idx1, idx2 = ctx.saved_tensors + graddist1 = graddist1.contiguous() + graddist2 = graddist2.contiguous() + device = graddist1.device + + gradxyz1 = torch.zeros(xyz1.size()) + gradxyz2 = torch.zeros(xyz2.size()) + + gradxyz1 = gradxyz1.to(device) + gradxyz2 = gradxyz2.to(device) + chamfer_3D.backward( + xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 + ) + return gradxyz1, gradxyz2 + + +class chamfer_3DDist(nn.Module): + def __init__(self): + super(chamfer_3DDist, self).__init__() + + def forward(self, input1, input2): + input1 = input1.contiguous() + input2 = input2.contiguous() + return chamfer_3DFunction.apply(input1, input2) + diff --git a/metrics/ChamferDistancePytorch/chamfer3D/setup.py b/metrics/ChamferDistancePytorch/chamfer3D/setup.py new file mode 100644 index 0000000..0200b9f --- /dev/null +++ b/metrics/ChamferDistancePytorch/chamfer3D/setup.py @@ -0,0 +1,16 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='chamfer_3D', + ext_modules=[ + CUDAExtension('chamfer_3D', [ + "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), + "/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']), + ]), + ], + + extra_cuda_cflags=['--compiler-bindir=/usr/bin/gcc-8'], + cmdclass={ + 'build_ext': BuildExtension + }) \ No newline at end of file diff --git a/metrics/ChamferDistancePytorch/chamfer5D/chamfer5D.cu b/metrics/ChamferDistancePytorch/chamfer5D/chamfer5D.cu new file mode 100644 index 0000000..650e889 --- /dev/null +++ b/metrics/ChamferDistancePytorch/chamfer5D/chamfer5D.cu @@ -0,0 +1,223 @@ + +#include +#include + +#include +#include + +#include + + + +__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ + const int batch=2048; + __shared__ float buf[batch*5]; + for (int i=blockIdx.x;ibest){ + result[(i*n+j)]=best; + result_i[(i*n+j)]=best_i; + } + } + __syncthreads(); + } + } +} +// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ + + const auto batch_size = xyz1.size(0); + const auto n = xyz1.size(1); //num_points point cloud A + const auto m = xyz2.size(1); //num_points point cloud B + + NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); + NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + + +} +__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ + for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); + NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + +} diff --git a/metrics/ChamferDistancePytorch/chamfer5D/chamfer_cuda.cpp b/metrics/ChamferDistancePytorch/chamfer5D/chamfer_cuda.cpp new file mode 100644 index 0000000..67574e2 --- /dev/null +++ b/metrics/ChamferDistancePytorch/chamfer5D/chamfer_cuda.cpp @@ -0,0 +1,33 @@ +#include +#include + +///TMP +//#include "common.h" +/// NOT TMP + + +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); + + +int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); + + + + +int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { + return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); +} + + +int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, + at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { + + return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); + m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); +} \ No newline at end of file diff --git a/metrics/ChamferDistancePytorch/chamfer5D/dist_chamfer_5D.py b/metrics/ChamferDistancePytorch/chamfer5D/dist_chamfer_5D.py new file mode 100644 index 0000000..3730a1f --- /dev/null +++ b/metrics/ChamferDistancePytorch/chamfer5D/dist_chamfer_5D.py @@ -0,0 +1,75 @@ +from torch import nn +from torch.autograd import Function +import torch +import importlib +import os + +chamfer_found = importlib.find_loader("chamfer_5D") is not None +if not chamfer_found: + ## Cool trick from https://github.com/chrdiller + print("Jitting Chamfer 5D") + + from torch.utils.cpp_extension import load + chamfer_5D = load(name="chamfer_5D", + sources=[ + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer5D.cu"]), + ]) + print("Loaded JIT 5D CUDA chamfer distance") + +else: + import chamfer_5D + print("Loaded compiled 5D CUDA chamfer distance") + + +# Chamfer's distance module @thibaultgroueix +# GPU tensors only +class chamfer_5DFunction(Function): + @staticmethod + def forward(ctx, xyz1, xyz2): + batchsize, n, _ = xyz1.size() + _, m, _ = xyz2.size() + device = xyz1.device + + dist1 = torch.zeros(batchsize, n) + dist2 = torch.zeros(batchsize, m) + + idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) + idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) + + dist1 = dist1.to(device) + dist2 = dist2.to(device) + idx1 = idx1.to(device) + idx2 = idx2.to(device) + torch.cuda.set_device(device) + + chamfer_5D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) + ctx.save_for_backward(xyz1, xyz2, idx1, idx2) + return dist1, dist2, idx1, idx2 + + @staticmethod + def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): + xyz1, xyz2, idx1, idx2 = ctx.saved_tensors + graddist1 = graddist1.contiguous() + graddist2 = graddist2.contiguous() + device = graddist1.device + + gradxyz1 = torch.zeros(xyz1.size()) + gradxyz2 = torch.zeros(xyz2.size()) + + gradxyz1 = gradxyz1.to(device) + gradxyz2 = gradxyz2.to(device) + chamfer_5D.backward( + xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 + ) + return gradxyz1, gradxyz2 + + +class chamfer_5DDist(nn.Module): + def __init__(self): + super(chamfer_5DDist, self).__init__() + + def forward(self, input1, input2): + input1 = input1.contiguous() + input2 = input2.contiguous() + return chamfer_5DFunction.apply(input1, input2) diff --git a/metrics/ChamferDistancePytorch/chamfer5D/setup.py b/metrics/ChamferDistancePytorch/chamfer5D/setup.py new file mode 100644 index 0000000..51f5960 --- /dev/null +++ b/metrics/ChamferDistancePytorch/chamfer5D/setup.py @@ -0,0 +1,16 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='chamfer_5D', + ext_modules=[ + CUDAExtension('chamfer_5D', [ + "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), + "/".join(__file__.split('/')[:-1] + ['chamfer5D.cu']), + ]), + ], + + extra_cuda_cflags=['--compiler-bindir=/usr/bin/gcc-8'], + cmdclass={ + 'build_ext': BuildExtension + }) \ No newline at end of file diff --git a/metrics/ChamferDistancePytorch/chamfer_python.py b/metrics/ChamferDistancePytorch/chamfer_python.py new file mode 100644 index 0000000..ce0aeaa --- /dev/null +++ b/metrics/ChamferDistancePytorch/chamfer_python.py @@ -0,0 +1,40 @@ +import torch + + +def pairwise_dist(x, y): + xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t()) + rx = xx.diag().unsqueeze(0).expand_as(xx) + ry = yy.diag().unsqueeze(0).expand_as(yy) + P = rx.t() + ry - 2 * zz + return P + + +def NN_loss(x, y, dim=0): + dist = pairwise_dist(x, y) + values, indices = dist.min(dim=dim) + return values.mean() + + +def distChamfer(a, b): + """ + :param a: Pointclouds Batch x nul_points x dim + :param b: Pointclouds Batch x nul_points x dim + :return: + -closest point on b of points from a + -closest point on a of points from b + -idx of closest point on b of points from a + -idx of closest point on a of points from b + Works for pointcloud of any dimension + """ + x, y = a.double(), b.double() + bs, num_points_x, points_dim = x.size() + bs, num_points_y, points_dim = y.size() + + xx = torch.pow(x, 2).sum(2) + yy = torch.pow(y, 2).sum(2) + zz = torch.bmm(x, y.transpose(2, 1)) + rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx + ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy + P = rx.transpose(2, 1) + ry - 2 * zz + return torch.min(P, 2)[0].float(), torch.min(P, 1)[0].float(), torch.min(P, 2)[1].int(), torch.min(P, 1)[1].int() + diff --git a/metrics/ChamferDistancePytorch/fscore.py b/metrics/ChamferDistancePytorch/fscore.py new file mode 100644 index 0000000..265378b --- /dev/null +++ b/metrics/ChamferDistancePytorch/fscore.py @@ -0,0 +1,17 @@ +import torch + +def fscore(dist1, dist2, threshold=0.001): + """ + Calculates the F-score between two point clouds with the corresponding threshold value. + :param dist1: Batch, N-Points + :param dist2: Batch, N-Points + :param th: float + :return: fscore, precision, recall + """ + # NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt the threshold accordingly. + precision_1 = torch.mean((dist1 < threshold).float(), dim=1) + precision_2 = torch.mean((dist2 < threshold).float(), dim=1) + fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2) + fscore[torch.isnan(fscore)] = 0 + return fscore, precision_1, precision_2 + diff --git a/metrics/ChamferDistancePytorch/unit_test.py b/metrics/ChamferDistancePytorch/unit_test.py new file mode 100644 index 0000000..13af6a3 --- /dev/null +++ b/metrics/ChamferDistancePytorch/unit_test.py @@ -0,0 +1,69 @@ +import torch, time +import chamfer2D.dist_chamfer_2D +import chamfer3D.dist_chamfer_3D +import chamfer5D.dist_chamfer_5D +import chamfer_python + +cham2D = chamfer2D.dist_chamfer_2D.chamfer_2DDist() +cham3D = chamfer3D.dist_chamfer_3D.chamfer_3DDist() +cham5D = chamfer5D.dist_chamfer_5D.chamfer_5DDist() + +from torch.autograd import Variable +from fscore import fscore + +def test_chamfer(distChamfer, dim): + points1 = torch.rand(4, 100, dim).cuda() + points2 = torch.rand(4, 200, dim, requires_grad=True).cuda() + dist1, dist2, idx1, idx2= distChamfer(points1, points2) + + loss = torch.sum(dist1) + loss.backward() + + mydist1, mydist2, myidx1, myidx2 = chamfer_python.distChamfer(points1, points2) + d1 = (dist1 - mydist1) ** 2 + d2 = (dist2 - mydist2) ** 2 + assert ( + torch.mean(d1) + torch.mean(d2) < 0.00000001 + ), "chamfer cuda and chamfer normal are not giving the same results" + + xd1 = idx1 - myidx1 + xd2 = idx2 - myidx2 + assert ( + torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0 + ), "chamfer cuda and chamfer normal are not giving the same results" + print(f"fscore :", fscore(dist1, dist2)) + print("Unit test passed") + + +def timings(distChamfer, dim): + p1 = torch.rand(32, 2000, dim).cuda() + p2 = torch.rand(32, 1000, dim).cuda() + print("Timings : Start CUDA version") + start = time.time() + num_it = 100 + for i in range(num_it): + points1 = Variable(p1, requires_grad=True) + points2 = Variable(p2) + mydist1, mydist2, idx1, idx2 = distChamfer(points1, points2) + loss = torch.sum(mydist1) + loss.backward() + print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.") + + + print("Timings : Start Pythonic version") + start = time.time() + for i in range(num_it): + points1 = Variable(p1, requires_grad=True) + points2 = Variable(p2) + mydist1, mydist2, idx1, idx2 = chamfer_python.distChamfer(points1, points2) + loss = torch.sum(mydist1) + loss.backward() + print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.") + + + +dims = [2,3,5] +for i,cham in enumerate([cham2D, cham3D, cham5D]): + print(f"testing Chamfer {dims[i]}D") + test_chamfer(cham, dims[i]) + timings(cham, dims[i]) diff --git a/metrics/PyTorchEMD/.gitignore b/metrics/PyTorchEMD/.gitignore new file mode 100644 index 0000000..8400d00 --- /dev/null +++ b/metrics/PyTorchEMD/.gitignore @@ -0,0 +1,5 @@ +__pycache__ +build +dist +emd_ext.egg-info +*.so diff --git a/metrics/PyTorchEMD/README.md b/metrics/PyTorchEMD/README.md new file mode 100644 index 0000000..8165a45 --- /dev/null +++ b/metrics/PyTorchEMD/README.md @@ -0,0 +1,31 @@ +# PyTorch Wrapper for Point-cloud Earth-Mover-Distance (EMD) + +## Dependency + +The code has been tested on Ubuntu 16.04, PyTorch 1.1.0, CUDA 9.0. + +## Usage + +First compile using + + python setup.py install + +Then, copy the lib file out to the main directory, + + cp build/lib.linux-x86_64-3.6/emd_cuda.cpython-36m-x86_64-linux-gnu.so . + +Then, you can use it by simply + + from emd import earth_mover_distance + d = earth_mover_distance(p1, p2, transpose=False) # p1: B x N1 x 3, p2: B x N2 x 3 + +Check `test_emd_loss.py` for example. + +## Author + +The cuda code is originally written by Haoqiang Fan. The PyTorch wrapper is written by Kaichun Mo. Also, Jiayuan Gu provided helps. + +## License + +MIT + diff --git a/metrics/PyTorchEMD/__init__.py b/metrics/PyTorchEMD/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/metrics/PyTorchEMD/cuda/emd.cpp b/metrics/PyTorchEMD/cuda/emd.cpp new file mode 100644 index 0000000..b94db14 --- /dev/null +++ b/metrics/PyTorchEMD/cuda/emd.cpp @@ -0,0 +1,29 @@ +#ifndef _EMD +#define _EMD + +#include +#include + +//CUDA declarations +at::Tensor ApproxMatchForward( + const at::Tensor xyz1, + const at::Tensor xyz2); + +at::Tensor MatchCostForward( + const at::Tensor xyz1, + const at::Tensor xyz2, + const at::Tensor match); + +std::vector MatchCostBackward( + const at::Tensor grad_cost, + const at::Tensor xyz1, + const at::Tensor xyz2, + const at::Tensor match); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("approxmatch_forward", &ApproxMatchForward,"ApproxMatch forward (CUDA)"); + m.def("matchcost_forward", &MatchCostForward,"MatchCost forward (CUDA)"); + m.def("matchcost_backward", &MatchCostBackward,"MatchCost backward (CUDA)"); +} + +#endif diff --git a/metrics/PyTorchEMD/cuda/emd_kernel.cu b/metrics/PyTorchEMD/cuda/emd_kernel.cu new file mode 100644 index 0000000..4744a81 --- /dev/null +++ b/metrics/PyTorchEMD/cuda/emd_kernel.cu @@ -0,0 +1,400 @@ +/********************************** + * Original Author: Haoqiang Fan + * Modified by: Kaichun Mo + *********************************/ + +#ifndef _EMD_KERNEL +#define _EMD_KERNEL + +#include +#include + +#include +#include // at::cuda::getApplyGrid +#include + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + + +/******************************** +* Forward kernel for approxmatch +*********************************/ + +template +__global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,scalar_t * __restrict__ match,scalar_t * temp){ + scalar_t * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n; + scalar_t multiL,multiR; + if (n>=m){ + multiL=1; + multiR=n/m; + }else{ + multiL=m/n; + multiR=1; + } + const int Block=1024; + __shared__ scalar_t buf[Block*4]; + for (int i=blockIdx.x;i=-2;j--){ + scalar_t level=-powf(4.0f,j); + if (j==-2){ + level=0; + } + for (int k0=0;k0>>(b,n,m,xyz1,xyz2,match,temp); +//} + +/* ApproxMatch forward interface +Input: + xyz1: (B, N1, 3) # dataset_points + xyz2: (B, N2, 3) # query_points +Output: + match: (B, N2, N1) +*/ +at::Tensor ApproxMatchForward( + const at::Tensor xyz1, + const at::Tensor xyz2){ + const auto b = xyz1.size(0); + const auto n = xyz1.size(1); + const auto m = xyz2.size(1); + + CHECK_EQ(xyz2.size(0), b); + CHECK_EQ(xyz1.size(2), 3); + CHECK_EQ(xyz2.size(2), 3); + CHECK_INPUT(xyz1); + CHECK_INPUT(xyz2); + + auto match = at::zeros({b, m, n}, xyz1.type()); + auto temp = at::zeros({b, (n+m)*2}, xyz1.type()); + + AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] { + approxmatch<<<32,512>>>(b, n, m, xyz1.data(), xyz2.data(), match.data(), temp.data()); + })); + THCudaCheck(cudaGetLastError()); + + return match; +} + + +/******************************** +* Forward kernel for matchcost +*********************************/ + +template +__global__ void matchcost(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ out){ + __shared__ scalar_t allsum[512]; + const int Block=1024; + __shared__ scalar_t buf[Block*3]; + for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,out); +//} + +/* MatchCost forward interface +Input: + xyz1: (B, N1, 3) # dataset_points + xyz2: (B, N2, 3) # query_points + match: (B, N2, N1) +Output: + cost: (B) +*/ +at::Tensor MatchCostForward( + const at::Tensor xyz1, + const at::Tensor xyz2, + const at::Tensor match){ + const auto b = xyz1.size(0); + const auto n = xyz1.size(1); + const auto m = xyz2.size(1); + + CHECK_EQ(xyz2.size(0), b); + CHECK_EQ(xyz1.size(2), 3); + CHECK_EQ(xyz2.size(2), 3); + CHECK_INPUT(xyz1); + CHECK_INPUT(xyz2); + + auto cost = at::zeros({b}, xyz1.type()); + + AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] { + matchcost<<<32,512>>>(b, n, m, xyz1.data(), xyz2.data(), match.data(), cost.data()); + })); + THCudaCheck(cudaGetLastError()); + + return cost; +} + + +/******************************** +* matchcostgrad2 kernel +*********************************/ + +template +__global__ void matchcostgrad2(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad2){ + __shared__ scalar_t sum_grad[256*3]; + for (int i=blockIdx.x;i +__global__ void matchcostgrad1(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad1){ + for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,grad1); +// matchcostgrad2<<>>(b,n,m,xyz1,xyz2,match,grad2); +//} + + +/* MatchCost backward interface +Input: + grad_cost: (B) # gradients on cost + xyz1: (B, N1, 3) # dataset_points + xyz2: (B, N2, 3) # query_points + match: (B, N2, N1) +Output: + grad1: (B, N1, 3) + grad2: (B, N2, 3) +*/ +std::vector MatchCostBackward( + const at::Tensor grad_cost, + const at::Tensor xyz1, + const at::Tensor xyz2, + const at::Tensor match){ + const auto b = xyz1.size(0); + const auto n = xyz1.size(1); + const auto m = xyz2.size(1); + + CHECK_EQ(xyz2.size(0), b); + CHECK_EQ(xyz1.size(2), 3); + CHECK_EQ(xyz2.size(2), 3); + CHECK_INPUT(xyz1); + CHECK_INPUT(xyz2); + + auto grad1 = at::zeros({b, n, 3}, xyz1.type()); + auto grad2 = at::zeros({b, m, 3}, xyz1.type()); + + AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostBackward", ([&] { + matchcostgrad1<<<32,512>>>(b, n, m, grad_cost.data(), xyz1.data(), xyz2.data(), match.data(), grad1.data()); + matchcostgrad2<<>>(b, n, m, grad_cost.data(), xyz1.data(), xyz2.data(), match.data(), grad2.data()); + })); + THCudaCheck(cudaGetLastError()); + + return std::vector({grad1, grad2}); +} + +#endif diff --git a/metrics/PyTorchEMD/emd.py b/metrics/PyTorchEMD/emd.py new file mode 100644 index 0000000..b0a01ce --- /dev/null +++ b/metrics/PyTorchEMD/emd.py @@ -0,0 +1,47 @@ +import torch +import emd_cuda + + +class EarthMoverDistanceFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, xyz1, xyz2): + xyz1 = xyz1.contiguous() + xyz2 = xyz2.contiguous() + assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently." + match = emd_cuda.approxmatch_forward(xyz1, xyz2) + cost = emd_cuda.matchcost_forward(xyz1, xyz2, match) + ctx.save_for_backward(xyz1, xyz2, match) + return cost + + @staticmethod + def backward(ctx, grad_cost): + xyz1, xyz2, match = ctx.saved_tensors + grad_cost = grad_cost.contiguous() + grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match) + return grad_xyz1, grad_xyz2 + + +def earth_mover_distance(xyz1, xyz2, transpose=True): + """Earth Mover Distance (Approx) + + Args: + xyz1 (torch.Tensor): (b, 3, n1) + xyz2 (torch.Tensor): (b, 3, n1) + transpose (bool): whether to transpose inputs as it might be BCN format. + Extensions only support BNC format. + + Returns: + cost (torch.Tensor): (b) + + """ + if xyz1.dim() == 2: + xyz1 = xyz1.unsqueeze(0) + if xyz2.dim() == 2: + xyz2 = xyz2.unsqueeze(0) + if transpose: + xyz1 = xyz1.transpose(1, 2) + xyz2 = xyz2.transpose(1, 2) + cost = EarthMoverDistanceFunction.apply(xyz1, xyz2) + cost = cost / xyz1.shape[1] + return cost + diff --git a/metrics/PyTorchEMD/setup.py b/metrics/PyTorchEMD/setup.py new file mode 100644 index 0000000..f648c3e --- /dev/null +++ b/metrics/PyTorchEMD/setup.py @@ -0,0 +1,27 @@ +"""Setup extension + +Notes: + If extra_compile_args is provided, you need to provide different instances for different extensions. + Refer to https://github.com/pytorch/pytorch/issues/20169 + +""" + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + +setup( + name='emd_ext', + ext_modules=[ + CUDAExtension( + name='emd_cuda', + sources=[ + 'cuda/emd.cpp', + 'cuda/emd_kernel.cu', + ], + extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']} + ), + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/metrics/PyTorchEMD/test_emd_loss.py b/metrics/PyTorchEMD/test_emd_loss.py new file mode 100644 index 0000000..66aa33c --- /dev/null +++ b/metrics/PyTorchEMD/test_emd_loss.py @@ -0,0 +1,44 @@ +import torch +import numpy as np +import time +from emd import earth_mover_distance + +# gt +p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda() +p1 = p1.repeat(3, 1, 1) +p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda() +p2 = p2.repeat(3, 1, 1) +print(p1) +print(p2) +p1.requires_grad = True +p2.requires_grad = True + +gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \ + (((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \ + (((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3 +print('gt_dist: ', gt_dist) + +gt_dist.backward() +print(p1.grad) +print(p2.grad) + +# emd +p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda() +p1 = p1.repeat(3, 1, 1) +p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda() +p2 = p2.repeat(3, 1, 1) +print(p1) +print(p2) +p1.requires_grad = True +p2.requires_grad = True + +d = earth_mover_distance(p1, p2, transpose=False) +print(d) + +loss = d[0] / 2 + d[1] * 2 + d[2] / 3 +print(loss) + +loss.backward() +print(p1.grad) +print(p2.grad) + diff --git a/metrics/__init__.py b/metrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/metrics/evaluation_metrics.py b/metrics/evaluation_metrics.py new file mode 100644 index 0000000..c0cbf62 --- /dev/null +++ b/metrics/evaluation_metrics.py @@ -0,0 +1,322 @@ +import torch +import numpy as np +import warnings +from scipy.stats import entropy +from sklearn.neighbors import NearestNeighbors +from numpy.linalg import norm + +from metrics.PyTorchEMD.emd import earth_mover_distance as EMD +from metrics.ChamferDistancePytorch.chamfer3D.dist_chamfer_3D import chamfer_3DDist +from metrics.ChamferDistancePytorch.fscore import fscore +from tqdm import tqdm + +cham3D = chamfer_3DDist() + +# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet +def distChamfer(a, b): + x, y = a, b + bs, num_points, points_dim = x.size() + xx = torch.bmm(x, x.transpose(2, 1)) + yy = torch.bmm(y, y.transpose(2, 1)) + zz = torch.bmm(x, y.transpose(2, 1)) + diag_ind = torch.arange(0, num_points).to(a).long() + rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx) + ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy) + P = (rx.transpose(2, 1) + ry - 2 * zz) + return P.min(1)[0], P.min(2)[0] + + +def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True): + N_sample = sample_pcs.shape[0] + N_ref = ref_pcs.shape[0] + assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample) + + cd_lst = [] + emd_lst = [] + fs_lst = [] + iterator = range(0, N_sample, batch_size) + + for b_start in iterator: + b_end = min(N_sample, b_start + batch_size) + sample_batch = sample_pcs[b_start:b_end] + ref_batch = ref_pcs[b_start:b_end] + + dl, dr, _, _ = cham3D(sample_batch.cuda(), ref_batch.cuda()) + fs = fscore(dl, dr)[0].cpu() + fs_lst.append(fs) + cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1)) + + emd_batch = EMD(sample_batch.cuda(), ref_batch.cuda(), transpose=False) + emd_lst.append(emd_batch) + + if reduced: + cd = torch.cat(cd_lst).mean() + emd = torch.cat(emd_lst).mean() + else: + cd = torch.cat(cd_lst) + emd = torch.cat(emd_lst) + fs_lst = torch.cat(fs_lst).mean() + results = { + 'MMD-CD': cd, + 'MMD-EMD': emd, + 'fscore': fs_lst + } + return results + +def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True): + N_sample = sample_pcs.shape[0] + N_ref = ref_pcs.shape[0] + all_cd = [] + all_emd = [] + iterator = range(N_sample) + for sample_b_start in tqdm(iterator): + sample_batch = sample_pcs[sample_b_start] + + cd_lst = [] + emd_lst = [] + for ref_b_start in range(0, N_ref, batch_size): + ref_b_end = min(N_ref, ref_b_start + batch_size) + ref_batch = ref_pcs[ref_b_start:ref_b_end] + + batch_size_ref = ref_batch.size(0) + sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1) + sample_batch_exp = sample_batch_exp.contiguous() + + dl, dr, _, _ = cham3D(sample_batch_exp.cuda(), ref_batch.cuda()) + cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1).detach().cpu()) + + emd_batch = EMD(sample_batch_exp.cuda(), ref_batch.cuda(), transpose=False) + emd_lst.append(emd_batch.view(1, -1).detach().cpu()) + + cd_lst = torch.cat(cd_lst, dim=1) + emd_lst = torch.cat(emd_lst, dim=1) + all_cd.append(cd_lst) + all_emd.append(emd_lst) + + all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref + all_emd = torch.cat(all_emd, dim=0) # N_sample, N_ref + + return all_cd, all_emd + + +# Adapted from https://github.com/xuqiantong/GAN-Metrics/blob/master/framework/metric.py +def knn(Mxx, Mxy, Myy, k, sqrt=False): + n0 = Mxx.size(0) + n1 = Myy.size(0) + label = torch.cat((torch.ones(n0), torch.zeros(n1))).to(Mxx) + M = torch.cat((torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.transpose(0, 1), Myy), 1)), 0) + if sqrt: + M = M.abs().sqrt() + INFINITY = float('inf') + val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(k, 0, False) + + count = torch.zeros(n0 + n1).to(Mxx) + for i in range(0, k): + count = count + label.index_select(0, idx[i]) + pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float() + + s = { + 'tp': (pred * label).sum(), + 'fp': (pred * (1 - label)).sum(), + 'fn': ((1 - pred) * label).sum(), + 'tn': ((1 - pred) * (1 - label)).sum(), + } + + s.update({ + 'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10), + 'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10), + 'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10), + 'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10), + 'acc': torch.eq(label, pred).float().mean(), + }) + return s + + +def lgan_mmd_cov(all_dist): + N_sample, N_ref = all_dist.size(0), all_dist.size(1) + min_val_fromsmp, min_idx = torch.min(all_dist, dim=1) + min_val, _ = torch.min(all_dist, dim=0) + mmd = min_val.mean() + mmd_smp = min_val_fromsmp.mean() + cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref) + cov = torch.tensor(cov).to(all_dist) + return { + 'lgan_mmd': mmd, + 'lgan_cov': cov, + 'lgan_mmd_smp': mmd_smp, + } + + +def compute_all_metrics(sample_pcs, ref_pcs, batch_size): + results = {} + + M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size) + + res_cd = lgan_mmd_cov(M_rs_cd.t()) + results.update({ + "%s-CD" % k: v for k, v in res_cd.items() + }) + + res_emd = lgan_mmd_cov(M_rs_emd.t()) + results.update({ + "%s-EMD" % k: v for k, v in res_emd.items() + }) + + M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size) + M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size) + + # 1-NN results + one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False) + results.update({ + "1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k + }) + one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False) + results.update({ + "1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k + }) + + return results + + +####################################################### +# JSD : from https://github.com/optas/latent_3d_points +####################################################### +def unit_cube_grid_point_cloud(resolution, clip_sphere=False): + """Returns the center coordinates of each cell of a 3D grid with resolution^3 cells, + that is placed in the unit-cube. + If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere. + """ + grid = np.ndarray((resolution, resolution, resolution, 3), np.float32) + spacing = 1.0 / float(resolution - 1) + for i in range(resolution): + for j in range(resolution): + for k in range(resolution): + grid[i, j, k, 0] = i * spacing - 0.5 + grid[i, j, k, 1] = j * spacing - 0.5 + grid[i, j, k, 2] = k * spacing - 0.5 + + if clip_sphere: + grid = grid.reshape(-1, 3) + grid = grid[norm(grid, axis=1) <= 0.5] + + return grid, spacing + + +def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28): + """Computes the JSD between two sets of point-clouds, as introduced in the paper + ```Learning Representations And Generative Models For 3D Point Clouds```. + Args: + sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points. + ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points. + resolution: (int) grid-resolution. Affects granularity of measurements. + """ + in_unit_sphere = True + sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1] + ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1] + return jensen_shannon_divergence(sample_grid_var, ref_grid_var) + + +def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose=False): + """Given a collection of point-clouds, estimate the entropy of the random variables + corresponding to occupancy-grid activation patterns. + Inputs: + pclouds: (numpy array) #point-clouds x points per point-cloud x 3 + grid_resolution (int) size of occupancy grid that will be used. + """ + epsilon = 10e-4 + bound = 0.5 + epsilon + if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound: + if verbose: + warnings.warn('Point-clouds are not in unit cube.') + + if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound: + if verbose: + warnings.warn('Point-clouds are not in unit sphere.') + + grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere) + grid_coordinates = grid_coordinates.reshape(-1, 3) + grid_counters = np.zeros(len(grid_coordinates)) + grid_bernoulli_rvars = np.zeros(len(grid_coordinates)) + nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates) + + for pc in pclouds: + _, indices = nn.kneighbors(pc) + indices = np.squeeze(indices) + for i in indices: + grid_counters[i] += 1 + indices = np.unique(indices) + for i in indices: + grid_bernoulli_rvars[i] += 1 + + acc_entropy = 0.0 + n = float(len(pclouds)) + for g in grid_bernoulli_rvars: + if g > 0: + p = float(g) / n + acc_entropy += entropy([p, 1.0 - p]) + + return acc_entropy / len(grid_counters), grid_counters + + +def jensen_shannon_divergence(P, Q): + if np.any(P < 0) or np.any(Q < 0): + raise ValueError('Negative values.') + if len(P) != len(Q): + raise ValueError('Non equal size.') + + P_ = P / np.sum(P) # Ensure probabilities. + Q_ = Q / np.sum(Q) + + e1 = entropy(P_, base=2) + e2 = entropy(Q_, base=2) + e_sum = entropy((P_ + Q_) / 2.0, base=2) + res = e_sum - ((e1 + e2) / 2.0) + + res2 = _jsdiv(P_, Q_) + + if not np.allclose(res, res2, atol=10e-5, rtol=0): + warnings.warn('Numerical values of two JSD methods don\'t agree.') + + return res + + +def _jsdiv(P, Q): + """another way of computing JSD""" + + def _kldiv(A, B): + a = A.copy() + b = B.copy() + idx = np.logical_and(a > 0, b > 0) + a = a[idx] + b = b[idx] + return np.sum([v for v in a * np.log2(a / b)]) + + P_ = P / np.sum(P) + Q_ = Q / np.sum(Q) + + M = 0.5 * (P_ + Q_) + + return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M)) + + +if __name__ == "__main__": + B, N = 2, 10 + x = torch.rand(B, N, 3) + y = torch.rand(B, N, 3) + + min_l, min_r = distChamfer(x.cuda(), y.cuda()) + print(min_l.shape) + print(min_r.shape) + + l_dist = min_l.mean().cpu().detach().item() + r_dist = min_r.mean().cpu().detach().item() + print(l_dist, r_dist) + + + emd_batch = EMD(x.cuda(), y.cuda(), False) + print(emd_batch.shape) + print(emd_batch.mean().detach().item()) + + jsd = jsd_between_point_cloud_sets(x.numpy(), y.numpy()) + print(jsd) + diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/pvcnn_completion.py b/model/pvcnn_completion.py new file mode 100644 index 0000000..db48b6b --- /dev/null +++ b/model/pvcnn_completion.py @@ -0,0 +1,253 @@ +import functools + +import torch.nn as nn +import torch +import numpy as np +from modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Attention, Swish + + +def _linear_gn_relu(in_channels, out_channels): + return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish()) + + +def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1): + r = width_multiplier + + if dim == 1: + block = _linear_gn_relu + else: + block = SharedMLP + if not isinstance(out_channels, (list, tuple)): + out_channels = [out_channels] + if len(out_channels) == 0 or (len(out_channels) == 1 and out_channels[0] is None): + return nn.Sequential(), in_channels, in_channels + + layers = [] + for oc in out_channels[:-1]: + if oc < 1: + layers.append(nn.Dropout(oc)) + else: + oc = int(r * oc) + layers.append(block(in_channels, oc)) + in_channels = oc + if dim == 1: + if classifier: + layers.append(nn.Linear(in_channels, out_channels[-1])) + else: + layers.append(_linear_gn_relu(in_channels, int(r * out_channels[-1]))) + else: + if classifier: + layers.append(nn.Conv1d(in_channels, out_channels[-1], 1)) + else: + layers.append(SharedMLP(in_channels, int(r * out_channels[-1]))) + return layers, out_channels[-1] if classifier else int(r * out_channels[-1]) + + +def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0, + width_multiplier=1, voxel_resolution_multiplier=1): + r, vr = width_multiplier, voxel_resolution_multiplier + + layers, concat_channels = [], 0 + c = 0 + for k, (out_channels, num_blocks, voxel_resolution) in enumerate(blocks): + out_channels = int(r * out_channels) + for p in range(num_blocks): + attention = k % 2 == 0 and k > 0 and p == 0 + if voxel_resolution is None: + block = SharedMLP + else: + block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, + with_se=with_se, normalize=normalize, eps=eps) + + if c == 0: + layers.append(block(in_channels, out_channels)) + else: + layers.append(block(in_channels+embed_dim, out_channels)) + in_channels = out_channels + concat_channels += out_channels + c += 1 + return layers, in_channels, concat_channels + + +def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=64, use_att=False, + dropout=0.1, with_se=False, normalize=True, eps=0, + width_multiplier=1, voxel_resolution_multiplier=1): + r, vr = width_multiplier, voxel_resolution_multiplier + in_channels = extra_feature_channels + 3 + + sa_layers, sa_in_channels = [], [] + c = 0 + for conv_configs, sa_configs in sa_blocks: + k = 0 + sa_in_channels.append(in_channels) + sa_blocks = [] + + if conv_configs is not None: + out_channels, num_blocks, voxel_resolution = conv_configs + out_channels = int(r * out_channels) + for p in range(num_blocks): + attention = (c+1) % 2 == 0 and c > 0 and use_att and p == 0 + if voxel_resolution is None: + block = SharedMLP + else: + block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, + dropout=dropout, + with_se=with_se and not attention, with_se_relu=True, + normalize=normalize, eps=eps) + + if c == 0: + sa_blocks.append(block(in_channels, out_channels)) + elif k ==0: + sa_blocks.append(block(in_channels+embed_dim, out_channels)) + in_channels = out_channels + k += 1 + extra_feature_channels = in_channels + num_centers, radius, num_neighbors, out_channels = sa_configs + _out_channels = [] + for oc in out_channels: + if isinstance(oc, (list, tuple)): + _out_channels.append([int(r * _oc) for _oc in oc]) + else: + _out_channels.append(int(r * oc)) + out_channels = _out_channels + if num_centers is None: + block = PointNetAModule + else: + block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius, + num_neighbors=num_neighbors) + sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels, + include_coordinates=True)) + c += 1 + in_channels = extra_feature_channels = sa_blocks[-1].out_channels + if len(sa_blocks) == 1: + sa_layers.append(sa_blocks[0]) + else: + sa_layers.append(nn.Sequential(*sa_blocks)) + + return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers + + +def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_points, embed_dim=64, use_att=False, + dropout=0.1, + with_se=False, normalize=True, eps=0, + width_multiplier=1, voxel_resolution_multiplier=1): + r, vr = width_multiplier, voxel_resolution_multiplier + + fp_layers = [] + c = 0 + for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks): + fp_blocks = [] + out_channels = tuple(int(r * oc) for oc in fp_configs) + fp_blocks.append( + PointNetFPModule(in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels) + ) + in_channels = out_channels[-1] + + if conv_configs is not None: + out_channels, num_blocks, voxel_resolution = conv_configs + out_channels = int(r * out_channels) + for p in range(num_blocks): + attention = c % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0 + if voxel_resolution is None: + block = SharedMLP + else: + block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, + dropout=dropout, + with_se=with_se and not attention,with_se_relu=True, normalize=normalize, eps=eps) + + fp_blocks.append(block(in_channels, out_channels)) + in_channels = out_channels + if len(fp_blocks) == 1: + fp_layers.append(fp_blocks[0]) + else: + fp_layers.append(nn.Sequential(*fp_blocks)) + + c += 1 + + return fp_layers, in_channels + + +class PVCNN2Base(nn.Module): + + def __init__(self, num_classes, sv_points, embed_dim, use_att, dropout=0.1, + extra_feature_channels=3, width_multiplier=1, voxel_resolution_multiplier=1): + super().__init__() + assert extra_feature_channels >= 0 + self.embed_dim = embed_dim + self.sv_points = sv_points + self.in_channels = extra_feature_channels + 3 + + sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components( + sa_blocks=self.sa_blocks, extra_feature_channels=extra_feature_channels, with_se=True, embed_dim=embed_dim, + use_att=use_att, dropout=dropout, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + self.sa_layers = nn.ModuleList(sa_layers) + + self.global_att = None if not use_att else Attention(channels_sa_features, 8, D=1) + + # only use extra features in the last fp module + sa_in_channels[0] = extra_feature_channels + fp_layers, channels_fp_features = create_pointnet2_fp_modules( + fp_blocks=self.fp_blocks, in_channels=channels_sa_features, sa_in_channels=sa_in_channels,sv_points=sv_points, + with_se=True, embed_dim=embed_dim, + use_att=use_att, dropout=dropout, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + self.fp_layers = nn.ModuleList(fp_layers) + + + layers, _ = create_mlp_components(in_channels=channels_fp_features, out_channels=[128, 0.5, num_classes], + classifier=True, dim=2, width_multiplier=width_multiplier) + self.classifier = nn.Sequential(*layers) + + self.embedf = nn.Sequential( + nn.Linear(embed_dim, embed_dim), + nn.LeakyReLU(0.1, inplace=True), + nn.Linear(embed_dim, embed_dim), + ) + + def get_timestep_embedding(self, timesteps, device): + assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 + + half_dim = self.embed_dim // 2 + emb = np.log(10000) / (half_dim - 1) + emb = torch.from_numpy(np.exp(np.arange(0, half_dim) * -emb)).float().to(device) + # emb = tf.range(num_embeddings, dtype=DEFAULT_DTYPE)[:, None] * emb[None, :] + emb = timesteps[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if self.embed_dim % 2 == 1: # zero pad + # emb = tf.concat([emb, tf.zeros([num_embeddings, 1])], axis=1) + emb = nn.functional.pad(emb, (0, 1), "constant", 0) + assert emb.shape == torch.Size([timesteps.shape[0], self.embed_dim]) + return emb + + def forward(self, inputs, t): + + temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:,:,None].expand(-1,-1,inputs.shape[-1]) + + # inputs : [B, in_channels + S, N] + coords, features = inputs[:, :3, :].contiguous(), inputs + coords_list, in_features_list = [], [] + for i, sa_blocks in enumerate(self.sa_layers): + in_features_list.append(features) + coords_list.append(coords) + if i == 0: + features, coords, temb = sa_blocks ((features, coords, temb)) + else: + features, coords, temb = sa_blocks ((torch.cat([features,temb],dim=1), coords, temb)) + in_features_list[0] = inputs[:, 3:, :].contiguous() + if self.global_att is not None: + features = self.global_att(features) + for fp_idx, fp_blocks in enumerate(self.fp_layers): + jump_coords = coords_list[-1 - fp_idx] + fump_feats = in_features_list[-1-fp_idx] + # if fp_idx == len(self.fp_layers) - 1: + # jump_coords = jump_coords[:,:,self.sv_points:] + # fump_feats = fump_feats[:,:,self.sv_points:] + + features, coords, temb = fp_blocks((jump_coords, coords, torch.cat([features,temb],dim=1), fump_feats, temb)) + + return self.classifier(features) + + diff --git a/model/pvcnn_generation.py b/model/pvcnn_generation.py new file mode 100644 index 0000000..3926b9e --- /dev/null +++ b/model/pvcnn_generation.py @@ -0,0 +1,247 @@ +import functools + +import torch.nn as nn +import torch +import numpy as np +from modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Attention, Swish + + +def _linear_gn_relu(in_channels, out_channels): + return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish()) + + +def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1): + r = width_multiplier + + if dim == 1: + block = _linear_gn_relu + else: + block = SharedMLP + if not isinstance(out_channels, (list, tuple)): + out_channels = [out_channels] + if len(out_channels) == 0 or (len(out_channels) == 1 and out_channels[0] is None): + return nn.Sequential(), in_channels, in_channels + + layers = [] + for oc in out_channels[:-1]: + if oc < 1: + layers.append(nn.Dropout(oc)) + else: + oc = int(r * oc) + layers.append(block(in_channels, oc)) + in_channels = oc + if dim == 1: + if classifier: + layers.append(nn.Linear(in_channels, out_channels[-1])) + else: + layers.append(_linear_gn_relu(in_channels, int(r * out_channels[-1]))) + else: + if classifier: + layers.append(nn.Conv1d(in_channels, out_channels[-1], 1)) + else: + layers.append(SharedMLP(in_channels, int(r * out_channels[-1]))) + return layers, out_channels[-1] if classifier else int(r * out_channels[-1]) + + +def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0, + width_multiplier=1, voxel_resolution_multiplier=1): + r, vr = width_multiplier, voxel_resolution_multiplier + + layers, concat_channels = [], 0 + c = 0 + for k, (out_channels, num_blocks, voxel_resolution) in enumerate(blocks): + out_channels = int(r * out_channels) + for p in range(num_blocks): + attention = k % 2 == 0 and k > 0 and p == 0 + if voxel_resolution is None: + block = SharedMLP + else: + block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, + with_se=with_se, normalize=normalize, eps=eps) + + if c == 0: + layers.append(block(in_channels, out_channels)) + else: + layers.append(block(in_channels+embed_dim, out_channels)) + in_channels = out_channels + concat_channels += out_channels + c += 1 + return layers, in_channels, concat_channels + + +def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=64, use_att=False, + dropout=0.1, with_se=False, normalize=True, eps=0, + width_multiplier=1, voxel_resolution_multiplier=1): + r, vr = width_multiplier, voxel_resolution_multiplier + in_channels = extra_feature_channels + 3 + + sa_layers, sa_in_channels = [], [] + c = 0 + for conv_configs, sa_configs in sa_blocks: + k = 0 + sa_in_channels.append(in_channels) + sa_blocks = [] + + if conv_configs is not None: + out_channels, num_blocks, voxel_resolution = conv_configs + out_channels = int(r * out_channels) + for p in range(num_blocks): + attention = (c+1) % 2 == 0 and use_att and p == 0 + if voxel_resolution is None: + block = SharedMLP + else: + block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, + dropout=dropout, + with_se=with_se, with_se_relu=True, + normalize=normalize, eps=eps) + + if c == 0: + sa_blocks.append(block(in_channels, out_channels)) + elif k ==0: + sa_blocks.append(block(in_channels+embed_dim, out_channels)) + in_channels = out_channels + k += 1 + extra_feature_channels = in_channels + num_centers, radius, num_neighbors, out_channels = sa_configs + _out_channels = [] + for oc in out_channels: + if isinstance(oc, (list, tuple)): + _out_channels.append([int(r * _oc) for _oc in oc]) + else: + _out_channels.append(int(r * oc)) + out_channels = _out_channels + if num_centers is None: + block = PointNetAModule + else: + block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius, + num_neighbors=num_neighbors) + sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels, + include_coordinates=True)) + c += 1 + in_channels = extra_feature_channels = sa_blocks[-1].out_channels + if len(sa_blocks) == 1: + sa_layers.append(sa_blocks[0]) + else: + sa_layers.append(nn.Sequential(*sa_blocks)) + + return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers + + +def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False, + dropout=0.1, + with_se=False, normalize=True, eps=0, + width_multiplier=1, voxel_resolution_multiplier=1): + r, vr = width_multiplier, voxel_resolution_multiplier + + fp_layers = [] + c = 0 + for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks): + fp_blocks = [] + out_channels = tuple(int(r * oc) for oc in fp_configs) + fp_blocks.append( + PointNetFPModule(in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels) + ) + in_channels = out_channels[-1] + + if conv_configs is not None: + out_channels, num_blocks, voxel_resolution = conv_configs + out_channels = int(r * out_channels) + for p in range(num_blocks): + attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0 + if voxel_resolution is None: + block = SharedMLP + else: + block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, + dropout=dropout, + with_se=with_se, with_se_relu=True, + normalize=normalize, eps=eps) + + fp_blocks.append(block(in_channels, out_channels)) + in_channels = out_channels + if len(fp_blocks) == 1: + fp_layers.append(fp_blocks[0]) + else: + fp_layers.append(nn.Sequential(*fp_blocks)) + + c += 1 + + return fp_layers, in_channels + + + +class PVCNN2Base(nn.Module): + + def __init__(self, num_classes, embed_dim, use_att, dropout=0.1, + extra_feature_channels=3, width_multiplier=1, voxel_resolution_multiplier=1): + super().__init__() + assert extra_feature_channels >= 0 + self.embed_dim = embed_dim + self.in_channels = extra_feature_channels + 3 + + sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components( + sa_blocks=self.sa_blocks, extra_feature_channels=extra_feature_channels, with_se=True, embed_dim=embed_dim, + use_att=use_att, dropout=dropout, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + self.sa_layers = nn.ModuleList(sa_layers) + + self.global_att = None if not use_att else Attention(channels_sa_features, 8, D=1) + + # only use extra features in the last fp module + sa_in_channels[0] = extra_feature_channels + fp_layers, channels_fp_features = create_pointnet2_fp_modules( + fp_blocks=self.fp_blocks, in_channels=channels_sa_features, sa_in_channels=sa_in_channels, with_se=True, embed_dim=embed_dim, + use_att=use_att, dropout=dropout, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + self.fp_layers = nn.ModuleList(fp_layers) + + + layers, _ = create_mlp_components(in_channels=channels_fp_features, out_channels=[128, dropout, num_classes], # was 0.5 + classifier=True, dim=2, width_multiplier=width_multiplier) + self.classifier = nn.Sequential(*layers) + + self.embedf = nn.Sequential( + nn.Linear(embed_dim, embed_dim), + nn.LeakyReLU(0.1, inplace=True), + nn.Linear(embed_dim, embed_dim), + ) + + def get_timestep_embedding(self, timesteps, device): + assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 + + half_dim = self.embed_dim // 2 + emb = np.log(10000) / (half_dim - 1) + emb = torch.from_numpy(np.exp(np.arange(0, half_dim) * -emb)).float().to(device) + # emb = tf.range(num_embeddings, dtype=DEFAULT_DTYPE)[:, None] * emb[None, :] + emb = timesteps[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if self.embed_dim % 2 == 1: # zero pad + # emb = tf.concat([emb, tf.zeros([num_embeddings, 1])], axis=1) + emb = nn.functional.pad(emb, (0, 1), "constant", 0) + assert emb.shape == torch.Size([timesteps.shape[0], self.embed_dim]) + return emb + + def forward(self, inputs, t): + + temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:,:,None].expand(-1,-1,inputs.shape[-1]) + + # inputs : [B, in_channels + S, N] + coords, features = inputs[:, :3, :].contiguous(), inputs + coords_list, in_features_list = [], [] + for i, sa_blocks in enumerate(self.sa_layers): + in_features_list.append(features) + coords_list.append(coords) + if i == 0: + features, coords, temb = sa_blocks ((features, coords, temb)) + else: + features, coords, temb = sa_blocks ((torch.cat([features,temb],dim=1), coords, temb)) + in_features_list[0] = inputs[:, 3:, :].contiguous() + if self.global_att is not None: + features = self.global_att(features) + for fp_idx, fp_blocks in enumerate(self.fp_layers): + features, coords, temb = fp_blocks((coords_list[-1-fp_idx], coords, torch.cat([features,temb],dim=1), in_features_list[-1-fp_idx], temb)) + + return self.classifier(features) + + diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000..89290fc --- /dev/null +++ b/modules/__init__.py @@ -0,0 +1,8 @@ +from modules.ball_query import BallQuery +from modules.frustum import FrustumPointNetLoss +from modules.loss import KLLoss +from modules.pointnet import PointNetAModule, PointNetSAModule, PointNetFPModule +from modules.pvconv import PVConv, Attention, Swish, PVConvReLU +from modules.se import SE3d +from modules.shared_mlp import SharedMLP +from modules.voxelization import Voxelization diff --git a/modules/ball_query.py b/modules/ball_query.py new file mode 100644 index 0000000..20251d0 --- /dev/null +++ b/modules/ball_query.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn + +import modules.functional as F + +__all__ = ['BallQuery'] + + +class BallQuery(nn.Module): + def __init__(self, radius, num_neighbors, include_coordinates=True): + super().__init__() + self.radius = radius + self.num_neighbors = num_neighbors + self.include_coordinates = include_coordinates + + def forward(self, points_coords, centers_coords, temb, points_features=None): + points_coords = points_coords.contiguous() + centers_coords = centers_coords.contiguous() + neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors) + neighbor_coordinates = F.grouping(points_coords, neighbor_indices) + neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1) + + if points_features is None: + assert self.include_coordinates, 'No Features For Grouping' + neighbor_features = neighbor_coordinates + else: + neighbor_features = F.grouping(points_features, neighbor_indices) + if self.include_coordinates: + neighbor_features = torch.cat([neighbor_coordinates, neighbor_features], dim=1) + return neighbor_features, F.grouping(temb, neighbor_indices) + + def extra_repr(self): + return 'radius={}, num_neighbors={}{}'.format( + self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '') diff --git a/modules/frustum.py b/modules/frustum.py new file mode 100644 index 0000000..e8d95d2 --- /dev/null +++ b/modules/frustum.py @@ -0,0 +1,138 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import modules.functional as PF + +__all__ = ['FrustumPointNetLoss', 'get_box_corners_3d'] + + +class FrustumPointNetLoss(nn.Module): + def __init__(self, num_heading_angle_bins, num_size_templates, size_templates, box_loss_weight=1.0, + corners_loss_weight=10.0, heading_residual_loss_weight=20.0, size_residual_loss_weight=20.0): + super().__init__() + self.box_loss_weight = box_loss_weight + self.corners_loss_weight = corners_loss_weight + self.heading_residual_loss_weight = heading_residual_loss_weight + self.size_residual_loss_weight = size_residual_loss_weight + + self.num_heading_angle_bins = num_heading_angle_bins + self.num_size_templates = num_size_templates + self.register_buffer('size_templates', size_templates.view(self.num_size_templates, 3)) + self.register_buffer( + 'heading_angle_bin_centers', torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins) + ) + + def forward(self, inputs, targets): + mask_logits = inputs['mask_logits'] # (B, 2, N) + center_reg = inputs['center_reg'] # (B, 3) + center = inputs['center'] # (B, 3) + heading_scores = inputs['heading_scores'] # (B, NH) + heading_residuals_normalized = inputs['heading_residuals_normalized'] # (B, NH) + heading_residuals = inputs['heading_residuals'] # (B, NH) + size_scores = inputs['size_scores'] # (B, NS) + size_residuals_normalized = inputs['size_residuals_normalized'] # (B, NS, 3) + size_residuals = inputs['size_residuals'] # (B, NS, 3) + + mask_logits_target = targets['mask_logits'] # (B, N) + center_target = targets['center'] # (B, 3) + heading_bin_id_target = targets['heading_bin_id'] # (B, ) + heading_residual_target = targets['heading_residual'] # (B, ) + size_template_id_target = targets['size_template_id'] # (B, ) + size_residual_target = targets['size_residual'] # (B, 3) + + batch_size = center.size(0) + batch_id = torch.arange(batch_size, device=center.device) + + # Basic Classification and Regression losses + mask_loss = F.cross_entropy(mask_logits, mask_logits_target) + heading_loss = F.cross_entropy(heading_scores, heading_bin_id_target) + size_loss = F.cross_entropy(size_scores, size_template_id_target) + center_loss = PF.huber_loss(torch.norm(center_target - center, dim=-1), delta=2.0) + center_reg_loss = PF.huber_loss(torch.norm(center_target - center_reg, dim=-1), delta=1.0) + + # Refinement losses for size/heading + heading_residuals_normalized = heading_residuals_normalized[batch_id, heading_bin_id_target] # (B, ) + heading_residual_normalized_target = heading_residual_target / (np.pi / self.num_heading_angle_bins) + heading_residual_normalized_loss = PF.huber_loss( + heading_residuals_normalized - heading_residual_normalized_target, delta=1.0 + ) + size_residuals_normalized = size_residuals_normalized[batch_id, size_template_id_target] # (B, 3) + size_residual_normalized_target = size_residual_target / self.size_templates[size_template_id_target] + size_residual_normalized_loss = PF.huber_loss( + torch.norm(size_residual_normalized_target - size_residuals_normalized, dim=-1), delta=1.0 + ) + + # Bounding box losses + heading = (heading_residuals[batch_id, heading_bin_id_target] + + self.heading_angle_bin_centers[heading_bin_id_target]) # (B, ) + # Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets) + size = (size_residuals[batch_id, size_template_id_target] + + self.size_templates[size_template_id_target]) # (B, 3) + corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8) + heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, ) + size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3) + corners_target, corners_target_flip = get_box_corners_3d(centers=center_target, headings=heading_target, + sizes=size_target, with_flip=True) # (B, 3, 8) + corners_loss = PF.huber_loss(torch.min( + torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1) + ), delta=1.0) + # Summing up + loss = mask_loss + self.box_loss_weight * ( + center_loss + center_reg_loss + heading_loss + size_loss + + self.heading_residual_loss_weight * heading_residual_normalized_loss + + self.size_residual_loss_weight * size_residual_normalized_loss + + self.corners_loss_weight * corners_loss + ) + + return loss + + +def get_box_corners_3d(centers, headings, sizes, with_flip=False): + """ + :param centers: coords of box centers, FloatTensor[N, 3] + :param headings: heading angles, FloatTensor[N, ] + :param sizes: box sizes, FloatTensor[N, 3] + :param with_flip: bool, whether to return flipped box (headings + np.pi) + :return: + coords of box corners, FloatTensor[N, 3, 8] + NOTE: corner points are in counter clockwise order, e.g., + 2--1 + 3--0 5 + 7--4 + """ + l = sizes[:, 0] # (N,) + w = sizes[:, 1] # (N,) + h = sizes[:, 2] # (N,) + x_corners = torch.stack([l/2, l/2, -l/2, -l/2, l/2, l/2, -l/2, -l/2], dim=1) # (N, 8) + y_corners = torch.stack([h/2, h/2, h/2, h/2, -h/2, -h/2, -h/2, -h/2], dim=1) # (N, 8) + z_corners = torch.stack([w/2, -w/2, -w/2, w/2, w/2, -w/2, -w/2, w/2], dim=1) # (N, 8) + + c = torch.cos(headings) # (N,) + s = torch.sin(headings) # (N,) + o = torch.ones_like(headings) # (N,) + z = torch.zeros_like(headings) # (N,) + + centers = centers.unsqueeze(-1) # (B, 3, 1) + corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8) + R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # roty matrix: (N, 3, 3) + if with_flip: + R_flip = torch.stack([-c, z, -s, z, o, z, s, z, -c], dim=1).view(-1, 3, 3) + return torch.matmul(R, corners) + centers, torch.matmul(R_flip, corners) + centers + else: + return torch.matmul(R, corners) + centers + + # centers = centers.unsqueeze(1) # (B, 1, 3) + # corners = torch.stack([x_corners, y_corners, z_corners], dim=-1) # (N, 8, 3) + # RT = torch.stack([c, z, -s, z, o, z, s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3) + # if with_flip: + # RT_flip = torch.stack([-c, z, s, z, o, z, -s, z, -c], dim=1).view(-1, 3, 3) # (N, 3, 3) + # return torch.matmul(corners, RT) + centers, torch.matmul(corners, RT_flip) + centers # (N, 8, 3) + # else: + # return torch.matmul(corners, RT) + centers # (N, 8, 3) + + # corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8) + # R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3) + # corners = torch.matmul(R, corners) + centers.unsqueeze(2) # (N, 3, 8) + # corners = corners.transpose(1, 2) # (N, 8, 3) diff --git a/modules/functional/__init__.py b/modules/functional/__init__.py new file mode 100644 index 0000000..ce707cc --- /dev/null +++ b/modules/functional/__init__.py @@ -0,0 +1,7 @@ +from modules.functional.ball_query import ball_query +from modules.functional.devoxelization import trilinear_devoxelize +from modules.functional.grouping import grouping +from modules.functional.interpolatation import nearest_neighbor_interpolate +from modules.functional.loss import kl_loss, huber_loss +from modules.functional.sampling import gather, furthest_point_sample, logits_mask +from modules.functional.voxelization import avg_voxelize diff --git a/modules/functional/backend.py b/modules/functional/backend.py new file mode 100644 index 0000000..794e0d6 --- /dev/null +++ b/modules/functional/backend.py @@ -0,0 +1,26 @@ +import os + +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) +_backend = load(name='_pvcnn_backend', + extra_cflags=['-O3', '-std=c++17'], + extra_cuda_cflags=['--compiler-bindir=/usr/bin/gcc-8'], + sources=[os.path.join(_src_path,'src', f) for f in [ + 'ball_query/ball_query.cpp', + 'ball_query/ball_query.cu', + 'grouping/grouping.cpp', + 'grouping/grouping.cu', + 'interpolate/neighbor_interpolate.cpp', + 'interpolate/neighbor_interpolate.cu', + 'interpolate/trilinear_devox.cpp', + 'interpolate/trilinear_devox.cu', + 'sampling/sampling.cpp', + 'sampling/sampling.cu', + 'voxelization/vox.cpp', + 'voxelization/vox.cu', + 'bindings.cpp', + ]] + ) + +__all__ = ['_backend'] diff --git a/modules/functional/ball_query.py b/modules/functional/ball_query.py new file mode 100644 index 0000000..a99df0d --- /dev/null +++ b/modules/functional/ball_query.py @@ -0,0 +1,19 @@ +from torch.autograd import Function + +from modules.functional.backend import _backend + +__all__ = ['ball_query'] + + +def ball_query(centers_coords, points_coords, radius, num_neighbors): + """ + :param centers_coords: coordinates of centers, FloatTensor[B, 3, M] + :param points_coords: coordinates of points, FloatTensor[B, 3, N] + :param radius: float, radius of ball query + :param num_neighbors: int, maximum number of neighbors + :return: + neighbor_indices: indices of neighbors, IntTensor[B, M, U] + """ + centers_coords = centers_coords.contiguous() + points_coords = points_coords.contiguous() + return _backend.ball_query(centers_coords, points_coords, radius, num_neighbors) diff --git a/modules/functional/devoxelization.py b/modules/functional/devoxelization.py new file mode 100644 index 0000000..b037f48 --- /dev/null +++ b/modules/functional/devoxelization.py @@ -0,0 +1,42 @@ +from torch.autograd import Function + +from modules.functional.backend import _backend + +__all__ = ['trilinear_devoxelize'] + + +class TrilinearDevoxelization(Function): + @staticmethod + def forward(ctx, features, coords, resolution, is_training=True): + """ + :param ctx: + :param coords: the coordinates of points, FloatTensor[B, 3, N] + :param features: FloatTensor[B, C, R, R, R] + :param resolution: int, the voxel resolution + :param is_training: bool, training mode + :return: + FloatTensor[B, C, N] + """ + B, C = features.shape[:2] + features = features.contiguous().view(B, C, -1) + coords = coords.contiguous() + outs, inds, wgts = _backend.trilinear_devoxelize_forward(resolution, is_training, coords, features) + if is_training: + ctx.save_for_backward(inds, wgts) + ctx.r = resolution + return outs + + @staticmethod + def backward(ctx, grad_output): + """ + :param ctx: + :param grad_output: gradient of outputs, FloatTensor[B, C, N] + :return: + gradient of inputs, FloatTensor[B, C, R, R, R] + """ + inds, wgts = ctx.saved_tensors + grad_inputs = _backend.trilinear_devoxelize_backward(grad_output.contiguous(), inds, wgts, ctx.r) + return grad_inputs.view(grad_output.size(0), grad_output.size(1), ctx.r, ctx.r, ctx.r), None, None, None + + +trilinear_devoxelize = TrilinearDevoxelization.apply diff --git a/modules/functional/grouping.py b/modules/functional/grouping.py new file mode 100644 index 0000000..72855ea --- /dev/null +++ b/modules/functional/grouping.py @@ -0,0 +1,31 @@ +from torch.autograd import Function + +from modules.functional.backend import _backend + +__all__ = ['grouping'] + + +class Grouping(Function): + @staticmethod + def forward(ctx, features, indices): + """ + :param ctx: + :param features: features of points, FloatTensor[B, C, N] + :param indices: neighbor indices of centers, IntTensor[B, M, U], M is #centers, U is #neighbors + :return: + grouped_features: grouped features, FloatTensor[B, C, M, U] + """ + features = features.contiguous() + indices = indices.contiguous() + ctx.save_for_backward(indices) + ctx.num_points = features.size(-1) + return _backend.grouping_forward(features, indices) + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + grad_features = _backend.grouping_backward(grad_output.contiguous(), indices, ctx.num_points) + return grad_features, None + + +grouping = Grouping.apply diff --git a/modules/functional/interpolatation.py b/modules/functional/interpolatation.py new file mode 100644 index 0000000..5a42425 --- /dev/null +++ b/modules/functional/interpolatation.py @@ -0,0 +1,38 @@ +from torch.autograd import Function + +from modules.functional.backend import _backend + +__all__ = ['nearest_neighbor_interpolate'] + + +class NeighborInterpolation(Function): + @staticmethod + def forward(ctx, points_coords, centers_coords, centers_features): + """ + :param ctx: + :param points_coords: coordinates of points, FloatTensor[B, 3, N] + :param centers_coords: coordinates of centers, FloatTensor[B, 3, M] + :param centers_features: features of centers, FloatTensor[B, C, M] + :return: + points_features: features of points, FloatTensor[B, C, N] + """ + centers_coords = centers_coords.contiguous() + points_coords = points_coords.contiguous() + centers_features = centers_features.contiguous() + points_features, indices, weights = _backend.three_nearest_neighbors_interpolate_forward( + points_coords, centers_coords, centers_features + ) + ctx.save_for_backward(indices, weights) + ctx.num_centers = centers_coords.size(-1) + return points_features + + @staticmethod + def backward(ctx, grad_output): + indices, weights = ctx.saved_tensors + grad_centers_features = _backend.three_nearest_neighbors_interpolate_backward( + grad_output.contiguous(), indices, weights, ctx.num_centers + ) + return None, None, grad_centers_features + + +nearest_neighbor_interpolate = NeighborInterpolation.apply diff --git a/modules/functional/loss.py b/modules/functional/loss.py new file mode 100644 index 0000000..41112b3 --- /dev/null +++ b/modules/functional/loss.py @@ -0,0 +1,17 @@ +import torch +import torch.nn.functional as F + +__all__ = ['kl_loss', 'huber_loss'] + + +def kl_loss(x, y): + x = F.softmax(x.detach(), dim=1) + y = F.log_softmax(y, dim=1) + return torch.mean(torch.sum(x * (torch.log(x) - y), dim=1)) + + +def huber_loss(error, delta): + abs_error = torch.abs(error) + quadratic = torch.min(abs_error, torch.full_like(abs_error, fill_value=delta)) + losses = 0.5 * (quadratic ** 2) + delta * (abs_error - quadratic) + return torch.mean(losses) diff --git a/modules/functional/sampling.py b/modules/functional/sampling.py new file mode 100644 index 0000000..160450b --- /dev/null +++ b/modules/functional/sampling.py @@ -0,0 +1,84 @@ +import numpy as np +import torch +from torch.autograd import Function + +from modules.functional.backend import _backend + +__all__ = ['gather', 'furthest_point_sample', 'logits_mask'] + + +class Gather(Function): + @staticmethod + def forward(ctx, features, indices): + """ + Gather + :param ctx: + :param features: features of points, FloatTensor[B, C, N] + :param indices: centers' indices in points, IntTensor[b, m] + :return: + centers_coords: coordinates of sampled centers, FloatTensor[B, C, M] + """ + features = features.contiguous() + indices = indices.int().contiguous() + ctx.save_for_backward(indices) + ctx.num_points = features.size(-1) + return _backend.gather_features_forward(features, indices) + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + grad_features = _backend.gather_features_backward(grad_output.contiguous(), indices, ctx.num_points) + return grad_features, None + + +gather = Gather.apply + + +def furthest_point_sample(coords, num_samples): + """ + Uses iterative furthest point sampling to select a set of npoint features that have the largest + minimum distance to the sampled point set + :param coords: coordinates of points, FloatTensor[B, 3, N] + :param num_samples: int, M + :return: + centers_coords: coordinates of sampled centers, FloatTensor[B, 3, M] + """ + coords = coords.contiguous() + indices = _backend.furthest_point_sampling(coords, num_samples) + return gather(coords, indices) + + +def logits_mask(coords, logits, num_points_per_object): + """ + Use logits to sample points + :param coords: coords of points, FloatTensor[B, 3, N] + :param logits: binary classification logits, FloatTensor[B, 2, N] + :param num_points_per_object: M, #points per object after masking, int + :return: + selected_coords: FloatTensor[B, 3, M] + masked_coords_mean: mean coords of selected points, FloatTensor[B, 3] + mask: mask to select points, BoolTensor[B, N] + """ + batch_size, _, num_points = coords.shape + mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N] + num_candidates = torch.sum(mask, dim=-1, keepdim=True) # [B, 1] + masked_coords = coords * mask.view(batch_size, 1, num_points) # [B, C, N] + masked_coords_mean = torch.sum(masked_coords, dim=-1) / torch.max(num_candidates, + torch.ones_like(num_candidates)).float() # [B, C] + selected_indices = torch.zeros((batch_size, num_points_per_object), device=coords.device, dtype=torch.int32) + for i in range(batch_size): + current_mask = mask[i] # [N] + current_candidates = current_mask.nonzero().view(-1) + current_num_candidates = current_candidates.numel() + if current_num_candidates >= num_points_per_object: + choices = np.random.choice(current_num_candidates, num_points_per_object, replace=False) + selected_indices[i] = current_candidates[choices] + elif current_num_candidates > 0: + choices = np.concatenate([ + np.arange(current_num_candidates).repeat(num_points_per_object // current_num_candidates), + np.random.choice(current_num_candidates, num_points_per_object % current_num_candidates, replace=False) + ]) + np.random.shuffle(choices) + selected_indices[i] = current_candidates[choices] + selected_coords = gather(masked_coords - masked_coords_mean.view(batch_size, -1, 1), selected_indices) + return selected_coords, masked_coords_mean, mask diff --git a/modules/functional/src/ball_query/ball_query.cpp b/modules/functional/src/ball_query/ball_query.cpp new file mode 100644 index 0000000..5ae1fb6 --- /dev/null +++ b/modules/functional/src/ball_query/ball_query.cpp @@ -0,0 +1,30 @@ +#include "ball_query.hpp" +#include "ball_query.cuh" + +#include "../utils.hpp" + +at::Tensor ball_query_forward(at::Tensor centers_coords, + at::Tensor points_coords, const float radius, + const int num_neighbors) { + CHECK_CUDA(centers_coords); + CHECK_CUDA(points_coords); + CHECK_CONTIGUOUS(centers_coords); + CHECK_CONTIGUOUS(points_coords); + CHECK_IS_FLOAT(centers_coords); + CHECK_IS_FLOAT(points_coords); + + int b = centers_coords.size(0); + int m = centers_coords.size(2); + int n = points_coords.size(2); + + at::Tensor neighbors_indices = torch::zeros( + {b, m, num_neighbors}, + at::device(centers_coords.device()).dtype(at::ScalarType::Int)); + + ball_query(b, n, m, radius * radius, num_neighbors, + centers_coords.data_ptr(), + points_coords.data_ptr(), + neighbors_indices.data_ptr()); + + return neighbors_indices; +} diff --git a/modules/functional/src/ball_query/ball_query.cu b/modules/functional/src/ball_query/ball_query.cu new file mode 100644 index 0000000..079e3cb --- /dev/null +++ b/modules/functional/src/ball_query/ball_query.cu @@ -0,0 +1,59 @@ +#include +#include +#include + +#include "../cuda_utils.cuh" + +/* + Function: ball query + Args: + b : batch size + n : number of points in point clouds + m : number of query centers + r2 : ball query radius ** 2 + u : maximum number of neighbors + centers_coords: coordinates of centers, FloatTensor[b, 3, m] + points_coords : coordinates of points, FloatTensor[b, 3, n] + neighbors_indices : neighbor indices in points, IntTensor[b, m, u] +*/ +__global__ void ball_query_kernel(int b, int n, int m, float r2, int u, + const float *__restrict__ centers_coords, + const float *__restrict__ points_coords, + int *__restrict__ neighbors_indices) { + int batch_index = blockIdx.x; + int index = threadIdx.x; + int stride = blockDim.x; + points_coords += batch_index * n * 3; + centers_coords += batch_index * m * 3; + neighbors_indices += batch_index * m * u; + + for (int j = index; j < m; j += stride) { + float center_x = centers_coords[j]; + float center_y = centers_coords[j + m]; + float center_z = centers_coords[j + m + m]; + for (int k = 0, cnt = 0; k < n && cnt < u; ++k) { + float dx = center_x - points_coords[k]; + float dy = center_y - points_coords[k + n]; + float dz = center_z - points_coords[k + n + n]; + float d2 = dx * dx + dy * dy + dz * dz; + if (d2 < r2) { + if (cnt == 0) { + for (int v = 0; v < u; ++v) { + neighbors_indices[j * u + v] = k; + } + } + neighbors_indices[j * u + cnt] = k; + ++cnt; + } + } + } +} + +void ball_query(int b, int n, int m, float r2, int u, + const float *centers_coords, const float *points_coords, + int *neighbors_indices) { + ball_query_kernel<<>>( + b, n, m, r2, u, centers_coords, points_coords, neighbors_indices); + CUDA_CHECK_ERRORS(); +} diff --git a/modules/functional/src/ball_query/ball_query.cuh b/modules/functional/src/ball_query/ball_query.cuh new file mode 100644 index 0000000..ba32492 --- /dev/null +++ b/modules/functional/src/ball_query/ball_query.cuh @@ -0,0 +1,8 @@ +#ifndef _BALL_QUERY_CUH +#define _BALL_QUERY_CUH + +void ball_query(int b, int n, int m, float r2, int u, + const float *centers_coords, const float *points_coords, + int *neighbors_indices); + +#endif diff --git a/modules/functional/src/ball_query/ball_query.hpp b/modules/functional/src/ball_query/ball_query.hpp new file mode 100644 index 0000000..d87bbd9 --- /dev/null +++ b/modules/functional/src/ball_query/ball_query.hpp @@ -0,0 +1,10 @@ +#ifndef _BALL_QUERY_HPP +#define _BALL_QUERY_HPP + +#include + +at::Tensor ball_query_forward(at::Tensor centers_coords, + at::Tensor points_coords, const float radius, + const int num_neighbors); + +#endif diff --git a/modules/functional/src/bindings.cpp b/modules/functional/src/bindings.cpp new file mode 100644 index 0000000..994e01b --- /dev/null +++ b/modules/functional/src/bindings.cpp @@ -0,0 +1,37 @@ +#include + +#include "ball_query/ball_query.hpp" +#include "grouping/grouping.hpp" +#include "interpolate/neighbor_interpolate.hpp" +#include "interpolate/trilinear_devox.hpp" +#include "sampling/sampling.hpp" +#include "voxelization/vox.hpp" + +PYBIND11_MODULE(_pvcnn_backend, m) { + m.def("gather_features_forward", &gather_features_forward, + "Gather Centers' Features forward (CUDA)"); + m.def("gather_features_backward", &gather_features_backward, + "Gather Centers' Features backward (CUDA)"); + m.def("furthest_point_sampling", &furthest_point_sampling_forward, + "Furthest Point Sampling (CUDA)"); + m.def("ball_query", &ball_query_forward, "Ball Query (CUDA)"); + m.def("grouping_forward", &grouping_forward, + "Grouping Features forward (CUDA)"); + m.def("grouping_backward", &grouping_backward, + "Grouping Features backward (CUDA)"); + m.def("three_nearest_neighbors_interpolate_forward", + &three_nearest_neighbors_interpolate_forward, + "3 Nearest Neighbors Interpolate forward (CUDA)"); + m.def("three_nearest_neighbors_interpolate_backward", + &three_nearest_neighbors_interpolate_backward, + "3 Nearest Neighbors Interpolate backward (CUDA)"); + + m.def("trilinear_devoxelize_forward", &trilinear_devoxelize_forward, + "Trilinear Devoxelization forward (CUDA)"); + m.def("trilinear_devoxelize_backward", &trilinear_devoxelize_backward, + "Trilinear Devoxelization backward (CUDA)"); + m.def("avg_voxelize_forward", &avg_voxelize_forward, + "Voxelization forward with average pooling (CUDA)"); + m.def("avg_voxelize_backward", &avg_voxelize_backward, + "Voxelization backward (CUDA)"); +} diff --git a/modules/functional/src/cuda_utils.cuh b/modules/functional/src/cuda_utils.cuh new file mode 100644 index 0000000..01bf551 --- /dev/null +++ b/modules/functional/src/cuda_utils.cuh @@ -0,0 +1,39 @@ +#ifndef _CUDA_UTILS_H +#define _CUDA_UTILS_H + +#include +#include +#include + +#include +#include + +#include + +#define MAXIMUM_THREADS 512 + +inline int optimal_num_threads(int work_size) { + const int pow_2 = std::log2(static_cast(work_size)); + return max(min(1 << pow_2, MAXIMUM_THREADS), 1); +} + +inline dim3 optimal_block_config(int x, int y) { + const int x_threads = optimal_num_threads(x); + const int y_threads = + max(min(optimal_num_threads(y), MAXIMUM_THREADS / x_threads), 1); + dim3 block_config(x_threads, y_threads, 1); + return block_config; +} + +#define CUDA_CHECK_ERRORS() \ + { \ + cudaError_t err = cudaGetLastError(); \ + if (cudaSuccess != err) { \ + fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ + cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ + __FILE__); \ + exit(-1); \ + } \ + } + +#endif diff --git a/modules/functional/src/grouping/grouping.cpp b/modules/functional/src/grouping/grouping.cpp new file mode 100644 index 0000000..4f97650 --- /dev/null +++ b/modules/functional/src/grouping/grouping.cpp @@ -0,0 +1,44 @@ +#include "grouping.hpp" +#include "grouping.cuh" + +#include "../utils.hpp" + +at::Tensor grouping_forward(at::Tensor features, at::Tensor indices) { + CHECK_CUDA(features); + CHECK_CUDA(indices); + CHECK_CONTIGUOUS(features); + CHECK_CONTIGUOUS(indices); + CHECK_IS_FLOAT(features); + CHECK_IS_INT(indices); + + int b = features.size(0); + int c = features.size(1); + int n = features.size(2); + int m = indices.size(1); + int u = indices.size(2); + at::Tensor output = torch::zeros( + {b, c, m, u}, at::device(features.device()).dtype(at::ScalarType::Float)); + grouping(b, c, n, m, u, features.data_ptr(), indices.data_ptr(), + output.data_ptr()); + return output; +} + +at::Tensor grouping_backward(at::Tensor grad_y, at::Tensor indices, + const int n) { + CHECK_CUDA(grad_y); + CHECK_CUDA(indices); + CHECK_CONTIGUOUS(grad_y); + CHECK_CONTIGUOUS(indices); + CHECK_IS_FLOAT(grad_y); + CHECK_IS_INT(indices); + + int b = grad_y.size(0); + int c = grad_y.size(1); + int m = indices.size(1); + int u = indices.size(2); + at::Tensor grad_x = torch::zeros( + {b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float)); + grouping_grad(b, c, n, m, u, grad_y.data_ptr(), + indices.data_ptr(), grad_x.data_ptr()); + return grad_x; +} diff --git a/modules/functional/src/grouping/grouping.cu b/modules/functional/src/grouping/grouping.cu new file mode 100644 index 0000000..0cf561a --- /dev/null +++ b/modules/functional/src/grouping/grouping.cu @@ -0,0 +1,85 @@ +#include +#include + +#include "../cuda_utils.cuh" + +/* + Function: grouping features of neighbors (forward) + Args: + b : batch size + c : #channles of features + n : number of points in point clouds + m : number of query centers + u : maximum number of neighbors + features: points' features, FloatTensor[b, c, n] + indices : neighbor indices in points, IntTensor[b, m, u] + out : gathered features, FloatTensor[b, c, m, u] +*/ +__global__ void grouping_kernel(int b, int c, int n, int m, int u, + const float *__restrict__ features, + const int *__restrict__ indices, + float *__restrict__ out) { + int batch_index = blockIdx.x; + features += batch_index * n * c; + indices += batch_index * m * u; + out += batch_index * m * u * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * m; i += stride) { + const int l = i / m; + const int j = i % m; + for (int k = 0; k < u; ++k) { + out[(l * m + j) * u + k] = features[l * n + indices[j * u + k]]; + } + } +} + +void grouping(int b, int c, int n, int m, int u, const float *features, + const int *indices, float *out) { + grouping_kernel<<>>(b, c, n, m, u, features, + indices, out); + CUDA_CHECK_ERRORS(); +} + +/* + Function: grouping features of neighbors (backward) + Args: + b : batch size + c : #channles of features + n : number of points in point clouds + m : number of query centers + u : maximum number of neighbors + grad_y : grad of gathered features, FloatTensor[b, c, m, u] + indices : neighbor indices in points, IntTensor[b, m, u] + grad_x: grad of points' features, FloatTensor[b, c, n] +*/ +__global__ void grouping_grad_kernel(int b, int c, int n, int m, int u, + const float *__restrict__ grad_y, + const int *__restrict__ indices, + float *__restrict__ grad_x) { + int batch_index = blockIdx.x; + grad_y += batch_index * m * u * c; + indices += batch_index * m * u; + grad_x += batch_index * n * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * m; i += stride) { + const int l = i / m; + const int j = i % m; + for (int k = 0; k < u; ++k) { + atomicAdd(grad_x + l * n + indices[j * u + k], + grad_y[(l * m + j) * u + k]); + } + } +} + +void grouping_grad(int b, int c, int n, int m, int u, const float *grad_y, + const int *indices, float *grad_x) { + grouping_grad_kernel<<>>( + b, c, n, m, u, grad_y, indices, grad_x); + CUDA_CHECK_ERRORS(); +} diff --git a/modules/functional/src/grouping/grouping.cuh b/modules/functional/src/grouping/grouping.cuh new file mode 100644 index 0000000..c8a114f --- /dev/null +++ b/modules/functional/src/grouping/grouping.cuh @@ -0,0 +1,9 @@ +#ifndef _GROUPING_CUH +#define _GROUPING_CUH + +void grouping(int b, int c, int n, int m, int u, const float *features, + const int *indices, float *out); +void grouping_grad(int b, int c, int n, int m, int u, const float *grad_y, + const int *indices, float *grad_x); + +#endif \ No newline at end of file diff --git a/modules/functional/src/grouping/grouping.hpp b/modules/functional/src/grouping/grouping.hpp new file mode 100644 index 0000000..3f5733d --- /dev/null +++ b/modules/functional/src/grouping/grouping.hpp @@ -0,0 +1,10 @@ +#ifndef _GROUPING_HPP +#define _GROUPING_HPP + +#include + +at::Tensor grouping_forward(at::Tensor features, at::Tensor indices); +at::Tensor grouping_backward(at::Tensor grad_y, at::Tensor indices, + const int n); + +#endif diff --git a/modules/functional/src/interpolate/neighbor_interpolate.cpp b/modules/functional/src/interpolate/neighbor_interpolate.cpp new file mode 100644 index 0000000..fc73c43 --- /dev/null +++ b/modules/functional/src/interpolate/neighbor_interpolate.cpp @@ -0,0 +1,65 @@ +#include "neighbor_interpolate.hpp" +#include "neighbor_interpolate.cuh" + +#include "../utils.hpp" + +std::vector +three_nearest_neighbors_interpolate_forward(at::Tensor points_coords, + at::Tensor centers_coords, + at::Tensor centers_features) { + CHECK_CUDA(points_coords); + CHECK_CUDA(centers_coords); + CHECK_CUDA(centers_features); + CHECK_CONTIGUOUS(points_coords); + CHECK_CONTIGUOUS(centers_coords); + CHECK_CONTIGUOUS(centers_features); + CHECK_IS_FLOAT(points_coords); + CHECK_IS_FLOAT(centers_coords); + CHECK_IS_FLOAT(centers_features); + + int b = centers_features.size(0); + int c = centers_features.size(1); + int m = centers_features.size(2); + int n = points_coords.size(2); + + at::Tensor indices = torch::zeros( + {b, 3, n}, at::device(points_coords.device()).dtype(at::ScalarType::Int)); + at::Tensor weights = torch::zeros( + {b, 3, n}, + at::device(points_coords.device()).dtype(at::ScalarType::Float)); + at::Tensor output = torch::zeros( + {b, c, n}, + at::device(centers_features.device()).dtype(at::ScalarType::Float)); + + three_nearest_neighbors_interpolate( + b, c, m, n, points_coords.data_ptr(), + centers_coords.data_ptr(), centers_features.data_ptr(), + indices.data_ptr(), weights.data_ptr(), + output.data_ptr()); + return {output, indices, weights}; +} + +at::Tensor three_nearest_neighbors_interpolate_backward(at::Tensor grad_y, + at::Tensor indices, + at::Tensor weights, + const int m) { + CHECK_CUDA(grad_y); + CHECK_CUDA(indices); + CHECK_CUDA(weights); + CHECK_CONTIGUOUS(grad_y); + CHECK_CONTIGUOUS(indices); + CHECK_CONTIGUOUS(weights); + CHECK_IS_FLOAT(grad_y); + CHECK_IS_INT(indices); + CHECK_IS_FLOAT(weights); + + int b = grad_y.size(0); + int c = grad_y.size(1); + int n = grad_y.size(2); + at::Tensor grad_x = torch::zeros( + {b, c, m}, at::device(grad_y.device()).dtype(at::ScalarType::Float)); + three_nearest_neighbors_interpolate_grad( + b, c, n, m, grad_y.data_ptr(), indices.data_ptr(), + weights.data_ptr(), grad_x.data_ptr()); + return grad_x; +} diff --git a/modules/functional/src/interpolate/neighbor_interpolate.cu b/modules/functional/src/interpolate/neighbor_interpolate.cu new file mode 100644 index 0000000..8168507 --- /dev/null +++ b/modules/functional/src/interpolate/neighbor_interpolate.cu @@ -0,0 +1,181 @@ +#include +#include +#include + +#include "../cuda_utils.cuh" + +/* + Function: three nearest neighbors + Args: + b : batch size + n : number of points in point clouds + m : number of query centers + points_coords : coordinates of points, FloatTensor[b, 3, n] + centers_coords: coordinates of centers, FloatTensor[b, 3, m] + weights : weights of nearest 3 centers to the point, + FloatTensor[b, 3, n] + indices : indices of nearest 3 centers to the point, + IntTensor[b, 3, n] +*/ +__global__ void three_nearest_neighbors_kernel( + int b, int n, int m, const float *__restrict__ points_coords, + const float *__restrict__ centers_coords, float *__restrict__ weights, + int *__restrict__ indices) { + int batch_index = blockIdx.x; + int index = threadIdx.x; + int stride = blockDim.x; + points_coords += batch_index * 3 * n; + weights += batch_index * 3 * n; + indices += batch_index * 3 * n; + centers_coords += batch_index * 3 * m; + + for (int j = index; j < n; j += stride) { + float ux = points_coords[j]; + float uy = points_coords[j + n]; + float uz = points_coords[j + n + n]; + + double best0 = 1e40, best1 = 1e40, best2 = 1e40; + int besti0 = 0, besti1 = 0, besti2 = 0; + for (int k = 0; k < m; ++k) { + float x = centers_coords[k]; + float y = centers_coords[k + m]; + float z = centers_coords[k + m + m]; + float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); + if (d < best2) { + best2 = d; + besti2 = k; + if (d < best1) { + best2 = best1; + besti2 = besti1; + best1 = d; + besti1 = k; + if (d < best0) { + best1 = best0; + besti1 = besti0; + best0 = d; + besti0 = k; + } + } + } + } + best0 = max(min(1e10f, best0), 1e-10f); + best1 = max(min(1e10f, best1), 1e-10f); + best2 = max(min(1e10f, best2), 1e-10f); + float d0d1 = best0 * best1; + float d0d2 = best0 * best2; + float d1d2 = best1 * best2; + float d0d1d2 = 1.0f / (d0d1 + d0d2 + d1d2); + weights[j] = d1d2 * d0d1d2; + indices[j] = besti0; + weights[j + n] = d0d2 * d0d1d2; + indices[j + n] = besti1; + weights[j + n + n] = d0d1 * d0d1d2; + indices[j + n + n] = besti2; + } +} + +/* + Function: interpolate three nearest neighbors (forward) + Args: + b : batch size + c : #channels of features + m : number of query centers + n : number of points in point clouds + centers_features: features of centers, FloatTensor[b, c, m] + indices : indices of nearest 3 centers to the point, + IntTensor[b, 3, n] + weights : weights for interpolation, FloatTensor[b, 3, n] + out : features of points, FloatTensor[b, c, n] +*/ +__global__ void three_nearest_neighbors_interpolate_kernel( + int b, int c, int m, int n, const float *__restrict__ centers_features, + const int *__restrict__ indices, const float *__restrict__ weights, + float *__restrict__ out) { + int batch_index = blockIdx.x; + centers_features += batch_index * m * c; + indices += batch_index * n * 3; + weights += batch_index * n * 3; + out += batch_index * n * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * n; i += stride) { + const int l = i / n; + const int j = i % n; + float w1 = weights[j]; + float w2 = weights[j + n]; + float w3 = weights[j + n + n]; + int i1 = indices[j]; + int i2 = indices[j + n]; + int i3 = indices[j + n + n]; + + out[i] = centers_features[l * m + i1] * w1 + + centers_features[l * m + i2] * w2 + + centers_features[l * m + i3] * w3; + } +} + +void three_nearest_neighbors_interpolate(int b, int c, int m, int n, + const float *points_coords, + const float *centers_coords, + const float *centers_features, + int *indices, float *weights, + float *out) { + three_nearest_neighbors_kernel<<>>( + b, n, m, points_coords, centers_coords, weights, indices); + three_nearest_neighbors_interpolate_kernel<<< + b, optimal_block_config(n, c), 0, at::cuda::getCurrentCUDAStream()>>>( + b, c, m, n, centers_features, indices, weights, out); + CUDA_CHECK_ERRORS(); +} + +/* + Function: interpolate three nearest neighbors (backward) + Args: + b : batch size + c : #channels of features + m : number of query centers + n : number of points in point clouds + grad_y : grad of features of points, FloatTensor[b, c, n] + indices : indices of nearest 3 centers to the point, IntTensor[b, 3, n] + weights : weights for interpolation, FloatTensor[b, 3, n] + grad_x : grad of features of centers, FloatTensor[b, c, m] +*/ +__global__ void three_nearest_neighbors_interpolate_grad_kernel( + int b, int c, int n, int m, const float *__restrict__ grad_y, + const int *__restrict__ indices, const float *__restrict__ weights, + float *__restrict__ grad_x) { + int batch_index = blockIdx.x; + grad_y += batch_index * n * c; + indices += batch_index * n * 3; + weights += batch_index * n * 3; + grad_x += batch_index * m * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * n; i += stride) { + const int l = i / n; + const int j = i % n; + float w1 = weights[j]; + float w2 = weights[j + n]; + float w3 = weights[j + n + n]; + int i1 = indices[j]; + int i2 = indices[j + n]; + int i3 = indices[j + n + n]; + atomicAdd(grad_x + l * m + i1, grad_y[i] * w1); + atomicAdd(grad_x + l * m + i2, grad_y[i] * w2); + atomicAdd(grad_x + l * m + i3, grad_y[i] * w3); + } +} + +void three_nearest_neighbors_interpolate_grad(int b, int c, int n, int m, + const float *grad_y, + const int *indices, + const float *weights, + float *grad_x) { + three_nearest_neighbors_interpolate_grad_kernel<<< + b, optimal_block_config(n, c), 0, at::cuda::getCurrentCUDAStream()>>>( + b, c, n, m, grad_y, indices, weights, grad_x); + CUDA_CHECK_ERRORS(); +} diff --git a/modules/functional/src/interpolate/neighbor_interpolate.cuh b/modules/functional/src/interpolate/neighbor_interpolate.cuh new file mode 100644 index 0000000..a15f37e --- /dev/null +++ b/modules/functional/src/interpolate/neighbor_interpolate.cuh @@ -0,0 +1,16 @@ +#ifndef _NEIGHBOR_INTERPOLATE_CUH +#define _NEIGHBOR_INTERPOLATE_CUH + +void three_nearest_neighbors_interpolate(int b, int c, int m, int n, + const float *points_coords, + const float *centers_coords, + const float *centers_features, + int *indices, float *weights, + float *out); +void three_nearest_neighbors_interpolate_grad(int b, int c, int n, int m, + const float *grad_y, + const int *indices, + const float *weights, + float *grad_x); + +#endif diff --git a/modules/functional/src/interpolate/neighbor_interpolate.hpp b/modules/functional/src/interpolate/neighbor_interpolate.hpp new file mode 100644 index 0000000..cdc7835 --- /dev/null +++ b/modules/functional/src/interpolate/neighbor_interpolate.hpp @@ -0,0 +1,16 @@ +#ifndef _NEIGHBOR_INTERPOLATE_HPP +#define _NEIGHBOR_INTERPOLATE_HPP + +#include +#include + +std::vector +three_nearest_neighbors_interpolate_forward(at::Tensor points_coords, + at::Tensor centers_coords, + at::Tensor centers_features); +at::Tensor three_nearest_neighbors_interpolate_backward(at::Tensor grad_y, + at::Tensor indices, + at::Tensor weights, + const int m); + +#endif diff --git a/modules/functional/src/interpolate/trilinear_devox.cpp b/modules/functional/src/interpolate/trilinear_devox.cpp new file mode 100644 index 0000000..a8ff4fc --- /dev/null +++ b/modules/functional/src/interpolate/trilinear_devox.cpp @@ -0,0 +1,91 @@ +#include "trilinear_devox.hpp" +#include "trilinear_devox.cuh" + +#include "../utils.hpp" + +/* + Function: trilinear devoxelization (forward) + Args: + r : voxel resolution + trainig : whether is training mode + coords : the coordinates of points, FloatTensor[b, 3, n] + features : features, FloatTensor[b, c, s], s = r ** 3 + Return: + outs : outputs, FloatTensor[b, c, n] + inds : the voxel coordinates of point cube, IntTensor[b, 8, n] + wgts : weight for trilinear interpolation, FloatTensor[b, 8, n] +*/ +std::vector +trilinear_devoxelize_forward(const int r, const bool is_training, + const at::Tensor coords, + const at::Tensor features) { + CHECK_CUDA(features); + CHECK_CUDA(coords); + CHECK_CONTIGUOUS(features); + CHECK_CONTIGUOUS(coords); + CHECK_IS_FLOAT(features); + CHECK_IS_FLOAT(coords); + + int b = features.size(0); + int c = features.size(1); + int n = coords.size(2); + int r2 = r * r; + int r3 = r2 * r; + at::Tensor outs = torch::zeros( + {b, c, n}, at::device(features.device()).dtype(at::ScalarType::Float)); + if (is_training) { + at::Tensor inds = torch::zeros( + {b, 8, n}, at::device(features.device()).dtype(at::ScalarType::Int)); + at::Tensor wgts = torch::zeros( + {b, 8, n}, at::device(features.device()).dtype(at::ScalarType::Float)); + trilinear_devoxelize(b, c, n, r, r2, r3, true, coords.data_ptr(), + features.data_ptr(), inds.data_ptr(), + wgts.data_ptr(), outs.data_ptr()); + return {outs, inds, wgts}; + } else { + at::Tensor inds = torch::zeros( + {1}, at::device(features.device()).dtype(at::ScalarType::Int)); + at::Tensor wgts = torch::zeros( + {1}, at::device(features.device()).dtype(at::ScalarType::Float)); + trilinear_devoxelize(b, c, n, r, r2, r3, false, coords.data_ptr(), + features.data_ptr(), inds.data_ptr(), + wgts.data_ptr(), outs.data_ptr()); + return {outs, inds, wgts}; + } +} + +/* + Function: trilinear devoxelization (backward) + Args: + grad_y : grad outputs, FloatTensor[b, c, n] + indices : the voxel coordinates of point cube, IntTensor[b, 8, n] + weights : weight for trilinear interpolation, FloatTensor[b, 8, n] + r : voxel resolution + Return: + grad_x : grad inputs, FloatTensor[b, c, s], s = r ** 3 +*/ +at::Tensor trilinear_devoxelize_backward(const at::Tensor grad_y, + const at::Tensor indices, + const at::Tensor weights, + const int r) { + CHECK_CUDA(grad_y); + CHECK_CUDA(weights); + CHECK_CUDA(indices); + CHECK_CONTIGUOUS(grad_y); + CHECK_CONTIGUOUS(weights); + CHECK_CONTIGUOUS(indices); + CHECK_IS_FLOAT(grad_y); + CHECK_IS_FLOAT(weights); + CHECK_IS_INT(indices); + + int b = grad_y.size(0); + int c = grad_y.size(1); + int n = grad_y.size(2); + int r3 = r * r * r; + at::Tensor grad_x = torch::zeros( + {b, c, r3}, at::device(grad_y.device()).dtype(at::ScalarType::Float)); + trilinear_devoxelize_grad(b, c, n, r3, indices.data_ptr(), + weights.data_ptr(), grad_y.data_ptr(), + grad_x.data_ptr()); + return grad_x; +} diff --git a/modules/functional/src/interpolate/trilinear_devox.cu b/modules/functional/src/interpolate/trilinear_devox.cu new file mode 100644 index 0000000..4e1e50c --- /dev/null +++ b/modules/functional/src/interpolate/trilinear_devox.cu @@ -0,0 +1,178 @@ +#include +#include + +#include "../cuda_utils.cuh" + +/* + Function: trilinear devoxlization (forward) + Args: + b : batch size + c : #channels + n : number of points + r : voxel resolution + r2 : r ** 2 + r3 : r ** 3 + coords : the coordinates of points, FloatTensor[b, 3, n] + feat : features, FloatTensor[b, c, r3] + inds : the voxel indices of point cube, IntTensor[b, 8, n] + wgts : weight for trilinear interpolation, FloatTensor[b, 8, n] + outs : outputs, FloatTensor[b, c, n] +*/ +__global__ void trilinear_devoxelize_kernel(int b, int c, int n, int r, int r2, + int r3, bool is_training, + const float *__restrict__ coords, + const float *__restrict__ feat, + int *__restrict__ inds, + float *__restrict__ wgts, + float *__restrict__ outs) { + int batch_index = blockIdx.x; + int stride = blockDim.x; + int index = threadIdx.x; + coords += batch_index * n * 3; + inds += batch_index * n * 8; + wgts += batch_index * n * 8; + feat += batch_index * c * r3; + outs += batch_index * c * n; + + for (int i = index; i < n; i += stride) { + float x = coords[i]; + float y = coords[i + n]; + float z = coords[i + n + n]; + float x_lo_f = floorf(x); + float y_lo_f = floorf(y); + float z_lo_f = floorf(z); + + float x_d_1 = x - x_lo_f; // / (x_hi_f - x_lo_f + 1e-8f) + float y_d_1 = y - y_lo_f; + float z_d_1 = z - z_lo_f; + float x_d_0 = 1.0f - x_d_1; + float y_d_0 = 1.0f - y_d_1; + float z_d_0 = 1.0f - z_d_1; + + float wgt000 = x_d_0 * y_d_0 * z_d_0; + float wgt001 = x_d_0 * y_d_0 * z_d_1; + float wgt010 = x_d_0 * y_d_1 * z_d_0; + float wgt011 = x_d_0 * y_d_1 * z_d_1; + float wgt100 = x_d_1 * y_d_0 * z_d_0; + float wgt101 = x_d_1 * y_d_0 * z_d_1; + float wgt110 = x_d_1 * y_d_1 * z_d_0; + float wgt111 = x_d_1 * y_d_1 * z_d_1; + + int x_lo = static_cast(x_lo_f); + int y_lo = static_cast(y_lo_f); + int z_lo = static_cast(z_lo_f); + int x_hi = (x_d_1 > 0) ? -1 : 0; + int y_hi = (y_d_1 > 0) ? -1 : 0; + int z_hi = (z_d_1 > 0) ? 1 : 0; + + int idx000 = x_lo * r2 + y_lo * r + z_lo; + int idx001 = idx000 + z_hi; // x_lo * r2 + y_lo * r + z_hi; + int idx010 = idx000 + (y_hi & r); // x_lo * r2 + y_hi * r + z_lo; + int idx011 = idx010 + z_hi; // x_lo * r2 + y_hi * r + z_hi; + int idx100 = idx000 + (x_hi & r2); // x_hi * r2 + y_lo * r + z_lo; + int idx101 = idx100 + z_hi; // x_hi * r2 + y_lo * r + z_hi; + int idx110 = idx100 + (y_hi & r); // x_hi * r2 + y_hi * r + z_lo; + int idx111 = idx110 + z_hi; // x_hi * r2 + y_hi * r + z_hi; + + if (is_training) { + wgts[i] = wgt000; + wgts[i + n] = wgt001; + wgts[i + n * 2] = wgt010; + wgts[i + n * 3] = wgt011; + wgts[i + n * 4] = wgt100; + wgts[i + n * 5] = wgt101; + wgts[i + n * 6] = wgt110; + wgts[i + n * 7] = wgt111; + inds[i] = idx000; + inds[i + n] = idx001; + inds[i + n * 2] = idx010; + inds[i + n * 3] = idx011; + inds[i + n * 4] = idx100; + inds[i + n * 5] = idx101; + inds[i + n * 6] = idx110; + inds[i + n * 7] = idx111; + } + + for (int j = 0; j < c; j++) { + int jr3 = j * r3; + outs[j * n + i] = + wgt000 * feat[jr3 + idx000] + wgt001 * feat[jr3 + idx001] + + wgt010 * feat[jr3 + idx010] + wgt011 * feat[jr3 + idx011] + + wgt100 * feat[jr3 + idx100] + wgt101 * feat[jr3 + idx101] + + wgt110 * feat[jr3 + idx110] + wgt111 * feat[jr3 + idx111]; + } + } +} + +/* + Function: trilinear devoxlization (backward) + Args: + b : batch size + c : #channels + n : number of points + r3 : voxel cube size = voxel resolution ** 3 + inds : the voxel indices of point cube, IntTensor[b, 8, n] + wgts : weight for trilinear interpolation, FloatTensor[b, 8, n] + grad_y : grad outputs, FloatTensor[b, c, n] + grad_x : grad inputs, FloatTensor[b, c, r3] +*/ +__global__ void trilinear_devoxelize_grad_kernel( + int b, int c, int n, int r3, const int *__restrict__ inds, + const float *__restrict__ wgts, const float *__restrict__ grad_y, + float *__restrict__ grad_x) { + int batch_index = blockIdx.x; + int stride = blockDim.x; + int index = threadIdx.x; + inds += batch_index * n * 8; + wgts += batch_index * n * 8; + grad_x += batch_index * c * r3; + grad_y += batch_index * c * n; + + for (int i = index; i < n; i += stride) { + int idx000 = inds[i]; + int idx001 = inds[i + n]; + int idx010 = inds[i + n * 2]; + int idx011 = inds[i + n * 3]; + int idx100 = inds[i + n * 4]; + int idx101 = inds[i + n * 5]; + int idx110 = inds[i + n * 6]; + int idx111 = inds[i + n * 7]; + float wgt000 = wgts[i]; + float wgt001 = wgts[i + n]; + float wgt010 = wgts[i + n * 2]; + float wgt011 = wgts[i + n * 3]; + float wgt100 = wgts[i + n * 4]; + float wgt101 = wgts[i + n * 5]; + float wgt110 = wgts[i + n * 6]; + float wgt111 = wgts[i + n * 7]; + + for (int j = 0; j < c; j++) { + int jr3 = j * r3; + float g = grad_y[j * n + i]; + atomicAdd(grad_x + jr3 + idx000, wgt000 * g); + atomicAdd(grad_x + jr3 + idx001, wgt001 * g); + atomicAdd(grad_x + jr3 + idx010, wgt010 * g); + atomicAdd(grad_x + jr3 + idx011, wgt011 * g); + atomicAdd(grad_x + jr3 + idx100, wgt100 * g); + atomicAdd(grad_x + jr3 + idx101, wgt101 * g); + atomicAdd(grad_x + jr3 + idx110, wgt110 * g); + atomicAdd(grad_x + jr3 + idx111, wgt111 * g); + } + } +} + +void trilinear_devoxelize(int b, int c, int n, int r, int r2, int r3, + bool training, const float *coords, const float *feat, + int *inds, float *wgts, float *outs) { + trilinear_devoxelize_kernel<<>>( + b, c, n, r, r2, r3, training, coords, feat, inds, wgts, outs); + CUDA_CHECK_ERRORS(); +} + +void trilinear_devoxelize_grad(int b, int c, int n, int r3, const int *inds, + const float *wgts, const float *grad_y, + float *grad_x) { + trilinear_devoxelize_grad_kernel<<>>( + b, c, n, r3, inds, wgts, grad_y, grad_x); + CUDA_CHECK_ERRORS(); +} diff --git a/modules/functional/src/interpolate/trilinear_devox.cuh b/modules/functional/src/interpolate/trilinear_devox.cuh new file mode 100644 index 0000000..8aadbaf --- /dev/null +++ b/modules/functional/src/interpolate/trilinear_devox.cuh @@ -0,0 +1,13 @@ +#ifndef _TRILINEAR_DEVOX_CUH +#define _TRILINEAR_DEVOX_CUH + +// CUDA function declarations +void trilinear_devoxelize(int b, int c, int n, int r, int r2, int r3, + bool is_training, const float *coords, + const float *feat, int *inds, float *wgts, + float *outs); +void trilinear_devoxelize_grad(int b, int c, int n, int r3, const int *inds, + const float *wgts, const float *grad_y, + float *grad_x); + +#endif diff --git a/modules/functional/src/interpolate/trilinear_devox.hpp b/modules/functional/src/interpolate/trilinear_devox.hpp new file mode 100644 index 0000000..a9d6795 --- /dev/null +++ b/modules/functional/src/interpolate/trilinear_devox.hpp @@ -0,0 +1,16 @@ +#ifndef _TRILINEAR_DEVOX_HPP +#define _TRILINEAR_DEVOX_HPP + +#include +#include + +std::vector trilinear_devoxelize_forward(const int r, + const bool is_training, + const at::Tensor coords, + const at::Tensor features); + +at::Tensor trilinear_devoxelize_backward(const at::Tensor grad_y, + const at::Tensor indices, + const at::Tensor weights, const int r); + +#endif diff --git a/modules/functional/src/sampling/sampling.cpp b/modules/functional/src/sampling/sampling.cpp new file mode 100644 index 0000000..9b8ca6e --- /dev/null +++ b/modules/functional/src/sampling/sampling.cpp @@ -0,0 +1,58 @@ +#include "sampling.hpp" +#include "sampling.cuh" + +#include "../utils.hpp" + +at::Tensor gather_features_forward(at::Tensor features, at::Tensor indices) { + CHECK_CUDA(features); + CHECK_CUDA(indices); + CHECK_CONTIGUOUS(features); + CHECK_CONTIGUOUS(indices); + CHECK_IS_FLOAT(features); + CHECK_IS_INT(indices); + + int b = features.size(0); + int c = features.size(1); + int n = features.size(2); + int m = indices.size(1); + at::Tensor output = torch::zeros( + {b, c, m}, at::device(features.device()).dtype(at::ScalarType::Float)); + gather_features(b, c, n, m, features.data_ptr(), + indices.data_ptr(), output.data_ptr()); + return output; +} + +at::Tensor gather_features_backward(at::Tensor grad_y, at::Tensor indices, + const int n) { + CHECK_CUDA(grad_y); + CHECK_CUDA(indices); + CHECK_CONTIGUOUS(grad_y); + CHECK_CONTIGUOUS(indices); + CHECK_IS_FLOAT(grad_y); + CHECK_IS_INT(indices); + + int b = grad_y.size(0); + int c = grad_y.size(1); + at::Tensor grad_x = torch::zeros( + {b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float)); + gather_features_grad(b, c, n, indices.size(1), grad_y.data_ptr(), + indices.data_ptr(), grad_x.data_ptr()); + return grad_x; +} + +at::Tensor furthest_point_sampling_forward(at::Tensor coords, + const int num_samples) { + CHECK_CUDA(coords); + CHECK_CONTIGUOUS(coords); + CHECK_IS_FLOAT(coords); + + int b = coords.size(0); + int n = coords.size(2); + at::Tensor indices = torch::zeros( + {b, num_samples}, at::device(coords.device()).dtype(at::ScalarType::Int)); + at::Tensor distances = torch::full( + {b, n}, 1e38f, at::device(coords.device()).dtype(at::ScalarType::Float)); + furthest_point_sampling(b, n, num_samples, coords.data_ptr(), + distances.data_ptr(), indices.data_ptr()); + return indices; +} diff --git a/modules/functional/src/sampling/sampling.cu b/modules/functional/src/sampling/sampling.cu new file mode 100644 index 0000000..06bc0ee --- /dev/null +++ b/modules/functional/src/sampling/sampling.cu @@ -0,0 +1,174 @@ +#include +#include + +#include "../cuda_utils.cuh" + +/* + Function: gather centers' features (forward) + Args: + b : batch size + c : #channles of features + n : number of points in point clouds + m : number of query/sampled centers + features: points' features, FloatTensor[b, c, n] + indices : centers' indices in points, IntTensor[b, m] + out : gathered features, FloatTensor[b, c, m] +*/ +__global__ void gather_features_kernel(int b, int c, int n, int m, + const float *__restrict__ features, + const int *__restrict__ indices, + float *__restrict__ out) { + int batch_index = blockIdx.x; + int channel_index = blockIdx.y; + int temp_index = batch_index * c + channel_index; + features += temp_index * n; + indices += batch_index * m; + out += temp_index * m; + + for (int j = threadIdx.x; j < m; j += blockDim.x) { + out[j] = features[indices[j]]; + } +} + +void gather_features(int b, int c, int n, int m, const float *features, + const int *indices, float *out) { + gather_features_kernel<<>>( + b, c, n, m, features, indices, out); + CUDA_CHECK_ERRORS(); +} + +/* + Function: gather centers' features (backward) + Args: + b : batch size + c : #channles of features + n : number of points in point clouds + m : number of query/sampled centers + grad_y : grad of gathered features, FloatTensor[b, c, m] + indices : centers' indices in points, IntTensor[b, m] + grad_x : grad of points' features, FloatTensor[b, c, n] +*/ +__global__ void gather_features_grad_kernel(int b, int c, int n, int m, + const float *__restrict__ grad_y, + const int *__restrict__ indices, + float *__restrict__ grad_x) { + int batch_index = blockIdx.x; + int channel_index = blockIdx.y; + int temp_index = batch_index * c + channel_index; + grad_y += temp_index * m; + indices += batch_index * m; + grad_x += temp_index * n; + + for (int j = threadIdx.x; j < m; j += blockDim.x) { + atomicAdd(grad_x + indices[j], grad_y[j]); + } +} + +void gather_features_grad(int b, int c, int n, int m, const float *grad_y, + const int *indices, float *grad_x) { + gather_features_grad_kernel<<>>( + b, c, n, m, grad_y, indices, grad_x); + CUDA_CHECK_ERRORS(); +} + +/* + Function: furthest point sampling + Args: + b : batch size + n : number of points in point clouds + m : number of query/sampled centers + coords : points' coords, FloatTensor[b, 3, n] + distances : minimum distance of a point to the set, IntTensor[b, n] + indices : sampled centers' indices in points, IntTensor[b, m] +*/ +__global__ void furthest_point_sampling_kernel(int b, int n, int m, + const float *__restrict__ coords, + float *__restrict__ distances, + int *__restrict__ indices) { + if (m <= 0) + return; + int batch_index = blockIdx.x; + coords += batch_index * n * 3; + distances += batch_index * n; + indices += batch_index * m; + + const int BlockSize = 512; + __shared__ float dists[BlockSize]; + __shared__ int dists_i[BlockSize]; + const int BufferSize = 3072; + __shared__ float buf[BufferSize * 3]; + + int old = 0; + if (threadIdx.x == 0) + indices[0] = old; + + for (int j = threadIdx.x; j < min(BufferSize, n); j += blockDim.x) { + buf[j] = coords[j]; + buf[j + BufferSize] = coords[j + n]; + buf[j + BufferSize + BufferSize] = coords[j + n + n]; + } + __syncthreads(); + + for (int j = 1; j < m; j++) { + int besti = 0; // best index + float best = -1; // farthest distance + // calculating the distance with the latest sampled point + float x1 = coords[old]; + float y1 = coords[old + n]; + float z1 = coords[old + n + n]; + for (int k = threadIdx.x; k < n; k += blockDim.x) { + // fetch distance at block n, thread k + float td = distances[k]; + float x2, y2, z2; + if (k < BufferSize) { + x2 = buf[k]; + y2 = buf[k + BufferSize]; + z2 = buf[k + BufferSize + BufferSize]; + } else { + x2 = coords[k]; + y2 = coords[k + n]; + z2 = coords[k + n + n]; + } + float d = + (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); + float d2 = min(d, td); + // update "point-to-set" distance + if (d2 != td) + distances[k] = d2; + // update the farthest distance at sample step j + if (d2 > best) { + best = d2; + besti = k; + } + } + + dists[threadIdx.x] = best; + dists_i[threadIdx.x] = besti; + for (int u = 0; (1 << u) < blockDim.x; u++) { + __syncthreads(); + if (threadIdx.x < (blockDim.x >> (u + 1))) { + int i1 = (threadIdx.x * 2) << u; + int i2 = (threadIdx.x * 2 + 1) << u; + if (dists[i1] < dists[i2]) { + dists[i1] = dists[i2]; + dists_i[i1] = dists_i[i2]; + } + } + } + __syncthreads(); + + // finish sample step j; old is the sampled index + old = dists_i[0]; + if (threadIdx.x == 0) + indices[j] = old; + } +} + +void furthest_point_sampling(int b, int n, int m, const float *coords, + float *distances, int *indices) { + furthest_point_sampling_kernel<<>>(b, n, m, coords, distances, + indices); + CUDA_CHECK_ERRORS(); +} diff --git a/modules/functional/src/sampling/sampling.cuh b/modules/functional/src/sampling/sampling.cuh new file mode 100644 index 0000000..e68358f --- /dev/null +++ b/modules/functional/src/sampling/sampling.cuh @@ -0,0 +1,11 @@ +#ifndef _SAMPLING_CUH +#define _SAMPLING_CUH + +void gather_features(int b, int c, int n, int m, const float *features, + const int *indices, float *out); +void gather_features_grad(int b, int c, int n, int m, const float *grad_y, + const int *indices, float *grad_x); +void furthest_point_sampling(int b, int n, int m, const float *coords, + float *distances, int *indices); + +#endif diff --git a/modules/functional/src/sampling/sampling.hpp b/modules/functional/src/sampling/sampling.hpp new file mode 100644 index 0000000..db2a5c8 --- /dev/null +++ b/modules/functional/src/sampling/sampling.hpp @@ -0,0 +1,12 @@ +#ifndef _SAMPLING_HPP +#define _SAMPLING_HPP + +#include + +at::Tensor gather_features_forward(at::Tensor features, at::Tensor indices); +at::Tensor gather_features_backward(at::Tensor grad_y, at::Tensor indices, + const int n); +at::Tensor furthest_point_sampling_forward(at::Tensor coords, + const int num_samples); + +#endif diff --git a/modules/functional/src/utils.hpp b/modules/functional/src/utils.hpp new file mode 100644 index 0000000..f4f21a0 --- /dev/null +++ b/modules/functional/src/utils.hpp @@ -0,0 +1,20 @@ +#ifndef _UTILS_HPP +#define _UTILS_HPP + +#include +#include + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") + +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") + +#define CHECK_IS_INT(x) \ + TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ + #x " must be an int tensor") + +#define CHECK_IS_FLOAT(x) \ + TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \ + #x " must be a float tensor") + +#endif diff --git a/modules/functional/src/voxelization/vox.cpp b/modules/functional/src/voxelization/vox.cpp new file mode 100644 index 0000000..6a84594 --- /dev/null +++ b/modules/functional/src/voxelization/vox.cpp @@ -0,0 +1,76 @@ +#include "vox.hpp" +#include "vox.cuh" + +#include "../utils.hpp" + +/* + Function: average pool voxelization (forward) + Args: + features: features, FloatTensor[b, c, n] + coords : coords of each point, IntTensor[b, 3, n] + resolution : voxel resolution + Return: + out : outputs, FloatTensor[b, c, s], s = r ** 3 + ind : voxel index of each point, IntTensor[b, n] + cnt : #points in each voxel index, IntTensor[b, s] +*/ +std::vector avg_voxelize_forward(const at::Tensor features, + const at::Tensor coords, + const int resolution) { + CHECK_CUDA(features); + CHECK_CUDA(coords); + CHECK_CONTIGUOUS(features); + CHECK_CONTIGUOUS(coords); + CHECK_IS_FLOAT(features); + CHECK_IS_INT(coords); + + int b = features.size(0); + int c = features.size(1); + int n = features.size(2); + int r = resolution; + int r2 = r * r; + int r3 = r2 * r; + at::Tensor ind = torch::zeros( + {b, n}, at::device(features.device()).dtype(at::ScalarType::Int)); + at::Tensor out = torch::zeros( + {b, c, r3}, at::device(features.device()).dtype(at::ScalarType::Float)); + at::Tensor cnt = torch::zeros( + {b, r3}, at::device(features.device()).dtype(at::ScalarType::Int)); + avg_voxelize(b, c, n, r, r2, r3, coords.data_ptr(), + features.data_ptr(), ind.data_ptr(), + cnt.data_ptr(), out.data_ptr()); + return {out, ind, cnt}; +} + +/* + Function: average pool voxelization (backward) + Args: + grad_y : grad outputs, FloatTensor[b, c, s] + indices: voxel index of each point, IntTensor[b, n] + cnt : #points in each voxel index, IntTensor[b, s] + Return: + grad_x : grad inputs, FloatTensor[b, c, n] +*/ +at::Tensor avg_voxelize_backward(const at::Tensor grad_y, + const at::Tensor indices, + const at::Tensor cnt) { + CHECK_CUDA(grad_y); + CHECK_CUDA(indices); + CHECK_CUDA(cnt); + CHECK_CONTIGUOUS(grad_y); + CHECK_CONTIGUOUS(indices); + CHECK_CONTIGUOUS(cnt); + CHECK_IS_FLOAT(grad_y); + CHECK_IS_INT(indices); + CHECK_IS_INT(cnt); + + int b = grad_y.size(0); + int c = grad_y.size(1); + int s = grad_y.size(2); + int n = indices.size(1); + at::Tensor grad_x = torch::zeros( + {b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float)); + avg_voxelize_grad(b, c, n, s, indices.data_ptr(), cnt.data_ptr(), + grad_y.data_ptr(), grad_x.data_ptr()); + return grad_x; +} diff --git a/modules/functional/src/voxelization/vox.cu b/modules/functional/src/voxelization/vox.cu new file mode 100644 index 0000000..1c1a2c9 --- /dev/null +++ b/modules/functional/src/voxelization/vox.cu @@ -0,0 +1,126 @@ +#include +#include + +#include "../cuda_utils.cuh" + +/* + Function: get how many points in each voxel grid + Args: + b : batch size + n : number of points + r : voxel resolution + r2 : = r * r + r3 : s, voxel cube size = r ** 3 + coords : coords of each point, IntTensor[b, 3, n] + ind : voxel index of each point, IntTensor[b, n] + cnt : #points in each voxel index, IntTensor[b, s] +*/ +__global__ void grid_stats_kernel(int b, int n, int r, int r2, int r3, + const int *__restrict__ coords, + int *__restrict__ ind, int *cnt) { + int batch_index = blockIdx.x; + int stride = blockDim.x; + int index = threadIdx.x; + coords += batch_index * n * 3; + ind += batch_index * n; + cnt += batch_index * r3; + + for (int i = index; i < n; i += stride) { + // if (ind[i] == -1) + // continue; + ind[i] = coords[i] * r2 + coords[i + n] * r + coords[i + n + n]; + atomicAdd(cnt + ind[i], 1); + } +} + +/* + Function: average pool voxelization (forward) + Args: + b : batch size + c : #channels + n : number of points + s : voxel cube size = voxel resolution ** 3 + ind : voxel index of each point, IntTensor[b, n] + cnt : #points in each voxel index, IntTensor[b, s] + feat: features, FloatTensor[b, c, n] + out : outputs, FloatTensor[b, c, s] +*/ +__global__ void avg_voxelize_kernel(int b, int c, int n, int s, + const int *__restrict__ ind, + const int *__restrict__ cnt, + const float *__restrict__ feat, + float *__restrict__ out) { + int batch_index = blockIdx.x; + int stride = blockDim.x; + int index = threadIdx.x; + ind += batch_index * n; + feat += batch_index * c * n; + out += batch_index * c * s; + cnt += batch_index * s; + for (int i = index; i < n; i += stride) { + int pos = ind[i]; + // if (pos == -1) + // continue; + int cur_cnt = cnt[pos]; + if (cur_cnt > 0) { + float div_cur_cnt = 1.0 / static_cast(cur_cnt); + for (int j = 0; j < c; j++) { + atomicAdd(out + j * s + pos, feat[j * n + i] * div_cur_cnt); + } + } + } +} + +/* + Function: average pool voxelization (backward) + Args: + b : batch size + c : #channels + n : number of points + r3 : voxel cube size = voxel resolution ** 3 + ind : voxel index of each point, IntTensor[b, n] + cnt : #points in each voxel index, IntTensor[b, s] + grad_y : grad outputs, FloatTensor[b, c, s] + grad_x : grad inputs, FloatTensor[b, c, n] +*/ +__global__ void avg_voxelize_grad_kernel(int b, int c, int n, int r3, + const int *__restrict__ ind, + const int *__restrict__ cnt, + const float *__restrict__ grad_y, + float *__restrict__ grad_x) { + int batch_index = blockIdx.x; + int stride = blockDim.x; + int index = threadIdx.x; + ind += batch_index * n; + grad_x += batch_index * c * n; + grad_y += batch_index * c * r3; + cnt += batch_index * r3; + for (int i = index; i < n; i += stride) { + int pos = ind[i]; + // if (pos == -1) + // continue; + int cur_cnt = cnt[pos]; + if (cur_cnt > 0) { + float div_cur_cnt = 1.0 / static_cast(cur_cnt); + for (int j = 0; j < c; j++) { + atomicAdd(grad_x + j * n + i, grad_y[j * r3 + pos] * div_cur_cnt); + } + } + } +} + +void avg_voxelize(int b, int c, int n, int r, int r2, int r3, const int *coords, + const float *feat, int *ind, int *cnt, float *out) { + grid_stats_kernel<<>>(b, n, r, r2, r3, coords, ind, + cnt); + avg_voxelize_kernel<<>>(b, c, n, r3, ind, cnt, + feat, out); + CUDA_CHECK_ERRORS(); +} + +void avg_voxelize_grad(int b, int c, int n, int s, const int *ind, + const int *cnt, const float *grad_y, float *grad_x) { + avg_voxelize_grad_kernel<<>>(b, c, n, s, ind, cnt, + grad_y, grad_x); + CUDA_CHECK_ERRORS(); +} diff --git a/modules/functional/src/voxelization/vox.cuh b/modules/functional/src/voxelization/vox.cuh new file mode 100644 index 0000000..9adb0fd --- /dev/null +++ b/modules/functional/src/voxelization/vox.cuh @@ -0,0 +1,10 @@ +#ifndef _VOX_CUH +#define _VOX_CUH + +// CUDA function declarations +void avg_voxelize(int b, int c, int n, int r, int r2, int r3, const int *coords, + const float *feat, int *ind, int *cnt, float *out); +void avg_voxelize_grad(int b, int c, int n, int s, const int *idx, + const int *cnt, const float *grad_y, float *grad_x); + +#endif diff --git a/modules/functional/src/voxelization/vox.hpp b/modules/functional/src/voxelization/vox.hpp new file mode 100644 index 0000000..6e62bc3 --- /dev/null +++ b/modules/functional/src/voxelization/vox.hpp @@ -0,0 +1,15 @@ +#ifndef _VOX_HPP +#define _VOX_HPP + +#include +#include + +std::vector avg_voxelize_forward(const at::Tensor features, + const at::Tensor coords, + const int resolution); + +at::Tensor avg_voxelize_backward(const at::Tensor grad_y, + const at::Tensor indices, + const at::Tensor cnt); + +#endif diff --git a/modules/functional/voxelization.py b/modules/functional/voxelization.py new file mode 100644 index 0000000..2452c68 --- /dev/null +++ b/modules/functional/voxelization.py @@ -0,0 +1,40 @@ +from torch.autograd import Function + +from modules.functional.backend import _backend + +__all__ = ['avg_voxelize'] + + +class AvgVoxelization(Function): + @staticmethod + def forward(ctx, features, coords, resolution): + """ + :param ctx: + :param features: Features of the point cloud, FloatTensor[B, C, N] + :param coords: Voxelized Coordinates of each point, IntTensor[B, 3, N] + :param resolution: Voxel resolution + :return: + Voxelized Features, FloatTensor[B, C, R, R, R] + """ + features = features.contiguous() + coords = coords.int().contiguous() + b, c, _ = features.shape + out, indices, counts = _backend.avg_voxelize_forward(features, coords, resolution) + ctx.save_for_backward(indices, counts) + return out.view(b, c, resolution, resolution, resolution) + + @staticmethod + def backward(ctx, grad_output): + """ + :param ctx: + :param grad_output: gradient of output, FloatTensor[B, C, R, R, R] + :return: + gradient of inputs, FloatTensor[B, C, N] + """ + b, c = grad_output.shape[:2] + indices, counts = ctx.saved_tensors + grad_features = _backend.avg_voxelize_backward(grad_output.contiguous().view(b, c, -1), indices, counts) + return grad_features, None, None + + +avg_voxelize = AvgVoxelization.apply diff --git a/modules/loss.py b/modules/loss.py new file mode 100644 index 0000000..173052d --- /dev/null +++ b/modules/loss.py @@ -0,0 +1,10 @@ +import torch.nn as nn + +import modules.functional as F + +__all__ = ['KLLoss'] + + +class KLLoss(nn.Module): + def forward(self, x, y): + return F.kl_loss(x, y) diff --git a/modules/pointnet.py b/modules/pointnet.py new file mode 100644 index 0000000..7925acf --- /dev/null +++ b/modules/pointnet.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn + +import modules.functional as F +from modules.ball_query import BallQuery +from modules.shared_mlp import SharedMLP + +__all__ = ['PointNetAModule', 'PointNetSAModule', 'PointNetFPModule'] + + +class PointNetAModule(nn.Module): + def __init__(self, in_channels, out_channels, include_coordinates=True): + super().__init__() + if not isinstance(out_channels, (list, tuple)): + out_channels = [[out_channels]] + elif not isinstance(out_channels[0], (list, tuple)): + out_channels = [out_channels] + + mlps = [] + total_out_channels = 0 + for _out_channels in out_channels: + mlps.append( + SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0), + out_channels=_out_channels, dim=1) + ) + total_out_channels += _out_channels[-1] + + self.include_coordinates = include_coordinates + self.out_channels = total_out_channels + self.mlps = nn.ModuleList(mlps) + + def forward(self, inputs): + features, coords = inputs + if self.include_coordinates: + features = torch.cat([features, coords], dim=1) + coords = torch.zeros((coords.size(0), 3, 1), device=coords.device) + if len(self.mlps) > 1: + features_list = [] + for mlp in self.mlps: + features_list.append(mlp(features).max(dim=-1, keepdim=True).values) + return torch.cat(features_list, dim=1), coords + else: + return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords + + def extra_repr(self): + return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}' + + +class PointNetSAModule(nn.Module): + def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True): + super().__init__() + if not isinstance(radius, (list, tuple)): + radius = [radius] + if not isinstance(num_neighbors, (list, tuple)): + num_neighbors = [num_neighbors] * len(radius) + assert len(radius) == len(num_neighbors) + if not isinstance(out_channels, (list, tuple)): + out_channels = [[out_channels]] * len(radius) + elif not isinstance(out_channels[0], (list, tuple)): + out_channels = [out_channels] * len(radius) + assert len(radius) == len(out_channels) + + groupers, mlps = [], [] + total_out_channels = 0 + for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors): + groupers.append( + BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates) + ) + mlps.append( + SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0), + out_channels=_out_channels, dim=2) + ) + total_out_channels += _out_channels[-1] + + self.num_centers = num_centers + self.out_channels = total_out_channels + self.groupers = nn.ModuleList(groupers) + self.mlps = nn.ModuleList(mlps) + + def forward(self, inputs): + features, coords, temb = inputs + centers_coords = F.furthest_point_sample(coords, self.num_centers) + features_list = [] + for grouper, mlp in zip(self.groupers, self.mlps): + features, temb = mlp(grouper(coords, centers_coords, temb, features)) + features_list.append(features.max(dim=-1).values) + if len(features_list) > 1: + return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb + else: + return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb + + def extra_repr(self): + return f'num_centers={self.num_centers}, out_channels={self.out_channels}' + + +class PointNetFPModule(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1) + + def forward(self, inputs): + if len(inputs) == 3: + points_coords, centers_coords, centers_features, temb = inputs + points_features = None + else: + points_coords, centers_coords, centers_features, points_features, temb = inputs + interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features) + interpolated_temb = F.nearest_neighbor_interpolate(points_coords, centers_coords, temb) + if points_features is not None: + interpolated_features = torch.cat( + [interpolated_features, points_features], dim=1 + ) + return self.mlp(interpolated_features), points_coords, interpolated_temb diff --git a/modules/pvconv.py b/modules/pvconv.py new file mode 100644 index 0000000..bcacfb0 --- /dev/null +++ b/modules/pvconv.py @@ -0,0 +1,132 @@ +import torch.nn as nn +import torch +import modules.functional as F +from modules.voxelization import Voxelization +from modules.shared_mlp import SharedMLP +from modules.se import SE3d + +__all__ = ['PVConv', 'Attention', 'Swish', 'PVConvReLU'] + + +class Swish(nn.Module): + def forward(self,x): + return x * torch.sigmoid(x) + + +class Attention(nn.Module): + def __init__(self, in_ch, num_groups, D=3): + super(Attention, self).__init__() + assert in_ch % num_groups == 0 + if D == 3: + self.q = nn.Conv3d(in_ch, in_ch, 1) + self.k = nn.Conv3d(in_ch, in_ch, 1) + self.v = nn.Conv3d(in_ch, in_ch, 1) + + self.out = nn.Conv3d(in_ch, in_ch, 1) + elif D == 1: + self.q = nn.Conv1d(in_ch, in_ch, 1) + self.k = nn.Conv1d(in_ch, in_ch, 1) + self.v = nn.Conv1d(in_ch, in_ch, 1) + + self.out = nn.Conv1d(in_ch, in_ch, 1) + + self.norm = nn.GroupNorm(num_groups, in_ch) + self.nonlin = Swish() + + self.sm = nn.Softmax(-1) + + + def forward(self, x): + B, C = x.shape[:2] + h = x + + + + + q = self.q(h).reshape(B,C,-1) + k = self.k(h).reshape(B,C,-1) + v = self.v(h).reshape(B,C,-1) + + qk = torch.matmul(q.permute(0, 2, 1), k) #* (int(C) ** (-0.5)) + + w = self.sm(qk) + + h = torch.matmul(v, w.permute(0, 2, 1)).reshape(B,C,*x.shape[2:]) + + h = self.out(h) + + x = h + x + + x = self.nonlin(self.norm(x)) + + return x + +class PVConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False, + dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.resolution = resolution + + self.voxelization = Voxelization(resolution, normalize=normalize, eps=eps) + voxel_layers = [ + nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), + nn.GroupNorm(num_groups=8, num_channels=out_channels), + Swish() + ] + voxel_layers += [nn.Dropout(dropout)] if dropout is not None else [] + voxel_layers += [ + nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), + nn.GroupNorm(num_groups=8, num_channels=out_channels), + Attention(out_channels, 8) if attention else Swish() + ] + if with_se: + voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu)) + self.voxel_layers = nn.Sequential(*voxel_layers) + self.point_features = SharedMLP(in_channels, out_channels) + + def forward(self, inputs): + features, coords, temb = inputs + voxel_features, voxel_coords = self.voxelization(features, coords) + voxel_features = self.voxel_layers(voxel_features) + voxel_features = F.trilinear_devoxelize(voxel_features, voxel_coords, self.resolution, self.training) + fused_features = voxel_features + self.point_features(features) + return fused_features, coords, temb + + + +class PVConvReLU(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False, leak=0.2, + dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.resolution = resolution + + self.voxelization = Voxelization(resolution, normalize=normalize, eps=eps) + voxel_layers = [ + nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), + nn.BatchNorm3d(out_channels), + nn.LeakyReLU(leak, True) + ] + voxel_layers += [nn.Dropout(dropout)] if dropout is not None else [] + voxel_layers += [ + nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), + nn.BatchNorm3d(out_channels), + Attention(out_channels, 8) if attention else nn.LeakyReLU(leak, True) + ] + if with_se: + voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu)) + self.voxel_layers = nn.Sequential(*voxel_layers) + self.point_features = SharedMLP(in_channels, out_channels) + + def forward(self, inputs): + features, coords, temb = inputs + voxel_features, voxel_coords = self.voxelization(features, coords) + voxel_features = self.voxel_layers(voxel_features) + voxel_features = F.trilinear_devoxelize(voxel_features, voxel_coords, self.resolution, self.training) + fused_features = voxel_features + self.point_features(features) + return fused_features, coords, temb diff --git a/modules/se.py b/modules/se.py new file mode 100644 index 0000000..c34eef7 --- /dev/null +++ b/modules/se.py @@ -0,0 +1,19 @@ +import torch.nn as nn +import torch +__all__ = ['SE3d'] + +class Swish(nn.Module): + def forward(self,x): + return x * torch.sigmoid(x) +class SE3d(nn.Module): + def __init__(self, channel, reduction=8, use_relu=False): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(True) if use_relu else Swish() , + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid() + ) + + def forward(self, inputs): + return inputs * self.fc(inputs.mean(-1).mean(-1).mean(-1)).view(inputs.shape[0], inputs.shape[1], 1, 1, 1) diff --git a/modules/shared_mlp.py b/modules/shared_mlp.py new file mode 100644 index 0000000..1fcc35e --- /dev/null +++ b/modules/shared_mlp.py @@ -0,0 +1,38 @@ +import torch.nn as nn +import torch + +__all__ = ['SharedMLP'] + + +class Swish(nn.Module): + def forward(self,x): + return x * torch.sigmoid(x) + +class SharedMLP(nn.Module): + def __init__(self, in_channels, out_channels, dim=1): + super().__init__() + if dim == 1: + conv = nn.Conv1d + bn = nn.GroupNorm + elif dim == 2: + conv = nn.Conv2d + bn = nn.GroupNorm + else: + raise ValueError + if not isinstance(out_channels, (list, tuple)): + out_channels = [out_channels] + layers = [] + for oc in out_channels: + layers.extend([ + conv(in_channels, oc, 1), + bn(8, oc), + Swish(), + ]) + in_channels = oc + self.layers = nn.Sequential(*layers) + + def forward(self, inputs): + if isinstance(inputs, (list, tuple)): + return (self.layers(inputs[0]), *inputs[1:]) + else: + return self.layers(inputs) diff --git a/modules/voxelization.py b/modules/voxelization.py new file mode 100644 index 0000000..7efc614 --- /dev/null +++ b/modules/voxelization.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn + +import modules.functional as F + +__all__ = ['Voxelization'] + + +class Voxelization(nn.Module): + def __init__(self, resolution, normalize=True, eps=0): + super().__init__() + self.r = int(resolution) + self.normalize = normalize + self.eps = eps + + def forward(self, features, coords): + coords = coords.detach() + norm_coords = coords - coords.mean(2, keepdim=True) + if self.normalize: + norm_coords = norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps) + 0.5 + else: + norm_coords = (norm_coords + 1) / 2.0 + norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1) + vox_coords = torch.round(norm_coords).to(torch.int32) + return F.avg_voxelize(features, vox_coords, self.r), norm_coords + + def extra_repr(self): + return 'resolution={}{}'.format(self.r, ', normalized eps = {}'.format(self.eps) if self.normalize else '') diff --git a/requirement_voxel.txt b/requirement_voxel.txt new file mode 100644 index 0000000..9814675 --- /dev/null +++ b/requirement_voxel.txt @@ -0,0 +1,26 @@ +conda: +python==3.6 +torch==1.4.0 +torchvision==0.5.0 +cudatoolkit==10.1 +kaolin==0.1.0 +pytorch3d==0.2.5 +lutorpy=1.3.7 +xmltodict=0.12.0 +numba=0.51.2 +pycuda=2019.1.2 +matplotlib + +pip: +torch-scatter==2.0.4 +torch-sparse==0.6.1 +torch-cluster==1.5.4 +torch-spline-conv==1.2.0 +descartes==1.1.0 +fire==0.3.1 +jupyter==1.0.0 +opencv_python==4.3.0 +Shapely==1.7.0 +Pillow==6.2.1 +torch_geometric==1.6.0 +open3d diff --git a/shape_completion/12_res32_partnet_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py b/shape_completion/12_res32_partnet_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py new file mode 100644 index 0000000..cd76baf --- /dev/null +++ b/shape_completion/12_res32_partnet_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py @@ -0,0 +1,825 @@ +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +import torch.utils.data + +import argparse +from model.unet import get_model +from torch.distributions import Normal + +from utils.file_utils import * +from utils.visualize import * +from model.pvcnn_completion import PVCNN2Base +import torch.distributed as dist +from datasets.partnet import GANdatasetPartNet +from datasets.shapenet_data_pc import ShapeNet15kPointClouds +''' +some utils +''' +def rotation_matrix(axis, theta): + """ + Return the rotation matrix associated with counterclockwise rotation about + the given axis by theta radians. + """ + axis = np.asarray(axis) + axis = axis / np.sqrt(np.dot(axis, axis)) + a = np.cos(theta / 2.0) + b, c, d = -axis * np.sin(theta / 2.0) + aa, bb, cc, dd = a * a, b * b, c * c, d * d + bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d + return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], + [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], + [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) + +def rotate(vertices, faces): + ''' + vertices: [numpoints, 3] + ''' + M = rotation_matrix([0, 1, 0], np.pi / 2).transpose() + N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose() + K = rotation_matrix([0, 0, 1], np.pi).transpose() + + v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]] + return v, f + +def norm(v, f): + v = (v - v.min())/(v.max() - v.min()) - 0.5 + + return v, f + +def getGradNorm(net): + pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters())) + gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters())) + return pNorm, gradNorm + + +def weights_init(m): + """ + xavier initialization + """ + classname = m.__class__.__name__ + if classname.find('Conv') != -1 and m.weight is not None: + torch.nn.init.xavier_normal_(m.weight) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_() + m.bias.data.fill_(0) + +''' +models +''' +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + KL divergence between normal distributions parameterized by mean and log-variance. + """ + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + (mean1 - mean2)**2 * torch.exp(-logvar2)) + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + # Assumes data is integers [0, 1] + assert x.shape == means.shape == log_scales.shape + px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 0.5) + cdf_plus = px0.cdf(plus_in) + min_in = inv_stdv * (centered_x - .5) + cdf_min = px0.cdf(min_in) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where( + x < 0.001, log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + assert log_probs.shape == x.shape + return log_probs + +class GaussianDiffusion: + def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): + self.loss_type = loss_type + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + assert isinstance(betas, np.ndarray) + self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy + assert (betas > 0).all() and (betas <= 1).all() + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.sv_points = sv_points + # initialize twice the actual length so we can keep running for eval + # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) + + alphas = 1. - betas + alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + + self.betas = torch.from_numpy(betas).float() + self.alphas_cumprod = alphas_cumprod.float() + self.alphas_cumprod_prev = alphas_cumprod_prev.float() + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + + betas = torch.from_numpy(betas).float() + alphas = torch.from_numpy(alphas).float() + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance = posterior_variance + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) + self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + + @staticmethod + def _extract(a, t, x_shape): + """ + Extract some coefficients at specified timesteps, + then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + bs, = t.shape + assert x_shape[0] == bs + out = torch.gather(a, 0, t) + assert out.shape == torch.Size([bs]) + return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) + + + + def q_mean_variance(self, x_start, t): + mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data (t == 0 means diffused for 1 step) + """ + if noise is None: + noise = torch.randn(x_start.shape, device=x_start.device) + assert noise.shape == x_start.shape + return ( + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + ) + + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) + posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + + model_output = denoise_fn(data, t)[:,:,self.sv_points:] + + + if self.model_var_type in ['fixedsmall', 'fixedlarge']: + # below: only log_variance is used in the KL computations + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood + 'fixedlarge': (self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), + 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + }[self.model_var_type] + model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) + model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) + else: + raise NotImplementedError(self.model_var_type) + + if self.model_mean_type == 'eps': + x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) + + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) + else: + raise NotImplementedError(self.loss_type) + + + assert model_mean.shape == x_recon.shape + assert model_variance.shape == model_log_variance.shape + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - + self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + ) + + ''' samples ''' + + def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): + """ + Sample from the model + """ + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, + return_pred_xstart=True) + noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) + + # no noise when t == 0 + nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) + + sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise + sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) + return (sample, pred_xstart) if return_pred_xstart else sample + + + def p_sample_loop(self, partial_x, denoise_fn, shape, device, + noise_fn=torch.randn, clip_denoised=True, keep_running=False): + """ + Generate samples + keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps + + """ + + assert isinstance(shape, (tuple, list)) + img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) + for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, return_pred_xstart=False) + + assert img_t[:,:,self.sv_points:].shape == shape + return img_t + + def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, + noise_fn=torch.randn,clip_denoised=True, keep_running=False): + """ + Generate samples, returning intermediate images + Useful for visualizing how denoised images evolve over time + Args: + repeat_noise_steps (int): Number of denoising timesteps in which the same noise + is used across the batch. If >= 0, the initial noise is the same for all batch elemements. + """ + assert isinstance(shape, (tuple, list)) + + total_steps = self.num_timesteps if not keep_running else len(self.betas) + + img_t = noise_fn(size=shape, dtype=torch.float, device=device) + imgs = [img_t] + for t in reversed(range(0,total_steps)): + + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False) + if t % freq == 0 or t == total_steps-1: + imgs.append(img_t) + + assert imgs[-1].shape == shape + return imgs + + '''losses''' + + def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( + denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) + + kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) + kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) + + return (kl, pred_xstart) if return_pred_xstart else kl + + def p_losses(self, denoise_fn, data_start, t, noise=None): + """ + Training loss calculation + """ + B, D, N = data_start.shape + assert t.shape == torch.Size([B]) + + if noise is None: + noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) + + data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) + + if self.loss_type == 'mse': + # predict the noise instead of x_start. seems to be weighted naturally like SNR + eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] + + losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) + elif self.loss_type == 'kl': + losses = self._vb_terms_bpd( + denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, + return_pred_xstart=False) + else: + raise NotImplementedError(self.loss_type) + + assert losses.shape == torch.Size([B]) + return losses + + '''debug''' + + def _prior_bpd(self, x_start): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, + mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) + assert kl_prior.shape == x_start.shape + return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) + + def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + + vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) + for t in reversed(range(T)): + + t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) + # Calculate VLB term at the current timestep + data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) + new_vals_b, pred_xstart = self._vb_terms_bpd( + denoise_fn, data_start=x_start, data_t=data_t, t=t_b, + clip_denoised=clip_denoised, return_pred_xstart=True) + # MSE for progressive prediction loss + assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape + new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) + assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) + # Insert the calculated term into the tensor of all terms + mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() + vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt + mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt + assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) + + prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) + total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b + assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ + total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) + return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() + + +class PVCNN2(PVCNN2Base): + sa_blocks = [ + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, + voxel_resolution_multiplier=1): + super().__init__( + num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, + dropout=dropout, extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + + +class Model(nn.Module): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + super(Model, self).__init__() + self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) + + self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, + dropout=args.dropout, extra_feature_channels=0) + + def prior_kl(self, x0): + return self.diffusion._prior_bpd(x0) + + def all_kl(self, x0, clip_denoised=True): + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + + return { + 'total_bpd_b': total_bpd_b, + 'terms_bpd': vals_bt, + 'prior_bpd_b': prior_bpd_b, + 'mse_bt':mse_bt + } + + + def _denoise(self, data, t): + B, D,N= data.shape + assert data.dtype == torch.float + assert t.shape == torch.Size([B]) and t.dtype == torch.int64 + + out = self.model(data, t) + + return out + + def get_loss_iter(self, data, noises=None): + B, D, N = data.shape + t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) + + if noises is not None: + noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + + losses = self.diffusion.p_losses( + denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + assert losses.shape == t.shape == torch.Size([B]) + return losses + + def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, + clip_denoised=True, + keep_running=False): + return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, + clip_denoised=clip_denoised, + keep_running=keep_running) + + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def multi_gpu_wrapper(self, f): + self.model = f(self.model) + + +def get_betas(schedule_type, b_start, b_end, time_num): + if schedule_type == 'linear': + betas = np.linspace(b_start, b_end, time_num) + elif schedule_type == 'warm0.1': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.1) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.2': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.2) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.5': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.5) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + else: + raise NotImplementedError(schedule_type) + return betas + + +def get_dataset(data_root, data_raw_root, pc_dataroot, npoints, category): + + train_ds = GANdatasetPartNet('train', data_root, category, npoints) + return train_ds + + +def get_dataloader(opt, train_dataset, test_dataset=None): + + if opt.distribution_type == 'multi': + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + num_replicas=opt.world_size, + rank=opt.rank + ) + if test_dataset is not None: + test_sampler = torch.utils.data.distributed.DistributedSampler( + test_dataset, + num_replicas=opt.world_size, + rank=opt.rank + ) + else: + test_sampler = None + else: + train_sampler = None + test_sampler = None + + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler, + shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True) + + if test_dataset is not None: + test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + else: + test_dataloader = None + + return train_dataloader, test_dataloader, train_sampler, test_sampler + + +def train(gpu, opt, output_dir, noises_init): + + set_seed(opt) + logger = setup_logging(output_dir) + if opt.distribution_type == 'multi': + should_diag = gpu==0 + else: + should_diag = True + if should_diag: + outf_syn, = setup_output_subdirs(output_dir, 'syn') + + if opt.distribution_type == 'multi': + if opt.dist_url == "env://" and opt.rank == -1: + opt.rank = int(os.environ["RANK"]) + + base_rank = opt.rank * opt.ngpus_per_node + opt.rank = base_rank + gpu + dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url, + world_size=opt.world_size, rank=opt.rank) + + opt.bs = int(opt.bs / opt.ngpus_per_node) + opt.workers = 0 + + opt.saveIter = int(opt.saveIter / opt.ngpus_per_node) + opt.diagIter = int(opt.diagIter / opt.ngpus_per_node) + opt.vizIter = int(opt.vizIter / opt.ngpus_per_node) + + + ''' data ''' + train_dataset = get_dataset(opt.data_root, opt.data_raw_root,opt.pc_dataroot, opt.npoints,opt.classes) + dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None) + + + ''' + create networks + ''' + + betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) + netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) + + if opt.distribution_type == 'multi': # Multiple processes, single GPU per process + def _transform_(m): + return nn.parallel.DistributedDataParallel( + m, device_ids=[gpu], output_device=gpu) + + torch.cuda.set_device(gpu) + netE.cuda(gpu) + netE.multi_gpu_wrapper(_transform_) + + + elif opt.distribution_type == 'single': + def _transform_(m): + return nn.parallel.DataParallel(m) + netE = netE.cuda() + netE.multi_gpu_wrapper(_transform_) + + elif gpu is not None: + torch.cuda.set_device(gpu) + netE = netE.cuda(gpu) + else: + raise ValueError('distribution_type = multi | single | None') + + if should_diag: + logger.info(opt) + + optimizer= optim.Adam(netE.parameters(), lr=opt.lrE, weight_decay=opt.e_decay, betas=(opt.beta1, 0.999)) + + lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.e_gamma) + + if opt.netE != '': + ckpt = torch.load(opt.netE) + netE.load_state_dict(ckpt['model_state']) + optimizer.load_state_dict(ckpt['optimizer_state']) + + if opt.netE != '': + start_epoch = torch.load(opt.netE)['epoch'] + 1 + else: + start_epoch = 0 + + + for epoch in range(start_epoch, opt.niter): + + if opt.distribution_type == 'multi': + train_sampler.set_epoch(epoch) + + lr_scheduler.step(epoch) + + for i, data in enumerate(dataloader): + + x = data['real'] + sv_x = data['raw'] + + sv_x = torch.cat([sv_x, x[:, :, opt.svpoints:]], dim=-1) + noises_batch = noises_init[data['idx']] + + ''' + train diffusion + ''' + + if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None): + sv_x = sv_x.cuda(gpu) + noises_batch = noises_batch.cuda(gpu) + elif opt.distribution_type == 'single': + sv_x = sv_x.cuda() + noises_batch = noises_batch.cuda() + + loss = netE.get_loss_iter(sv_x, noises_batch).mean() + + optimizer.zero_grad() + loss.backward() + netpNorm, netgradNorm = getGradNorm(netE) + if opt.grad_clip is not None: + torch.nn.utils.clip_grad_norm_(netE.parameters(), opt.grad_clip) + + optimizer.step() + + + if i % opt.print_freq == 0 and should_diag: + + logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, ' + 'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} ' + .format( + epoch, opt.niter, i, len(dataloader),loss.item(), + netpNorm, netgradNorm, + )) + + if (epoch + 1) % opt.diagIter == 0 and should_diag: + + logger.info('Diagnosis:') + + x_range = [x.min().item(), x.max().item()] + kl_stats = netE.all_kl(sv_x) + logger.info(' [{:>3d}/{:>3d}] ' + 'x_range: [{:>10.4f}, {:>10.4f}], ' + 'total_bpd_b: {:>10.4f}, ' + 'terms_bpd: {:>10.4f}, ' + 'prior_bpd_b: {:>10.4f} ' + 'mse_bt: {:>10.4f} ' + .format( + epoch, opt.niter, + *x_range, + kl_stats['total_bpd_b'].item(), + kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item() + )) + + + + if (epoch + 1) % opt.vizIter == 0 and should_diag: + logger.info('Generation: eval') + + netE.eval() + m, s = torch.tensor([[0,0,0]]).transpose(0,1)[None], torch.tensor([[1,1,1]]).transpose(0,1)[None]#train_dataset.get_pc_stats(0) + + with torch.no_grad(): + + x_gen_eval = netE.gen_samples(sv_x[:,:,:opt.svpoints], sv_x[:,:,opt.svpoints:].shape, sv_x.device, clip_denoised=False).detach().cpu() + + + gen_stats = [x_gen_eval.mean(), x_gen_eval.std()] + gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()] + + logger.info(' [{:>3d}/{:>3d}] ' + 'eval_gen_range: [{:>10.4f}, {:>10.4f}] ' + 'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] ' + .format( + epoch, opt.niter, + *gen_eval_range, *gen_stats, + )) + + export_to_pc_batch('%s/epoch_%03d_samples_eval' % (outf_syn, epoch), + (x_gen_eval*s+m).transpose(1, 2).numpy()*3) + + export_to_pc_batch('%s/epoch_%03d_ground_truth' % (outf_syn, epoch), + (sv_x.detach().cpu()*s+m).transpose(1, 2).numpy()*3) + + export_to_pc_batch('%s/epoch_%03d_partial' % (outf_syn, epoch), + (sv_x[:,:,:opt.svpoints].detach().cpu()*s+m).transpose(1, 2).numpy()*3) + + + netE.train() + + + + + + + + if (epoch + 1) % opt.saveIter == 0: + + if should_diag: + + + save_dict = { + 'epoch': epoch, + 'model_state': netE.state_dict(), + 'optimizer_state': optimizer.state_dict() + } + + torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch)) + + + if opt.distribution_type == 'multi': + dist.barrier() + map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu} + netE.load_state_dict( + torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state']) + + dist.destroy_process_group() + +def main(): + opt = parse_args() + + exp_id = os.path.splitext(os.path.basename(__file__))[0] + dir_id = os.path.dirname(__file__) + output_dir = get_output_dir(dir_id, exp_id) + copy_source(__file__, output_dir) + + ''' workaround ''' + + train_dataset = get_dataset(opt.data_root, opt.data_raw_root, opt.pc_dataroot, opt.npoints,opt.classes) + noises_init = torch.randn(len(train_dataset), opt.nc, opt.npoints-opt.svpoints) + + if opt.dist_url == "env://" and opt.world_size == -1: + opt.world_size = int(os.environ["WORLD_SIZE"]) + + if opt.distribution_type == 'multi': + opt.ngpus_per_node = torch.cuda.device_count() + opt.world_size = opt.ngpus_per_node * opt.world_size + mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init)) + else: + train(opt.gpu, opt, output_dir, noises_init) + + + +def parse_args(): + + parser = argparse.ArgumentParser() + parser.add_argument('--data_root', default='/home/ubuntu/01DATA/partnet/data_v0', help='input batch size') + parser.add_argument('--data_raw_root', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_dim32_sdf_pc', + help='input batch size') + parser.add_argument('--pc_dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', + help='input batch size') + parser.add_argument('--classes', default='Chair') + + parser.add_argument('--bs', type=int, default=64, help='input batch size') + parser.add_argument('--workers', type=int, default=16, help='workers') + parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + + parser.add_argument('--nc', default=3) + parser.add_argument('--npoints', default=2048) + parser.add_argument('--svpoints', default=1024) + '''model''' + parser.add_argument('--beta_start', default=0.0001) + parser.add_argument('--beta_end', default=0.02) + parser.add_argument('--schedule_type', default='linear') + parser.add_argument('--time_num', default=1000) + + #params + parser.add_argument('--attention', default=True) + parser.add_argument('--dropout', default=0.1) + parser.add_argument('--embed_dim', type=int, default=64) + parser.add_argument('--loss_type', default='mse') + parser.add_argument('--model_mean_type', default='eps') + parser.add_argument('--model_var_type', default='fixedsmall') + + parser.add_argument('--lrE', type=float, default=2e-4, help='learning rate for E, default=0.0002') + parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') + parser.add_argument('--e_decay', type=float, default=0, help='weight decay for EBM') + parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM') + parser.add_argument('--e_gamma', type=float, default=0.998, help='lr decay for EBM') + + parser.add_argument('--netE', default='', help="path to netE (to continue training)") + + + '''distributed''' + parser.add_argument('--world_size', default=1, type=int, + help='Number of distributed nodes.') + parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str, + help='url used to set up distributed training') + parser.add_argument('--dist_backend', default='nccl', type=str, + help='distributed backend') + parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None], + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') + parser.add_argument('--rank', default=0, type=int, + help='node rank for distributed training') + parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use. None means using all available GPUs.') + + '''eval''' + parser.add_argument('--saveIter', default=100, type=int,help='unit: epoch') + parser.add_argument('--diagIter', default=50, type=int, help='unit: epoch') + parser.add_argument('--vizIter', default=50, type=int,help='unit: epoch') + parser.add_argument('--print_freq', default=50, type=int,help='unit: iter') + + parser.add_argument('--manualSeed', default=42, type=int, help='random seed') + + + opt = parser.parse_args() + + return opt + +if __name__ == '__main__': + main() diff --git a/shape_completion/13_res32_partnet_table_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py b/shape_completion/13_res32_partnet_table_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py new file mode 100644 index 0000000..f7a39fa --- /dev/null +++ b/shape_completion/13_res32_partnet_table_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py @@ -0,0 +1,825 @@ +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +import torch.utils.data + +import argparse +from model.unet import get_model +from torch.distributions import Normal + +from utils.file_utils import * +from utils.visualize import * +from model.pvcnn_completion import PVCNN2Base +import torch.distributed as dist +from datasets.partnet import GANdatasetPartNet +from datasets.shapenet_data_pc import ShapeNet15kPointClouds +''' +some utils +''' +def rotation_matrix(axis, theta): + """ + Return the rotation matrix associated with counterclockwise rotation about + the given axis by theta radians. + """ + axis = np.asarray(axis) + axis = axis / np.sqrt(np.dot(axis, axis)) + a = np.cos(theta / 2.0) + b, c, d = -axis * np.sin(theta / 2.0) + aa, bb, cc, dd = a * a, b * b, c * c, d * d + bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d + return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], + [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], + [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) + +def rotate(vertices, faces): + ''' + vertices: [numpoints, 3] + ''' + M = rotation_matrix([0, 1, 0], np.pi / 2).transpose() + N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose() + K = rotation_matrix([0, 0, 1], np.pi).transpose() + + v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]] + return v, f + +def norm(v, f): + v = (v - v.min())/(v.max() - v.min()) - 0.5 + + return v, f + +def getGradNorm(net): + pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters())) + gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters())) + return pNorm, gradNorm + + +def weights_init(m): + """ + xavier initialization + """ + classname = m.__class__.__name__ + if classname.find('Conv') != -1 and m.weight is not None: + torch.nn.init.xavier_normal_(m.weight) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_() + m.bias.data.fill_(0) + +''' +models +''' +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + KL divergence between normal distributions parameterized by mean and log-variance. + """ + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + (mean1 - mean2)**2 * torch.exp(-logvar2)) + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + # Assumes data is integers [0, 1] + assert x.shape == means.shape == log_scales.shape + px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 0.5) + cdf_plus = px0.cdf(plus_in) + min_in = inv_stdv * (centered_x - .5) + cdf_min = px0.cdf(min_in) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where( + x < 0.001, log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + assert log_probs.shape == x.shape + return log_probs + +class GaussianDiffusion: + def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): + self.loss_type = loss_type + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + assert isinstance(betas, np.ndarray) + self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy + assert (betas > 0).all() and (betas <= 1).all() + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.sv_points = sv_points + # initialize twice the actual length so we can keep running for eval + # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) + + alphas = 1. - betas + alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + + self.betas = torch.from_numpy(betas).float() + self.alphas_cumprod = alphas_cumprod.float() + self.alphas_cumprod_prev = alphas_cumprod_prev.float() + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + + betas = torch.from_numpy(betas).float() + alphas = torch.from_numpy(alphas).float() + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance = posterior_variance + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) + self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + + @staticmethod + def _extract(a, t, x_shape): + """ + Extract some coefficients at specified timesteps, + then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + bs, = t.shape + assert x_shape[0] == bs + out = torch.gather(a, 0, t) + assert out.shape == torch.Size([bs]) + return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) + + + + def q_mean_variance(self, x_start, t): + mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data (t == 0 means diffused for 1 step) + """ + if noise is None: + noise = torch.randn(x_start.shape, device=x_start.device) + assert noise.shape == x_start.shape + return ( + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + ) + + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) + posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + + model_output = denoise_fn(data, t)[:,:,self.sv_points:] + + + if self.model_var_type in ['fixedsmall', 'fixedlarge']: + # below: only log_variance is used in the KL computations + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood + 'fixedlarge': (self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), + 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + }[self.model_var_type] + model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) + model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) + else: + raise NotImplementedError(self.model_var_type) + + if self.model_mean_type == 'eps': + x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) + + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) + else: + raise NotImplementedError(self.loss_type) + + + assert model_mean.shape == x_recon.shape + assert model_variance.shape == model_log_variance.shape + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - + self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + ) + + ''' samples ''' + + def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): + """ + Sample from the model + """ + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, + return_pred_xstart=True) + noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) + + # no noise when t == 0 + nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) + + sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise + sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) + return (sample, pred_xstart) if return_pred_xstart else sample + + + def p_sample_loop(self, partial_x, denoise_fn, shape, device, + noise_fn=torch.randn, clip_denoised=True, keep_running=False): + """ + Generate samples + keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps + + """ + + assert isinstance(shape, (tuple, list)) + img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) + for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, return_pred_xstart=False) + + assert img_t[:,:,self.sv_points:].shape == shape + return img_t + + def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, + noise_fn=torch.randn,clip_denoised=True, keep_running=False): + """ + Generate samples, returning intermediate images + Useful for visualizing how denoised images evolve over time + Args: + repeat_noise_steps (int): Number of denoising timesteps in which the same noise + is used across the batch. If >= 0, the initial noise is the same for all batch elemements. + """ + assert isinstance(shape, (tuple, list)) + + total_steps = self.num_timesteps if not keep_running else len(self.betas) + + img_t = noise_fn(size=shape, dtype=torch.float, device=device) + imgs = [img_t] + for t in reversed(range(0,total_steps)): + + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False) + if t % freq == 0 or t == total_steps-1: + imgs.append(img_t) + + assert imgs[-1].shape == shape + return imgs + + '''losses''' + + def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( + denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) + + kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) + kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) + + return (kl, pred_xstart) if return_pred_xstart else kl + + def p_losses(self, denoise_fn, data_start, t, noise=None): + """ + Training loss calculation + """ + B, D, N = data_start.shape + assert t.shape == torch.Size([B]) + + if noise is None: + noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) + + data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) + + if self.loss_type == 'mse': + # predict the noise instead of x_start. seems to be weighted naturally like SNR + eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] + + losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) + elif self.loss_type == 'kl': + losses = self._vb_terms_bpd( + denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, + return_pred_xstart=False) + else: + raise NotImplementedError(self.loss_type) + + assert losses.shape == torch.Size([B]) + return losses + + '''debug''' + + def _prior_bpd(self, x_start): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, + mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) + assert kl_prior.shape == x_start.shape + return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) + + def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + + vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) + for t in reversed(range(T)): + + t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) + # Calculate VLB term at the current timestep + data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) + new_vals_b, pred_xstart = self._vb_terms_bpd( + denoise_fn, data_start=x_start, data_t=data_t, t=t_b, + clip_denoised=clip_denoised, return_pred_xstart=True) + # MSE for progressive prediction loss + assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape + new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) + assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) + # Insert the calculated term into the tensor of all terms + mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() + vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt + mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt + assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) + + prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) + total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b + assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ + total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) + return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() + + +class PVCNN2(PVCNN2Base): + sa_blocks = [ + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, + voxel_resolution_multiplier=1): + super().__init__( + num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, + dropout=dropout, extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + + +class Model(nn.Module): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + super(Model, self).__init__() + self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) + + self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, + dropout=args.dropout, extra_feature_channels=0) + + def prior_kl(self, x0): + return self.diffusion._prior_bpd(x0) + + def all_kl(self, x0, clip_denoised=True): + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + + return { + 'total_bpd_b': total_bpd_b, + 'terms_bpd': vals_bt, + 'prior_bpd_b': prior_bpd_b, + 'mse_bt':mse_bt + } + + + def _denoise(self, data, t): + B, D,N= data.shape + assert data.dtype == torch.float + assert t.shape == torch.Size([B]) and t.dtype == torch.int64 + + out = self.model(data, t) + + return out + + def get_loss_iter(self, data, noises=None): + B, D, N = data.shape + t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) + + if noises is not None: + noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + + losses = self.diffusion.p_losses( + denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + assert losses.shape == t.shape == torch.Size([B]) + return losses + + def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, + clip_denoised=True, + keep_running=False): + return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, + clip_denoised=clip_denoised, + keep_running=keep_running) + + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def multi_gpu_wrapper(self, f): + self.model = f(self.model) + + +def get_betas(schedule_type, b_start, b_end, time_num): + if schedule_type == 'linear': + betas = np.linspace(b_start, b_end, time_num) + elif schedule_type == 'warm0.1': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.1) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.2': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.2) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.5': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.5) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + else: + raise NotImplementedError(schedule_type) + return betas + + +def get_dataset(data_root, data_raw_root, pc_dataroot, npoints, category): + + train_ds = GANdatasetPartNet('train', data_root, category, npoints) + return train_ds + + +def get_dataloader(opt, train_dataset, test_dataset=None): + + if opt.distribution_type == 'multi': + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + num_replicas=opt.world_size, + rank=opt.rank + ) + if test_dataset is not None: + test_sampler = torch.utils.data.distributed.DistributedSampler( + test_dataset, + num_replicas=opt.world_size, + rank=opt.rank + ) + else: + test_sampler = None + else: + train_sampler = None + test_sampler = None + + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler, + shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True) + + if test_dataset is not None: + test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + else: + test_dataloader = None + + return train_dataloader, test_dataloader, train_sampler, test_sampler + + +def train(gpu, opt, output_dir, noises_init): + + set_seed(opt) + logger = setup_logging(output_dir) + if opt.distribution_type == 'multi': + should_diag = gpu==0 + else: + should_diag = True + if should_diag: + outf_syn, = setup_output_subdirs(output_dir, 'syn') + + if opt.distribution_type == 'multi': + if opt.dist_url == "env://" and opt.rank == -1: + opt.rank = int(os.environ["RANK"]) + + base_rank = opt.rank * opt.ngpus_per_node + opt.rank = base_rank + gpu + dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url, + world_size=opt.world_size, rank=opt.rank) + + opt.bs = int(opt.bs / opt.ngpus_per_node) + opt.workers = 0 + + opt.saveIter = int(opt.saveIter / opt.ngpus_per_node) + opt.diagIter = int(opt.diagIter / opt.ngpus_per_node) + opt.vizIter = int(opt.vizIter / opt.ngpus_per_node) + + + ''' data ''' + train_dataset = get_dataset(opt.data_root, opt.data_raw_root,opt.pc_dataroot, opt.npoints,opt.classes) + dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None) + + + ''' + create networks + ''' + + betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) + netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) + + if opt.distribution_type == 'multi': # Multiple processes, single GPU per process + def _transform_(m): + return nn.parallel.DistributedDataParallel( + m, device_ids=[gpu], output_device=gpu) + + torch.cuda.set_device(gpu) + netE.cuda(gpu) + netE.multi_gpu_wrapper(_transform_) + + + elif opt.distribution_type == 'single': + def _transform_(m): + return nn.parallel.DataParallel(m) + netE = netE.cuda() + netE.multi_gpu_wrapper(_transform_) + + elif gpu is not None: + torch.cuda.set_device(gpu) + netE = netE.cuda(gpu) + else: + raise ValueError('distribution_type = multi | single | None') + + if should_diag: + logger.info(opt) + + optimizer= optim.Adam(netE.parameters(), lr=opt.lrE, weight_decay=opt.e_decay, betas=(opt.beta1, 0.999)) + + lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.e_gamma) + + if opt.netE != '': + ckpt = torch.load(opt.netE) + netE.load_state_dict(ckpt['model_state']) + optimizer.load_state_dict(ckpt['optimizer_state']) + + if opt.netE != '': + start_epoch = torch.load(opt.netE)['epoch'] + 1 + else: + start_epoch = 0 + + + for epoch in range(start_epoch, opt.niter): + + if opt.distribution_type == 'multi': + train_sampler.set_epoch(epoch) + + lr_scheduler.step(epoch) + + for i, data in enumerate(dataloader): + + x = data['real'] + sv_x = data['raw'] + + sv_x = torch.cat([sv_x, x[:, :, opt.svpoints:]], dim=-1) + noises_batch = noises_init[data['idx']] + + ''' + train diffusion + ''' + + if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None): + sv_x = sv_x.cuda(gpu) + noises_batch = noises_batch.cuda(gpu) + elif opt.distribution_type == 'single': + sv_x = sv_x.cuda() + noises_batch = noises_batch.cuda() + + loss = netE.get_loss_iter(sv_x, noises_batch).mean() + + optimizer.zero_grad() + loss.backward() + netpNorm, netgradNorm = getGradNorm(netE) + if opt.grad_clip is not None: + torch.nn.utils.clip_grad_norm_(netE.parameters(), opt.grad_clip) + + optimizer.step() + + + if i % opt.print_freq == 0 and should_diag: + + logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, ' + 'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} ' + .format( + epoch, opt.niter, i, len(dataloader),loss.item(), + netpNorm, netgradNorm, + )) + + if (epoch + 1) % opt.diagIter == 0 and should_diag: + + logger.info('Diagnosis:') + + x_range = [x.min().item(), x.max().item()] + kl_stats = netE.all_kl(sv_x) + logger.info(' [{:>3d}/{:>3d}] ' + 'x_range: [{:>10.4f}, {:>10.4f}], ' + 'total_bpd_b: {:>10.4f}, ' + 'terms_bpd: {:>10.4f}, ' + 'prior_bpd_b: {:>10.4f} ' + 'mse_bt: {:>10.4f} ' + .format( + epoch, opt.niter, + *x_range, + kl_stats['total_bpd_b'].item(), + kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item() + )) + + + + if (epoch + 1) % opt.vizIter == 0 and should_diag: + logger.info('Generation: eval') + + netE.eval() + m, s = torch.tensor([[0,0,0]]).transpose(0,1)[None], torch.tensor([[1,1,1]]).transpose(0,1)[None]#train_dataset.get_pc_stats(0) + + with torch.no_grad(): + + x_gen_eval = netE.gen_samples(sv_x[:,:,:opt.svpoints], sv_x[:,:,opt.svpoints:].shape, sv_x.device, clip_denoised=False).detach().cpu() + + + gen_stats = [x_gen_eval.mean(), x_gen_eval.std()] + gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()] + + logger.info(' [{:>3d}/{:>3d}] ' + 'eval_gen_range: [{:>10.4f}, {:>10.4f}] ' + 'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] ' + .format( + epoch, opt.niter, + *gen_eval_range, *gen_stats, + )) + + export_to_pc_batch('%s/epoch_%03d_samples_eval' % (outf_syn, epoch), + (x_gen_eval*s+m).transpose(1, 2).numpy()*3) + + export_to_pc_batch('%s/epoch_%03d_ground_truth' % (outf_syn, epoch), + (sv_x.detach().cpu()*s+m).transpose(1, 2).numpy()*3) + + export_to_pc_batch('%s/epoch_%03d_partial' % (outf_syn, epoch), + (sv_x[:,:,:opt.svpoints].detach().cpu()*s+m).transpose(1, 2).numpy()*3) + + + netE.train() + + + + + + + + if (epoch + 1) % opt.saveIter == 0: + + if should_diag: + + + save_dict = { + 'epoch': epoch, + 'model_state': netE.state_dict(), + 'optimizer_state': optimizer.state_dict() + } + + torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch)) + + + if opt.distribution_type == 'multi': + dist.barrier() + map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu} + netE.load_state_dict( + torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state']) + + dist.destroy_process_group() + +def main(): + opt = parse_args() + + exp_id = os.path.splitext(os.path.basename(__file__))[0] + dir_id = os.path.dirname(__file__) + output_dir = get_output_dir(dir_id, exp_id) + copy_source(__file__, output_dir) + + ''' workaround ''' + + train_dataset = get_dataset(opt.data_root, opt.data_raw_root, opt.pc_dataroot, opt.npoints,opt.classes) + noises_init = torch.randn(len(train_dataset), opt.nc, opt.npoints-opt.svpoints) + + if opt.dist_url == "env://" and opt.world_size == -1: + opt.world_size = int(os.environ["WORLD_SIZE"]) + + if opt.distribution_type == 'multi': + opt.ngpus_per_node = torch.cuda.device_count() + opt.world_size = opt.ngpus_per_node * opt.world_size + mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init)) + else: + train(opt.gpu, opt, output_dir, noises_init) + + + +def parse_args(): + + parser = argparse.ArgumentParser() + parser.add_argument('--data_root', default='/home/ubuntu/01DATA/partnet/', help='input batch size') + parser.add_argument('--data_raw_root', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_dim32_sdf_pc', + help='input batch size') + parser.add_argument('--pc_dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', + help='input batch size') + parser.add_argument('--classes', default='Table') + + parser.add_argument('--bs', type=int, default=64, help='input batch size') + parser.add_argument('--workers', type=int, default=16, help='workers') + parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + + parser.add_argument('--nc', default=3) + parser.add_argument('--npoints', default=2048) + parser.add_argument('--svpoints', default=1024) + '''model''' + parser.add_argument('--beta_start', default=0.0001) + parser.add_argument('--beta_end', default=0.02) + parser.add_argument('--schedule_type', default='linear') + parser.add_argument('--time_num', default=1000) + + #params + parser.add_argument('--attention', default=True) + parser.add_argument('--dropout', default=0.1) + parser.add_argument('--embed_dim', type=int, default=64) + parser.add_argument('--loss_type', default='mse') + parser.add_argument('--model_mean_type', default='eps') + parser.add_argument('--model_var_type', default='fixedsmall') + + parser.add_argument('--lrE', type=float, default=2e-4, help='learning rate for E, default=0.0002') + parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') + parser.add_argument('--e_decay', type=float, default=0, help='weight decay for EBM') + parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM') + parser.add_argument('--e_gamma', type=float, default=0.998, help='lr decay for EBM') + + parser.add_argument('--netE', default='', help="path to netE (to continue training)") + + + '''distributed''' + parser.add_argument('--world_size', default=1, type=int, + help='Number of distributed nodes.') + parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str, + help='url used to set up distributed training') + parser.add_argument('--dist_backend', default='nccl', type=str, + help='distributed backend') + parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None], + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') + parser.add_argument('--rank', default=0, type=int, + help='node rank for distributed training') + parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use. None means using all available GPUs.') + + '''eval''' + parser.add_argument('--saveIter', default=100, type=int,help='unit: epoch') + parser.add_argument('--diagIter', default=50, type=int, help='unit: epoch') + parser.add_argument('--vizIter', default=50, type=int,help='unit: epoch') + parser.add_argument('--print_freq', default=50, type=int,help='unit: iter') + + parser.add_argument('--manualSeed', default=42, type=int, help='random seed') + + + opt = parser.parse_args() + + return opt + +if __name__ == '__main__': + main() diff --git a/shape_completion/14_res32_partnet_lamp_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py b/shape_completion/14_res32_partnet_lamp_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py new file mode 100644 index 0000000..b8a23f5 --- /dev/null +++ b/shape_completion/14_res32_partnet_lamp_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py @@ -0,0 +1,822 @@ +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +import torch.utils.data + +import argparse +from model.unet import get_model +from torch.distributions import Normal + +from utils.file_utils import * +from utils.visualize import * +from model.pvcnn_completion import PVCNN2Base +import torch.distributed as dist +from datasets.partnet import GANdatasetPartNet +from datasets.shapenet_data_pc import ShapeNet15kPointClouds +''' +some utils +''' +def rotation_matrix(axis, theta): + """ + Return the rotation matrix associated with counterclockwise rotation about + the given axis by theta radians. + """ + axis = np.asarray(axis) + axis = axis / np.sqrt(np.dot(axis, axis)) + a = np.cos(theta / 2.0) + b, c, d = -axis * np.sin(theta / 2.0) + aa, bb, cc, dd = a * a, b * b, c * c, d * d + bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d + return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], + [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], + [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) + +def rotate(vertices, faces): + ''' + vertices: [numpoints, 3] + ''' + M = rotation_matrix([0, 1, 0], np.pi / 2).transpose() + N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose() + K = rotation_matrix([0, 0, 1], np.pi).transpose() + + v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]] + return v, f + +def norm(v, f): + v = (v - v.min())/(v.max() - v.min()) - 0.5 + + return v, f + +def getGradNorm(net): + pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters())) + gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters())) + return pNorm, gradNorm + + +def weights_init(m): + """ + xavier initialization + """ + classname = m.__class__.__name__ + if classname.find('Conv') != -1 and m.weight is not None: + torch.nn.init.xavier_normal_(m.weight) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_() + m.bias.data.fill_(0) + +''' +models +''' +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + KL divergence between normal distributions parameterized by mean and log-variance. + """ + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + (mean1 - mean2)**2 * torch.exp(-logvar2)) + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + # Assumes data is integers [0, 1] + assert x.shape == means.shape == log_scales.shape + px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 0.5) + cdf_plus = px0.cdf(plus_in) + min_in = inv_stdv * (centered_x - .5) + cdf_min = px0.cdf(min_in) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where( + x < 0.001, log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + assert log_probs.shape == x.shape + return log_probs + +class GaussianDiffusion: + def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): + self.loss_type = loss_type + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + assert isinstance(betas, np.ndarray) + self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy + assert (betas > 0).all() and (betas <= 1).all() + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.sv_points = sv_points + # initialize twice the actual length so we can keep running for eval + # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) + + alphas = 1. - betas + alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + + self.betas = torch.from_numpy(betas).float() + self.alphas_cumprod = alphas_cumprod.float() + self.alphas_cumprod_prev = alphas_cumprod_prev.float() + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + + betas = torch.from_numpy(betas).float() + alphas = torch.from_numpy(alphas).float() + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance = posterior_variance + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) + self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + + @staticmethod + def _extract(a, t, x_shape): + """ + Extract some coefficients at specified timesteps, + then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + bs, = t.shape + assert x_shape[0] == bs + out = torch.gather(a, 0, t) + assert out.shape == torch.Size([bs]) + return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) + + + + def q_mean_variance(self, x_start, t): + mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data (t == 0 means diffused for 1 step) + """ + if noise is None: + noise = torch.randn(x_start.shape, device=x_start.device) + assert noise.shape == x_start.shape + return ( + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + ) + + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) + posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + + model_output = denoise_fn(data, t)[:,:,self.sv_points:] + + + if self.model_var_type in ['fixedsmall', 'fixedlarge']: + # below: only log_variance is used in the KL computations + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood + 'fixedlarge': (self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), + 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + }[self.model_var_type] + model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) + model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) + else: + raise NotImplementedError(self.model_var_type) + + if self.model_mean_type == 'eps': + x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) + + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) + else: + raise NotImplementedError(self.loss_type) + + + assert model_mean.shape == x_recon.shape + assert model_variance.shape == model_log_variance.shape + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - + self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + ) + + ''' samples ''' + + def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): + """ + Sample from the model + """ + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, + return_pred_xstart=True) + noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) + + # no noise when t == 0 + nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) + + sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise + sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) + return (sample, pred_xstart) if return_pred_xstart else sample + + + def p_sample_loop(self, partial_x, denoise_fn, shape, device, + noise_fn=torch.randn, clip_denoised=True, keep_running=False): + """ + Generate samples + keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps + + """ + + assert isinstance(shape, (tuple, list)) + img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) + for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, return_pred_xstart=False) + + assert img_t[:,:,self.sv_points:].shape == shape + return img_t + + def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, + noise_fn=torch.randn,clip_denoised=True, keep_running=False): + """ + Generate samples, returning intermediate images + Useful for visualizing how denoised images evolve over time + Args: + repeat_noise_steps (int): Number of denoising timesteps in which the same noise + is used across the batch. If >= 0, the initial noise is the same for all batch elemements. + """ + assert isinstance(shape, (tuple, list)) + + total_steps = self.num_timesteps if not keep_running else len(self.betas) + + img_t = noise_fn(size=shape, dtype=torch.float, device=device) + imgs = [img_t] + for t in reversed(range(0,total_steps)): + + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False) + if t % freq == 0 or t == total_steps-1: + imgs.append(img_t) + + assert imgs[-1].shape == shape + return imgs + + '''losses''' + + def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( + denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) + + kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) + kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) + + return (kl, pred_xstart) if return_pred_xstart else kl + + def p_losses(self, denoise_fn, data_start, t, noise=None): + """ + Training loss calculation + """ + B, D, N = data_start.shape + assert t.shape == torch.Size([B]) + + if noise is None: + noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) + + data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) + + if self.loss_type == 'mse': + # predict the noise instead of x_start. seems to be weighted naturally like SNR + eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] + + losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) + elif self.loss_type == 'kl': + losses = self._vb_terms_bpd( + denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, + return_pred_xstart=False) + else: + raise NotImplementedError(self.loss_type) + + assert losses.shape == torch.Size([B]) + return losses + + '''debug''' + + def _prior_bpd(self, x_start): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, + mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) + assert kl_prior.shape == x_start.shape + return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) + + def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + + vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) + for t in reversed(range(T)): + + t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) + # Calculate VLB term at the current timestep + data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) + new_vals_b, pred_xstart = self._vb_terms_bpd( + denoise_fn, data_start=x_start, data_t=data_t, t=t_b, + clip_denoised=clip_denoised, return_pred_xstart=True) + # MSE for progressive prediction loss + assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape + new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) + assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) + # Insert the calculated term into the tensor of all terms + mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() + vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt + mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt + assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) + + prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) + total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b + assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ + total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) + return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() + + +class PVCNN2(PVCNN2Base): + sa_blocks = [ + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, + voxel_resolution_multiplier=1): + super().__init__( + num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, + dropout=dropout, extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + + +class Model(nn.Module): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + super(Model, self).__init__() + self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) + + self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, + dropout=args.dropout, extra_feature_channels=0) + + def prior_kl(self, x0): + return self.diffusion._prior_bpd(x0) + + def all_kl(self, x0, clip_denoised=True): + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + + return { + 'total_bpd_b': total_bpd_b, + 'terms_bpd': vals_bt, + 'prior_bpd_b': prior_bpd_b, + 'mse_bt':mse_bt + } + + + def _denoise(self, data, t): + B, D,N= data.shape + assert data.dtype == torch.float + assert t.shape == torch.Size([B]) and t.dtype == torch.int64 + + out = self.model(data, t) + + return out + + def get_loss_iter(self, data, noises=None): + B, D, N = data.shape + t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) + + if noises is not None: + noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + + losses = self.diffusion.p_losses( + denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + assert losses.shape == t.shape == torch.Size([B]) + return losses + + def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, + clip_denoised=True, + keep_running=False): + return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, + clip_denoised=clip_denoised, + keep_running=keep_running) + + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def multi_gpu_wrapper(self, f): + self.model = f(self.model) + + +def get_betas(schedule_type, b_start, b_end, time_num): + if schedule_type == 'linear': + betas = np.linspace(b_start, b_end, time_num) + elif schedule_type == 'warm0.1': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.1) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.2': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.2) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.5': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.5) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + else: + raise NotImplementedError(schedule_type) + return betas + + +def get_dataset(data_root, npoints, category): + + train_ds = GANdatasetPartNet('train', data_root, category, npoints) + return train_ds + + +def get_dataloader(opt, train_dataset, test_dataset=None): + + if opt.distribution_type == 'multi': + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + num_replicas=opt.world_size, + rank=opt.rank + ) + if test_dataset is not None: + test_sampler = torch.utils.data.distributed.DistributedSampler( + test_dataset, + num_replicas=opt.world_size, + rank=opt.rank + ) + else: + test_sampler = None + else: + train_sampler = None + test_sampler = None + + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler, + shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True) + + if test_dataset is not None: + test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + else: + test_dataloader = None + + return train_dataloader, test_dataloader, train_sampler, test_sampler + + +def train(gpu, opt, output_dir, noises_init): + + set_seed(opt) + logger = setup_logging(output_dir) + if opt.distribution_type == 'multi': + should_diag = gpu==0 + else: + should_diag = True + if should_diag: + outf_syn, = setup_output_subdirs(output_dir, 'syn') + + if opt.distribution_type == 'multi': + if opt.dist_url == "env://" and opt.rank == -1: + opt.rank = int(os.environ["RANK"]) + + base_rank = opt.rank * opt.ngpus_per_node + opt.rank = base_rank + gpu + dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url, + world_size=opt.world_size, rank=opt.rank) + + opt.bs = int(opt.bs / opt.ngpus_per_node) + opt.workers = 0 + + opt.saveIter = int(opt.saveIter / opt.ngpus_per_node) + opt.diagIter = int(opt.diagIter / opt.ngpus_per_node) + opt.vizIter = int(opt.vizIter / opt.ngpus_per_node) + + + ''' data ''' + train_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes) + dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None) + + + ''' + create networks + ''' + + betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) + netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) + + if opt.distribution_type == 'multi': # Multiple processes, single GPU per process + def _transform_(m): + return nn.parallel.DistributedDataParallel( + m, device_ids=[gpu], output_device=gpu) + + torch.cuda.set_device(gpu) + netE.cuda(gpu) + netE.multi_gpu_wrapper(_transform_) + + + elif opt.distribution_type == 'single': + def _transform_(m): + return nn.parallel.DataParallel(m) + netE = netE.cuda() + netE.multi_gpu_wrapper(_transform_) + + elif gpu is not None: + torch.cuda.set_device(gpu) + netE = netE.cuda(gpu) + else: + raise ValueError('distribution_type = multi | single | None') + + if should_diag: + logger.info(opt) + + optimizer= optim.Adam(netE.parameters(), lr=opt.lrE, weight_decay=opt.e_decay, betas=(opt.beta1, 0.999)) + + lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.e_gamma) + + if opt.netE != '': + ckpt = torch.load(opt.netE) + netE.load_state_dict(ckpt['model_state']) + optimizer.load_state_dict(ckpt['optimizer_state']) + + if opt.netE != '': + start_epoch = torch.load(opt.netE)['epoch'] + 1 + else: + start_epoch = 0 + + + for epoch in range(start_epoch, opt.niter): + + if opt.distribution_type == 'multi': + train_sampler.set_epoch(epoch) + + lr_scheduler.step(epoch) + + for i, data in enumerate(dataloader): + + x = data['real'] + sv_x = data['raw'] + + sv_x = torch.cat([sv_x, x[:, :, opt.svpoints:]], dim=-1) + noises_batch = noises_init[data['idx']] + + ''' + train diffusion + ''' + + if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None): + sv_x = sv_x.cuda(gpu) + noises_batch = noises_batch.cuda(gpu) + elif opt.distribution_type == 'single': + sv_x = sv_x.cuda() + noises_batch = noises_batch.cuda() + + loss = netE.get_loss_iter(sv_x, noises_batch).mean() + + optimizer.zero_grad() + loss.backward() + netpNorm, netgradNorm = getGradNorm(netE) + if opt.grad_clip is not None: + torch.nn.utils.clip_grad_norm_(netE.parameters(), opt.grad_clip) + + optimizer.step() + + + if i % opt.print_freq == 0 and should_diag: + + logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, ' + 'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} ' + .format( + epoch, opt.niter, i, len(dataloader),loss.item(), + netpNorm, netgradNorm, + )) + + if (epoch + 1) % opt.diagIter == 0 and should_diag: + + logger.info('Diagnosis:') + + x_range = [x.min().item(), x.max().item()] + kl_stats = netE.all_kl(sv_x) + logger.info(' [{:>3d}/{:>3d}] ' + 'x_range: [{:>10.4f}, {:>10.4f}], ' + 'total_bpd_b: {:>10.4f}, ' + 'terms_bpd: {:>10.4f}, ' + 'prior_bpd_b: {:>10.4f} ' + 'mse_bt: {:>10.4f} ' + .format( + epoch, opt.niter, + *x_range, + kl_stats['total_bpd_b'].item(), + kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item() + )) + + + + if (epoch + 1) % opt.vizIter == 0 and should_diag: + logger.info('Generation: eval') + + netE.eval() + m, s = torch.tensor([[0,0,0]]).transpose(0,1)[None], torch.tensor([[1,1,1]]).transpose(0,1)[None]#train_dataset.get_pc_stats(0) + + with torch.no_grad(): + + x_gen_eval = netE.gen_samples(sv_x[:,:,:opt.svpoints], sv_x[:,:,opt.svpoints:].shape, sv_x.device, clip_denoised=False).detach().cpu() + + + gen_stats = [x_gen_eval.mean(), x_gen_eval.std()] + gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()] + + logger.info(' [{:>3d}/{:>3d}] ' + 'eval_gen_range: [{:>10.4f}, {:>10.4f}] ' + 'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] ' + .format( + epoch, opt.niter, + *gen_eval_range, *gen_stats, + )) + + export_to_pc_batch('%s/epoch_%03d_samples_eval' % (outf_syn, epoch), + (x_gen_eval*s+m).transpose(1, 2).numpy()*3) + + export_to_pc_batch('%s/epoch_%03d_ground_truth' % (outf_syn, epoch), + (sv_x.detach().cpu()*s+m).transpose(1, 2).numpy()*3) + + export_to_pc_batch('%s/epoch_%03d_partial' % (outf_syn, epoch), + (sv_x[:,:,:opt.svpoints].detach().cpu()*s+m).transpose(1, 2).numpy()*3) + + + netE.train() + + + + + + + + if (epoch + 1) % opt.saveIter == 0: + + if should_diag: + + + save_dict = { + 'epoch': epoch, + 'model_state': netE.state_dict(), + 'optimizer_state': optimizer.state_dict() + } + + torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch)) + + + if opt.distribution_type == 'multi': + dist.barrier() + map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu} + netE.load_state_dict( + torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state']) + + dist.destroy_process_group() + +def main(): + opt = parse_args() + + exp_id = os.path.splitext(os.path.basename(__file__))[0] + dir_id = os.path.dirname(__file__) + output_dir = get_output_dir(dir_id, exp_id) + copy_source(__file__, output_dir) + + ''' workaround ''' + + train_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes) + noises_init = torch.randn(len(train_dataset), opt.nc, opt.npoints-opt.svpoints) + + if opt.dist_url == "env://" and opt.world_size == -1: + opt.world_size = int(os.environ["WORLD_SIZE"]) + + if opt.distribution_type == 'multi': + opt.ngpus_per_node = torch.cuda.device_count() + opt.world_size = opt.ngpus_per_node * opt.world_size + mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init)) + else: + train(opt.gpu, opt, output_dir, noises_init) + + + +def parse_args(): + + parser = argparse.ArgumentParser() + parser.add_argument('--data_root', default='/viscam/u/alexzhou907/01DATA/partnet/', help='input batch size') + + parser.add_argument('--classes', default='Table') + + parser.add_argument('--bs', type=int, default=64, help='input batch size') + parser.add_argument('--workers', type=int, default=16, help='workers') + parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + + parser.add_argument('--nc', default=3) + parser.add_argument('--npoints', default=2048) + parser.add_argument('--svpoints', default=1024) + '''model''' + parser.add_argument('--beta_start', default=0.0001) + parser.add_argument('--beta_end', default=0.02) + parser.add_argument('--schedule_type', default='linear') + parser.add_argument('--time_num', default=1000) + + #params + parser.add_argument('--attention', default=True) + parser.add_argument('--dropout', default=0.1) + parser.add_argument('--embed_dim', type=int, default=64) + parser.add_argument('--loss_type', default='mse') + parser.add_argument('--model_mean_type', default='eps') + parser.add_argument('--model_var_type', default='fixedsmall') + + parser.add_argument('--lrE', type=float, default=2e-4, help='learning rate for E, default=0.0002') + parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') + parser.add_argument('--e_decay', type=float, default=0, help='weight decay for EBM') + parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM') + parser.add_argument('--e_gamma', type=float, default=0.998, help='lr decay for EBM') + + parser.add_argument('--netE', default='', help="path to netE (to continue training)") + + + '''distributed''' + parser.add_argument('--world_size', default=1, type=int, + help='Number of distributed nodes.') + parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str, + help='url used to set up distributed training') + parser.add_argument('--dist_backend', default='nccl', type=str, + help='distributed backend') + parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None], + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') + parser.add_argument('--rank', default=0, type=int, + help='node rank for distributed training') + parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use. None means using all available GPUs.') + + '''eval''' + parser.add_argument('--saveIter', default=100, type=int,help='unit: epoch') + parser.add_argument('--diagIter', default=50, type=int, help='unit: epoch') + parser.add_argument('--vizIter', default=50, type=int,help='unit: epoch') + parser.add_argument('--print_freq', default=50, type=int,help='unit: iter') + + parser.add_argument('--manualSeed', default=42, type=int, help='random seed') + + + opt = parser.parse_args() + + return opt + +if __name__ == '__main__': + main() diff --git a/shape_completion/__init__.py b/shape_completion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/shape_completion/control_gen_chair.py b/shape_completion/control_gen_chair.py new file mode 100644 index 0000000..96a1387 --- /dev/null +++ b/shape_completion/control_gen_chair.py @@ -0,0 +1,660 @@ + +from pprint import pprint +from metrics.evaluation_metrics import compute_all_metrics, EMD_CD + +import torch.nn as nn +import torch.utils.data + +import argparse +from torch.distributions import Normal +from utils.visualize import * +from utils.file_utils import * +from utils.mitsuba_renderer import write_to_xml_batch +from model.pvcnn_completion import PVCNN2Base + +from datasets.shapenet_data_pc import ShapeNet15kPointClouds +from datasets.shapenet_data_sv import * +''' +models +''' +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + KL divergence between normal distributions parameterized by mean and log-variance. + """ + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + (mean1 - mean2)**2 * torch.exp(-logvar2)) + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + # Assumes data is integers [0, 1] + assert x.shape == means.shape == log_scales.shape + px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 0.5) + cdf_plus = px0.cdf(plus_in) + min_in = inv_stdv * (centered_x - .5) + cdf_min = px0.cdf(min_in) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where( + x < 0.001, log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + assert log_probs.shape == x.shape + return log_probs + + + +class GaussianDiffusion: + def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): + self.loss_type = loss_type + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + assert isinstance(betas, np.ndarray) + self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy + assert (betas > 0).all() and (betas <= 1).all() + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.sv_points = sv_points + # initialize twice the actual length so we can keep running for eval + # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) + + alphas = 1. - betas + alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + + self.betas = torch.from_numpy(betas).float() + self.alphas_cumprod = alphas_cumprod.float() + self.alphas_cumprod_prev = alphas_cumprod_prev.float() + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + + betas = torch.from_numpy(betas).float() + alphas = torch.from_numpy(alphas).float() + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance = posterior_variance + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) + self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + + @staticmethod + def _extract(a, t, x_shape): + """ + Extract some coefficients at specified timesteps, + then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + bs, = t.shape + assert x_shape[0] == bs + out = torch.gather(a, 0, t) + assert out.shape == torch.Size([bs]) + return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) + + + + def q_mean_variance(self, x_start, t): + mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data (t == 0 means diffused for 1 step) + """ + if noise is None: + noise = torch.randn(x_start.shape, device=x_start.device) + assert noise.shape == x_start.shape + return ( + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + ) + + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) + posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + + model_output = denoise_fn(data, t)[:,:,self.sv_points:] + + + if self.model_var_type in ['fixedsmall', 'fixedlarge']: + # below: only log_variance is used in the KL computations + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood + 'fixedlarge': (self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), + 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + }[self.model_var_type] + model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) + model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) + else: + raise NotImplementedError(self.model_var_type) + + if self.model_mean_type == 'eps': + x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) + + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) + else: + raise NotImplementedError(self.loss_type) + + + assert model_mean.shape == x_recon.shape + assert model_variance.shape == model_log_variance.shape + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - + self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + ) + + ''' samples ''' + + def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): + """ + Sample from the model + """ + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, + return_pred_xstart=True) + noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) + + # no noise when t == 0 + nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) + + sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise + sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) + return (sample, pred_xstart) if return_pred_xstart else sample + + + def p_sample_loop(self, partial_x, denoise_fn, shape, device, + noise_fn=torch.randn, clip_denoised=True, keep_running=False): + """ + Generate samples + keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps + + """ + + assert isinstance(shape, (tuple, list)) + img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) + for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, return_pred_xstart=False) + + assert img_t[:,:,self.sv_points:].shape == shape + return img_t + + def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, + noise_fn=torch.randn,clip_denoised=True, keep_running=False): + """ + Generate samples, returning intermediate images + Useful for visualizing how denoised images evolve over time + Args: + repeat_noise_steps (int): Number of denoising timesteps in which the same noise + is used across the batch. If >= 0, the initial noise is the same for all batch elemements. + """ + assert isinstance(shape, (tuple, list)) + + total_steps = self.num_timesteps if not keep_running else len(self.betas) + + img_t = noise_fn(size=shape, dtype=torch.float, device=device) + imgs = [img_t] + for t in reversed(range(0,total_steps)): + + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False) + if t % freq == 0 or t == total_steps-1: + imgs.append(img_t) + + assert imgs[-1].shape == shape + return imgs + + + + + '''losses''' + + def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( + denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) + + kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) + kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) + + return (kl, pred_xstart) if return_pred_xstart else kl + + def p_losses(self, denoise_fn, data_start, t, noise=None): + """ + Training loss calculation + """ + B, D, N = data_start.shape + assert t.shape == torch.Size([B]) + + if noise is None: + noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) + + data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) + + if self.loss_type == 'mse': + # predict the noise instead of x_start. seems to be weighted naturally like SNR + eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] + + losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) + elif self.loss_type == 'kl': + losses = self._vb_terms_bpd( + denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, + return_pred_xstart=False) + else: + raise NotImplementedError(self.loss_type) + + assert losses.shape == torch.Size([B]) + return losses + + '''debug''' + + def _prior_bpd(self, x_start): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, + mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) + assert kl_prior.shape == x_start.shape + return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) + + def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + + vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) + for t in reversed(range(T)): + + t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) + # Calculate VLB term at the current timestep + data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) + new_vals_b, pred_xstart = self._vb_terms_bpd( + denoise_fn, data_start=x_start, data_t=data_t, t=t_b, + clip_denoised=clip_denoised, return_pred_xstart=True) + # MSE for progressive prediction loss + assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape + new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) + assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) + # Insert the calculated term into the tensor of all terms + mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() + vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt + mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt + assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) + + prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) + total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b + assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ + total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) + return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() + + def interpolate(self, x0_part, x1_part, x0_sv, x1_sv, t, lamb, denoise_fn, noise_fn=torch.randn): + + assert t >= 1 + + t_vec = torch.empty(x0_part.shape[0], dtype=torch.int64, device=x0_part.device).fill_(t-1) + encoding0 = self.q_sample(x0_part, t_vec) + encoding1 = self.q_sample(x1_part, t_vec) + + enc = encoding0 * (1-lamb) + (lamb) * encoding1 + + img_t = torch.cat([torch.cat([x0_sv[:,:,:int(self.sv_points*(1-lamb))], x1_sv[:,:,:(self.sv_points - int(self.sv_points*(1-lamb)))]], dim=-1), enc], dim=-1) + + for k in reversed(range(0,t)): + t_ = torch.empty(img_t.shape[0], dtype=torch.int64, device=img_t.device).fill_(k) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=False, return_pred_xstart=False).detach() + + + return img_t + + +class PVCNN2(PVCNN2Base): + sa_blocks = [ + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, + voxel_resolution_multiplier=1): + super().__init__( + num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, + dropout=dropout, extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + + +class Model(nn.Module): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + super(Model, self).__init__() + self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) + + self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, + dropout=args.dropout, extra_feature_channels=0) + + def prior_kl(self, x0): + return self.diffusion._prior_bpd(x0) + + def all_kl(self, x0, clip_denoised=True): + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + + return { + 'total_bpd_b': total_bpd_b, + 'terms_bpd': vals_bt, + 'prior_bpd_b': prior_bpd_b, + 'mse_bt':mse_bt + } + + + def _denoise(self, data, t): + B, D,N= data.shape + assert data.dtype == torch.float + assert t.shape == torch.Size([B]) and t.dtype == torch.int64 + + out = self.model(data, t) + + return out + + def get_loss_iter(self, data, noises=None): + B, D, N = data.shape + t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) + + if noises is not None: + noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + + losses = self.diffusion.p_losses( + denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + assert losses.shape == t.shape == torch.Size([B]) + return losses + + def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, + clip_denoised=True, + keep_running=False): + return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, + clip_denoised=clip_denoised, + keep_running=keep_running) + + def interpolate(self, x0_part, x1_part, x0_sv, x1_sv, t, lamb): + + return self.diffusion.interpolate(x0_part, x1_part, x0_sv, x1_sv, t, lamb, self._denoise) + + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def multi_gpu_wrapper(self, f): + self.model = f(self.model) + +def get_betas(schedule_type, b_start, b_end, time_num): + if schedule_type == 'linear': + betas = np.linspace(b_start, b_end, time_num) + elif schedule_type == 'warm0.1': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.1) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.2': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.2) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.5': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.5) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + else: + raise NotImplementedError(schedule_type) + return betas + + + +############################################################################# +def get_svr_dataset(pc_dataroot, mesh_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, + cache=os.path.join(mesh_root, '../cache'), split='val', + categories=category, + radius=3, elev=-89, azim=180, img_size=512, focal_length=1000, + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return te_dataset + +def get_mvr_dataset(pc_dataroot, views_root, npoints,category, get_image=True): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, + cache=os.path.join(pc_dataroot, '../cache'), split='val', + categories=category, + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, get_image=get_image, + ) + return te_dataset + + +def generate_multimodal(opt, netE, save_dir, logger): + test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, + opt.npoints, opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + if i!=3: + continue + gt_all = data['test_points'] + x_all = data['sv_points'] + img_all = data['image'] + + for v in range(20): + + recons = [] + svs = [] + for p in [0,1]: + x = x_all[:,p].transpose(1, 2).contiguous() + img = img_all[:,p] + + m, s = data['mean'].float(), data['std'].float() + + recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', + clip_denoised=False).detach().cpu() + + recon = recon.transpose(1, 2).contiguous() + x = x.transpose(1, 2).contiguous() + + recons.append(recon) + svs.append(x[:, :opt.svpoints,:]) + + for l, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))): + + # im = np.fliplr(np.flipud(d[-1])) + write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,l), 'mode_%03d'%v, 'depth_%d'%p), + (torch.stack(d[:-1], dim=0)* s[0] + m[0]).numpy(), cat='chair') + + export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,l), 'mode_%03d'%v, 'depth_%d'%p), + (torch.stack(d[:-1], dim=0)* s[0] + m[0]).numpy()) + plt.imsave(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, l), 'mode_%03d' % v, 'depth_%d.png' % p), + d[-1].permute(1, 2, 0), cmap='gray') + + x0_part = recons[0].transpose(1, 2).contiguous()[:,:,opt.svpoints:].cuda() + x1_part = recons[1].transpose(1, 2).contiguous()[:,:,opt.svpoints:].cuda() + x0_sv = svs[0].transpose(1,2).cuda() + x1_sv = svs[1].transpose(1,2).cuda() + + interres = [] + for lamb in np.linspace(0.1, 0.9, 5): + res = netE.interpolate(x0_part, x1_part, x0_sv, x1_sv, 1000, lamb) + + res = torch.cat([x0_sv, x1_sv, res[:,:,opt.svpoints:]], dim=-1).detach().cpu().transpose(1,2).contiguous() + interres.append(res) + for l, d in enumerate(torch.stack(interres, dim=1)): + write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d' % (i, l), 'mode_%03d' % v), + (d* s[0] + m[0]).numpy(), cat='chair') + export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, l), 'mode_%03d' % v), + (d * s[0] + m[0]).numpy()) + +def main(opt): + exp_id = os.path.splitext(os.path.basename(__file__))[0] + dir_id = os.path.dirname(__file__) + output_dir = get_output_dir(dir_id, exp_id) + copy_source(__file__, output_dir) + logger = setup_logging(output_dir) + + outf_syn, = setup_output_subdirs(output_dir, 'syn') + + betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) + netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) + + if opt.cuda: + netE.cuda() + + def _transform_(m): + return nn.parallel.DataParallel(m) + + netE = netE.cuda() + netE.multi_gpu_wrapper(_transform_) + + netE.eval() + + ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] + + with torch.no_grad(): + for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): + + opt.netE = ckpt + logger.info("Resume Path:%s" % opt.netE) + + resumed_param = torch.load(opt.netE) + netE.load_state_dict(resumed_param['model_state']) + + + generate_multimodal(opt, netE, outf_syn, logger) + + +def parse_args(): + + parser = argparse.ArgumentParser() + parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') + parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', + help='input batch size') + parser.add_argument('--classes', default=['chair']) + + parser.add_argument('--batch_size', type=int, default=1, help='input batch size') + parser.add_argument('--workers', type=int, default=16, help='workers') + parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + + parser.add_argument('--eval_recon_mvr', default=False) + parser.add_argument('--generate_multimodal', default=False) + parser.add_argument('--eval_saved', default=False) + parser.add_argument('--eval_redwood', default=True) + + parser.add_argument('--nc', default=3) + parser.add_argument('--npoints', default=2048) + parser.add_argument('--svpoints', default=200) + '''model''' + parser.add_argument('--beta_start', default=0.0001) + parser.add_argument('--beta_end', default=0.02) + parser.add_argument('--schedule_type', default='linear') + parser.add_argument('--time_num', default=1000) + + #params + parser.add_argument('--attention', default=True) + parser.add_argument('--dropout', default=0.1) + parser.add_argument('--embed_dim', type=int, default=64) + parser.add_argument('--loss_type', default='mse') + parser.add_argument('--model_mean_type', default='eps') + parser.add_argument('--model_var_type', default='fixedsmall') + + + # constrain function + parser.add_argument('--constrain_eps', default=0.2) + parser.add_argument('--constrain_steps', type=int, default=1) + + parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/6_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-11-01-19-21-18', help="path to netE (to continue training)") + + '''eval''' + + parser.add_argument('--eval_path', + default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth') + + parser.add_argument('--manualSeed', default=None, type=int, help='random seed') + + parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') + + opt = parser.parse_args() + + if torch.cuda.is_available(): + opt.cuda = True + else: + opt.cuda = False + + return opt +if __name__ == '__main__': + opt = parse_args() + set_seed(opt) + main(opt) diff --git a/shape_completion/teaser_chair.py b/shape_completion/teaser_chair.py new file mode 100644 index 0000000..5d022e5 --- /dev/null +++ b/shape_completion/teaser_chair.py @@ -0,0 +1,706 @@ + + +import torch.nn as nn +import torch.utils.data + +import argparse +from torch.distributions import Normal +from utils.visualize import * +from utils.file_utils import * +from utils.mitsuba_renderer2 import write_to_xml_batch, write_to_xml +from model.pvcnn_completion import PVCNN2Base + +from datasets.shapenet_data_pc import ShapeNet15kPointClouds +from datasets.shapenet_data_sv import * +''' +models +''' +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + KL divergence between normal distributions parameterized by mean and log-variance. + """ + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + (mean1 - mean2)**2 * torch.exp(-logvar2)) + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + # Assumes data is integers [0, 1] + assert x.shape == means.shape == log_scales.shape + px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 0.5) + cdf_plus = px0.cdf(plus_in) + min_in = inv_stdv * (centered_x - .5) + cdf_min = px0.cdf(min_in) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where( + x < 0.001, log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + assert log_probs.shape == x.shape + return log_probs + + + +class GaussianDiffusion: + def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): + self.loss_type = loss_type + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + assert isinstance(betas, np.ndarray) + self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy + assert (betas > 0).all() and (betas <= 1).all() + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.sv_points = sv_points + # initialize twice the actual length so we can keep running for eval + # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) + + alphas = 1. - betas + alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + + self.betas = torch.from_numpy(betas).float() + self.alphas_cumprod = alphas_cumprod.float() + self.alphas_cumprod_prev = alphas_cumprod_prev.float() + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + + betas = torch.from_numpy(betas).float() + alphas = torch.from_numpy(alphas).float() + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance = posterior_variance + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) + self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + + @staticmethod + def _extract(a, t, x_shape): + """ + Extract some coefficients at specified timesteps, + then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + bs, = t.shape + assert x_shape[0] == bs + out = torch.gather(a, 0, t) + assert out.shape == torch.Size([bs]) + return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) + + + + def q_mean_variance(self, x_start, t): + mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data (t == 0 means diffused for 1 step) + """ + if noise is None: + noise = torch.randn(x_start.shape, device=x_start.device) + assert noise.shape == x_start.shape + return ( + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + ) + + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) + posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + + model_output = denoise_fn(data, t)[:,:,self.sv_points:] + + + if self.model_var_type in ['fixedsmall', 'fixedlarge']: + # below: only log_variance is used in the KL computations + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood + 'fixedlarge': (self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), + 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + }[self.model_var_type] + model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) + model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) + else: + raise NotImplementedError(self.model_var_type) + + if self.model_mean_type == 'eps': + x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) + + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) + else: + raise NotImplementedError(self.loss_type) + + + assert model_mean.shape == x_recon.shape + assert model_variance.shape == model_log_variance.shape + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - + self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + ) + + ''' samples ''' + + def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): + """ + Sample from the model + """ + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, + return_pred_xstart=True) + noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) + + # no noise when t == 0 + nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) + + sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise + sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) + return (sample, pred_xstart) if return_pred_xstart else sample + + + def p_sample_loop(self, partial_x, denoise_fn, shape, device, + noise_fn=torch.randn, clip_denoised=True, keep_running=False): + """ + Generate samples + keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps + + """ + + assert isinstance(shape, (tuple, list)) + img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) + for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, return_pred_xstart=False) + + assert img_t[:,:,self.sv_points:].shape == shape + return img_t + + + def p_sample_loop_trajectory2(self, partial_x, denoise_fn, shape, device, num_save, + noise_fn=torch.randn,clip_denoised=True, keep_running=False): + """ + Generate samples, returning intermediate images + Useful for visualizing how denoised images evolve over time + Args: + repeat_noise_steps (int): Number of denoising timesteps in which the same noise + is used across the batch. If >= 0, the initial noise is the same for all batch elemements. + """ + assert isinstance(shape, (tuple, list)) + + total_steps = self.num_timesteps if not keep_running else len(self.betas) + + scale = np.exp(np.log(1/total_steps)/num_save) + save_step = total_steps + + img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) + imgs = [img_t.detach().cpu()] + for t in reversed(range(0,total_steps)): + if (t+1) == save_step and t > 0 and len(imgs) 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + assert log_probs.shape == x.shape + return log_probs + + + +class GaussianDiffusion: + def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): + self.loss_type = loss_type + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + assert isinstance(betas, np.ndarray) + self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy + assert (betas > 0).all() and (betas <= 1).all() + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.sv_points = sv_points + # initialize twice the actual length so we can keep running for eval + # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) + + alphas = 1. - betas + alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + + self.betas = torch.from_numpy(betas).float() + self.alphas_cumprod = alphas_cumprod.float() + self.alphas_cumprod_prev = alphas_cumprod_prev.float() + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + + betas = torch.from_numpy(betas).float() + alphas = torch.from_numpy(alphas).float() + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance = posterior_variance + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) + self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + + @staticmethod + def _extract(a, t, x_shape): + """ + Extract some coefficients at specified timesteps, + then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + bs, = t.shape + assert x_shape[0] == bs + out = torch.gather(a, 0, t) + assert out.shape == torch.Size([bs]) + return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) + + + + def q_mean_variance(self, x_start, t): + mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data (t == 0 means diffused for 1 step) + """ + if noise is None: + noise = torch.randn(x_start.shape, device=x_start.device) + assert noise.shape == x_start.shape + return ( + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + ) + + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) + posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + + model_output = denoise_fn(data, t)[:,:,self.sv_points:] + + + if self.model_var_type in ['fixedsmall', 'fixedlarge']: + # below: only log_variance is used in the KL computations + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood + 'fixedlarge': (self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), + 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + }[self.model_var_type] + model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) + model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) + else: + raise NotImplementedError(self.model_var_type) + + if self.model_mean_type == 'eps': + x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) + + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) + else: + raise NotImplementedError(self.loss_type) + + + assert model_mean.shape == x_recon.shape + assert model_variance.shape == model_log_variance.shape + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - + self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + ) + + ''' samples ''' + + def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): + """ + Sample from the model + """ + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, + return_pred_xstart=True) + noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) + + # no noise when t == 0 + nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) + + sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise + sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) + return (sample, pred_xstart) if return_pred_xstart else sample + + + def p_sample_loop(self, partial_x, denoise_fn, shape, device, + noise_fn=torch.randn, clip_denoised=True, keep_running=False): + """ + Generate samples + keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps + + """ + + assert isinstance(shape, (tuple, list)) + img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) + for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, return_pred_xstart=False) + + assert img_t[:,:,self.sv_points:].shape == shape + return img_t + + def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, + noise_fn=torch.randn,clip_denoised=True, keep_running=False): + """ + Generate samples, returning intermediate images + Useful for visualizing how denoised images evolve over time + Args: + repeat_noise_steps (int): Number of denoising timesteps in which the same noise + is used across the batch. If >= 0, the initial noise is the same for all batch elemements. + """ + assert isinstance(shape, (tuple, list)) + + total_steps = self.num_timesteps if not keep_running else len(self.betas) + + img_t = noise_fn(size=shape, dtype=torch.float, device=device) + imgs = [img_t] + for t in reversed(range(0,total_steps)): + + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False) + if t % freq == 0 or t == total_steps-1: + imgs.append(img_t) + + assert imgs[-1].shape == shape + return imgs + + '''losses''' + + def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( + denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) + + kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) + kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) + + return (kl, pred_xstart) if return_pred_xstart else kl + + def p_losses(self, denoise_fn, data_start, t, noise=None): + """ + Training loss calculation + """ + B, D, N = data_start.shape + assert t.shape == torch.Size([B]) + + if noise is None: + noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) + + data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) + + if self.loss_type == 'mse': + # predict the noise instead of x_start. seems to be weighted naturally like SNR + eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] + + losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) + elif self.loss_type == 'kl': + losses = self._vb_terms_bpd( + denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, + return_pred_xstart=False) + else: + raise NotImplementedError(self.loss_type) + + assert losses.shape == torch.Size([B]) + return losses + + '''debug''' + + def _prior_bpd(self, x_start): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, + mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) + assert kl_prior.shape == x_start.shape + return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) + + def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + + vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) + for t in reversed(range(T)): + + t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) + # Calculate VLB term at the current timestep + data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) + new_vals_b, pred_xstart = self._vb_terms_bpd( + denoise_fn, data_start=x_start, data_t=data_t, t=t_b, + clip_denoised=clip_denoised, return_pred_xstart=True) + # MSE for progressive prediction loss + assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape + new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) + assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) + # Insert the calculated term into the tensor of all terms + mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() + vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt + mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt + assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) + + prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) + total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b + assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ + total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) + return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() + + +class PVCNN2(PVCNN2Base): + sa_blocks = [ + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, + voxel_resolution_multiplier=1): + super().__init__( + num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, + dropout=dropout, extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + + +class Model(nn.Module): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + super(Model, self).__init__() + self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) + + self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, + dropout=args.dropout, extra_feature_channels=0) + + def prior_kl(self, x0): + return self.diffusion._prior_bpd(x0) + + def all_kl(self, x0, clip_denoised=True): + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + + return { + 'total_bpd_b': total_bpd_b, + 'terms_bpd': vals_bt, + 'prior_bpd_b': prior_bpd_b, + 'mse_bt':mse_bt + } + + + def _denoise(self, data, t): + B, D,N= data.shape + assert data.dtype == torch.float + assert t.shape == torch.Size([B]) and t.dtype == torch.int64 + + out = self.model(data, t) + + return out + + def get_loss_iter(self, data, noises=None): + B, D, N = data.shape + t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) + + if noises is not None: + noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + + losses = self.diffusion.p_losses( + denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + assert losses.shape == t.shape == torch.Size([B]) + return losses + + def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, + clip_denoised=True, + keep_running=False): + return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, + clip_denoised=clip_denoised, + keep_running=keep_running) + + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def multi_gpu_wrapper(self, f): + self.model = f(self.model) + +def get_betas(schedule_type, b_start, b_end, time_num): + if schedule_type == 'linear': + betas = np.linspace(b_start, b_end, time_num) + elif schedule_type == 'warm0.1': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.1) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.2': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.2) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.5': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.5) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + else: + raise NotImplementedError(schedule_type) + return betas + + + +############################################################################# +def get_svr_dataset(pc_dataroot, mesh_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, + cache=os.path.join(mesh_root, '../cache'), split='val', + categories=category, + radius=3, elev=0, azim=0, img_size=512, focal_length=1000, + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return te_dataset + +def get_mvr_dataset(pc_dataroot, views_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, + cache=os.path.join(pc_dataroot, '../cache'), split='val', + categories=category, + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return te_dataset + + +def evaluate_recon_mvr(opt, netE, save_dir, logger): + test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, + opt.npoints, opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + ref = [] + samples = [] + images = [] + masked = [] + k = 0 + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + gt_all = data['test_points'] + x_all = data['sv_points'] + img_all = data['image'] + + + B,V,N,C = x_all.shape + gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1) + + x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous() + img = img_all.reshape(B * V, *img_all.shape[2:]) + + m, s = data['mean'].float(), data['std'].float() + + recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', + clip_denoised=False).detach().cpu() + + recon = recon.transpose(1, 2).contiguous() + x = x.transpose(1, 2).contiguous() + + + x_adj = x.reshape(B,V,N,C)* s + m + recon_adj = recon.reshape(B,V,N,C)* s + m + img = img.reshape(B,V,*img.shape[1:]) + + ref.append( gt_all * s + m) + masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) + samples.append(recon_adj) + images.append(img) + + ref_pcs = torch.cat(ref, dim=0) + sample_pcs = torch.cat(samples, dim=0) + images = torch.cat(images, dim=0) + masked = torch.cat(masked, dim=0) + + B, V, N, C = ref_pcs.shape + + + torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth')) + # torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) + torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) + # Compute metrics + results = EMD_CD(sample_pcs.reshape(B*V, N, C), + ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) + + results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} + + pprint({key: val.mean().item() for key, val in results.items()}) + logger.info({key: val.mean().item() for key, val in results.items()}) + + results['pc'] = sample_pcs + torch.save(results, os.path.join(save_dir, 'ours_results.pth')) + + + +def generate_multimodal(opt, netE, save_dir, logger): + test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv, + opt.npoints, opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + gt_all = data['test_points'] + x_all = data['sv_points'] + img_all = data['image'] + + Path(os.path.join(save_dir, 'x_%03d'%k)).mkdir(exist_ok=True, parents=True) + Path(os.path.join(save_dir, 'x_ply_%03d' % k)).mkdir(exist_ok=True, parents=True) + + for v in range(5): + x = x_all.transpose(1, 2).contiguous() + img = img_all + + m, s = data['mean'].float(), data['std'].float() + + recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', + clip_denoised=False).detach().cpu() + + recon = recon.transpose(1, 2).contiguous() + x = x.transpose(1, 2).contiguous() + + for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))): + + im = np.fliplr(np.flipud(d[-1])) + plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray') + write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + + export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + + +def main(opt): + exp_id = os.path.splitext(os.path.basename(__file__))[0] + dir_id = os.path.dirname(__file__) + output_dir = get_output_dir(dir_id, exp_id) + copy_source(__file__, output_dir) + logger = setup_logging(output_dir) + + outf_syn, = setup_output_subdirs(output_dir, 'syn') + + betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) + netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) + + if opt.cuda: + netE.cuda() + + def _transform_(m): + return nn.parallel.DataParallel(m) + + netE = netE.cuda() + netE.multi_gpu_wrapper(_transform_) + + netE.eval() + + ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] + + with torch.no_grad(): + for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): + + opt.netE = ckpt + logger.info("Resume Path:%s" % opt.netE) + + resumed_param = torch.load(opt.netE) + netE.load_state_dict(resumed_param['model_state']) + + + if opt.eval_recon_mvr: + # Evaluate generation + evaluate_recon_mvr(opt, netE, outf_syn, logger) + + if opt.generate_multimodal: + + generate_multimodal(opt, netE, outf_syn, logger) + + +def parse_args(): + + parser = argparse.ArgumentParser() + parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') + parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', + help='input batch size') + parser.add_argument('--classes', default=['car']) + + parser.add_argument('--batch_size', type=int, default=8, help='input batch size') + parser.add_argument('--workers', type=int, default=16, help='workers') + parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + + parser.add_argument('--eval_recon_mvr', default=True) + parser.add_argument('--generate_multimodal', default=False) + + parser.add_argument('--nc', default=3) + parser.add_argument('--npoints', default=2048) + parser.add_argument('--svpoints', default=200) + '''model''' + parser.add_argument('--beta_start', default=0.0001) + parser.add_argument('--beta_end', default=0.02) + parser.add_argument('--schedule_type', default='linear') + parser.add_argument('--time_num', default=1000) + + #params + parser.add_argument('--attention', default=True) + parser.add_argument('--dropout', default=0.1) + parser.add_argument('--embed_dim', type=int, default=64) + parser.add_argument('--loss_type', default='mse') + parser.add_argument('--model_mean_type', default='eps') + parser.add_argument('--model_var_type', default='fixedsmall') + + + # constrain function + parser.add_argument('--constrain_eps', default=0.2) + parser.add_argument('--constrain_steps', type=int, default=1) + + parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/3_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-11-01-03-08-40', help="path to netE (to continue training)") + + '''eval''' + + parser.add_argument('--eval_path', + default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth') + + parser.add_argument('--manualSeed', default=42, type=int, help='random seed') + + parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') + + opt = parser.parse_args() + + if torch.cuda.is_available(): + opt.cuda = True + else: + opt.cuda = False + + return opt +if __name__ == '__main__': + opt = parse_args() + set_seed(opt) + + main(opt) diff --git a/shape_completion/test_chair.py b/shape_completion/test_chair.py new file mode 100644 index 0000000..dda8e3c --- /dev/null +++ b/shape_completion/test_chair.py @@ -0,0 +1,753 @@ + +from pprint import pprint +from metrics.evaluation_metrics import compute_all_metrics, EMD_CD + +import torch.nn as nn +import torch.utils.data + +import argparse +from torch.distributions import Normal +from utils.visualize import * +from utils.file_utils import * +from utils.mitsuba_renderer import write_to_xml_batch +from model.pvcnn_completion import PVCNN2Base + +from datasets.shapenet_data_pc import ShapeNet15kPointClouds +from datasets.shapenet_data_sv import * +''' +models +''' +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + KL divergence between normal distributions parameterized by mean and log-variance. + """ + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + (mean1 - mean2)**2 * torch.exp(-logvar2)) + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + # Assumes data is integers [0, 1] + assert x.shape == means.shape == log_scales.shape + px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 0.5) + cdf_plus = px0.cdf(plus_in) + min_in = inv_stdv * (centered_x - .5) + cdf_min = px0.cdf(min_in) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where( + x < 0.001, log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + assert log_probs.shape == x.shape + return log_probs + + + +class GaussianDiffusion: + def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): + self.loss_type = loss_type + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + assert isinstance(betas, np.ndarray) + self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy + assert (betas > 0).all() and (betas <= 1).all() + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.sv_points = sv_points + # initialize twice the actual length so we can keep running for eval + # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) + + alphas = 1. - betas + alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + + self.betas = torch.from_numpy(betas).float() + self.alphas_cumprod = alphas_cumprod.float() + self.alphas_cumprod_prev = alphas_cumprod_prev.float() + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + + betas = torch.from_numpy(betas).float() + alphas = torch.from_numpy(alphas).float() + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance = posterior_variance + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) + self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + + @staticmethod + def _extract(a, t, x_shape): + """ + Extract some coefficients at specified timesteps, + then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + bs, = t.shape + assert x_shape[0] == bs + out = torch.gather(a, 0, t) + assert out.shape == torch.Size([bs]) + return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) + + + + def q_mean_variance(self, x_start, t): + mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data (t == 0 means diffused for 1 step) + """ + if noise is None: + noise = torch.randn(x_start.shape, device=x_start.device) + assert noise.shape == x_start.shape + return ( + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + ) + + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) + posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + + model_output = denoise_fn(data, t)[:,:,self.sv_points:] + + + if self.model_var_type in ['fixedsmall', 'fixedlarge']: + # below: only log_variance is used in the KL computations + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood + 'fixedlarge': (self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), + 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + }[self.model_var_type] + model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) + model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) + else: + raise NotImplementedError(self.model_var_type) + + if self.model_mean_type == 'eps': + x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) + + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) + else: + raise NotImplementedError(self.loss_type) + + + assert model_mean.shape == x_recon.shape + assert model_variance.shape == model_log_variance.shape + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - + self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + ) + + ''' samples ''' + + def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): + """ + Sample from the model + """ + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, + return_pred_xstart=True) + noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) + + # no noise when t == 0 + nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) + + sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise + sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) + return (sample, pred_xstart) if return_pred_xstart else sample + + + def p_sample_loop(self, partial_x, denoise_fn, shape, device, + noise_fn=torch.randn, clip_denoised=True, keep_running=False): + """ + Generate samples + keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps + + """ + + assert isinstance(shape, (tuple, list)) + img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) + for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, return_pred_xstart=False) + + assert img_t[:,:,self.sv_points:].shape == shape + return img_t + + def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, + noise_fn=torch.randn,clip_denoised=True, keep_running=False): + """ + Generate samples, returning intermediate images + Useful for visualizing how denoised images evolve over time + Args: + repeat_noise_steps (int): Number of denoising timesteps in which the same noise + is used across the batch. If >= 0, the initial noise is the same for all batch elemements. + """ + assert isinstance(shape, (tuple, list)) + + total_steps = self.num_timesteps if not keep_running else len(self.betas) + + img_t = noise_fn(size=shape, dtype=torch.float, device=device) + imgs = [img_t] + for t in reversed(range(0,total_steps)): + + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False) + if t % freq == 0 or t == total_steps-1: + imgs.append(img_t) + + assert imgs[-1].shape == shape + return imgs + + '''losses''' + + def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( + denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) + + kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) + kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) + + return (kl, pred_xstart) if return_pred_xstart else kl + + def p_losses(self, denoise_fn, data_start, t, noise=None): + """ + Training loss calculation + """ + B, D, N = data_start.shape + assert t.shape == torch.Size([B]) + + if noise is None: + noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) + + data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) + + if self.loss_type == 'mse': + # predict the noise instead of x_start. seems to be weighted naturally like SNR + eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] + + losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) + elif self.loss_type == 'kl': + losses = self._vb_terms_bpd( + denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, + return_pred_xstart=False) + else: + raise NotImplementedError(self.loss_type) + + assert losses.shape == torch.Size([B]) + return losses + + '''debug''' + + def _prior_bpd(self, x_start): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, + mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) + assert kl_prior.shape == x_start.shape + return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) + + def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + + vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) + for t in reversed(range(T)): + + t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) + # Calculate VLB term at the current timestep + data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) + new_vals_b, pred_xstart = self._vb_terms_bpd( + denoise_fn, data_start=x_start, data_t=data_t, t=t_b, + clip_denoised=clip_denoised, return_pred_xstart=True) + # MSE for progressive prediction loss + assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape + new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) + assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) + # Insert the calculated term into the tensor of all terms + mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() + vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt + mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt + assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) + + prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) + total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b + assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ + total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) + return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() + + +class PVCNN2(PVCNN2Base): + sa_blocks = [ + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, + voxel_resolution_multiplier=1): + super().__init__( + num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, + dropout=dropout, extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + + +class Model(nn.Module): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + super(Model, self).__init__() + self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) + + self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, + dropout=args.dropout, extra_feature_channels=0) + + def prior_kl(self, x0): + return self.diffusion._prior_bpd(x0) + + def all_kl(self, x0, clip_denoised=True): + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + + return { + 'total_bpd_b': total_bpd_b, + 'terms_bpd': vals_bt, + 'prior_bpd_b': prior_bpd_b, + 'mse_bt':mse_bt + } + + + def _denoise(self, data, t): + B, D,N= data.shape + assert data.dtype == torch.float + assert t.shape == torch.Size([B]) and t.dtype == torch.int64 + + out = self.model(data, t) + + return out + + def get_loss_iter(self, data, noises=None): + B, D, N = data.shape + t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) + + if noises is not None: + noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + + losses = self.diffusion.p_losses( + denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + assert losses.shape == t.shape == torch.Size([B]) + return losses + + def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, + clip_denoised=True, + keep_running=False): + return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, + clip_denoised=clip_denoised, + keep_running=keep_running) + + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def multi_gpu_wrapper(self, f): + self.model = f(self.model) + +def get_betas(schedule_type, b_start, b_end, time_num): + if schedule_type == 'linear': + betas = np.linspace(b_start, b_end, time_num) + elif schedule_type == 'warm0.1': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.1) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.2': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.2) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.5': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.5) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + else: + raise NotImplementedError(schedule_type) + return betas + + + +############################################################################# +def get_svr_dataset(pc_dataroot, mesh_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, + cache=os.path.join(mesh_root, '../cache'), split='val', + categories=category, + radius=3, elev=-89, azim=180, img_size=512, focal_length=1000, + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return te_dataset + +def get_mvr_dataset(pc_dataroot, views_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, + cache=os.path.join(pc_dataroot, '../cache'), split='val', + categories=category, + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return te_dataset + + +def evaluate_recon_mvr(opt, netE, save_dir, logger): + test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, + opt.npoints, opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + ref = [] + samples = [] + images = [] + masked = [] + k = 0 + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + gt_all = data['test_points'] + x_all = data['sv_points'] + # img_all = data['image'] + + + B,V,N,C = x_all.shape + gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1) + + x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous() + # img = img_all.reshape(B * V, *img_all.shape[2:]) + + m, s = data['mean'].float(), data['std'].float() + + recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', + clip_denoised=False).detach().cpu() + + recon = recon.transpose(1, 2).contiguous() + x = x.transpose(1, 2).contiguous() + + + x_adj = x.reshape(B,V,N,C)* s + m + recon_adj = recon.reshape(B,V,N,C)* s + m + # img = img.reshape(B,V,*img.shape[1:]) + + ref.append( gt_all * s + m) + masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) + samples.append(recon_adj) + # images.append(img) + + ref_pcs = torch.cat(ref, dim=0) + sample_pcs = torch.cat(samples, dim=0) + # images = torch.cat(images, dim=0) + masked = torch.cat(masked, dim=0) + + B, V, N, C = ref_pcs.shape + + + torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth')) + # torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) + torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) + # Compute metrics + results = EMD_CD(sample_pcs.reshape(B*V, N, C), + ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) + + results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} + + pprint({key: val.mean().item() for key, val in results.items()}) + logger.info({key: val.mean().item() for key, val in results.items()}) + + results['pc'] = sample_pcs + torch.save(results, os.path.join(save_dir, 'ours_results.pth')) + + del ref_pcs, masked, results + +def evaluate_saved(opt, netE, save_dir, logger): + ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn' + + gt_pth = ours_base + '/recon_gt.pth' + ours_pth = ours_base + '/ours_results.pth' + gt = torch.load(gt_pth).permute(1,0,2,3) + ours = torch.load(ours_pth)['pc'].permute(1,0,2,3) + + all_res = {} + for i, (gt_, ours_) in enumerate(zip(gt, ours)): + results = compute_all_metrics(gt_, ours_, opt.batch_size) + + for key, val in results.items(): + if i == 0: + all_res[key] = val + else: + all_res[key] += val + pprint(results) + for key, val in all_res.items(): + all_res[key] = val / gt.shape[0] + + pprint({key: val.mean().item() for key, val in all_res.items()}) + + +def generate_multimodal(opt, netE, save_dir, logger): + test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv, + opt.npoints, opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + gt_all = data['test_points'] + x_all = data['sv_points'] + img_all = data['image'] + + for v in range(6): + x = x_all.transpose(1, 2).contiguous() + img = img_all + + m, s = data['mean'].float(), data['std'].float() + + recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', + clip_denoised=False).detach().cpu() + + recon = recon.transpose(1, 2).contiguous() + x = x.transpose(1, 2).contiguous() + + for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))): + + im = np.fliplr(np.flipud(d[-1])) + plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray') + write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy(), cat='chair') + + export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + + +def redwood_demo(opt, netE, save_dir, logger): + import open3d as o3d + pth = "/viscam/u/alexzhou907/01DATA/redwood/09515_pc_partial.ply" + pth_gt = "/viscam/u/alexzhou907/01DATA/redwood/09515_pc.ply" + + points = np.asarray(o3d.io.read_point_cloud(pth).points) + + gt_points = np.asarray(o3d.io.read_point_cloud(pth_gt).points) + + np.save('gt.npy', gt_points) + + test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, + opt.npoints, opt.classes) + + m, s = torch.from_numpy(test_dataset[0]['mean']).float(), torch.from_numpy(test_dataset[0]['std']).float() + + x = torch.from_numpy(points[np.random.choice(points.shape[0], size=opt.npoints, replace=False)]).float() + x = (x-m)/s + + x = x.transpose(1,2).cuda() + + res = [] + for k in range(20): + recon = netE.gen_samples(x[:,:,:opt.svpoints], x[:,:,opt.svpoints:].shape, 'cuda', + clip_denoised=False).detach().cpu() + recon = recon.transpose(1, 2).contiguous() + recon = recon * s+ m + res.append(recon) + res = torch.cat(res, dim=0) + + write_to_xml_batch(os.path.join(save_dir, 'xml'), + (res).numpy(), cat='chair') + + export_to_pc_batch(os.path.join(save_dir, 'ply'), + (res).numpy()) + + torch.save(res, os.path.join(save_dir, 'redwood_demo.pth')) + + pcwrite(os.path.join(save_dir, 'ply', 'gt.ply'), + gt_points[np.random.choice(gt_points.shape[0], size=opt.npoints, replace=False)]) + write_to_xml_batch(os.path.join(save_dir, 'xml_gt'), + gt_points[None], cat='chair') + + exit() + +def main(opt): + exp_id = os.path.splitext(os.path.basename(__file__))[0] + dir_id = os.path.dirname(__file__) + output_dir = get_output_dir(dir_id, exp_id) + copy_source(__file__, output_dir) + logger = setup_logging(output_dir) + + outf_syn, = setup_output_subdirs(output_dir, 'syn') + + betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) + netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) + + if opt.cuda: + netE.cuda() + + def _transform_(m): + return nn.parallel.DataParallel(m) + + netE = netE.cuda() + netE.multi_gpu_wrapper(_transform_) + + netE.eval() + + ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] + + with torch.no_grad(): + for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): + + opt.netE = ckpt + logger.info("Resume Path:%s" % opt.netE) + + resumed_param = torch.load(opt.netE) + netE.load_state_dict(resumed_param['model_state']) + + + if opt.eval_recon_mvr: + # Evaluate generation + evaluate_recon_mvr(opt, netE, outf_syn, logger) + + if opt.generate_multimodal: + + generate_multimodal(opt, netE, outf_syn, logger) + + if opt.eval_saved: + evaluate_saved(opt, netE, outf_syn, logger) + + if opt.eval_redwood: + redwood_demo(opt, netE, outf_syn, logger) + +def parse_args(): + + parser = argparse.ArgumentParser() + parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') + parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', + help='input batch size') + parser.add_argument('--classes', default=['chair']) + + parser.add_argument('--batch_size', type=int, default=1, help='input batch size') + parser.add_argument('--workers', type=int, default=16, help='workers') + parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + + parser.add_argument('--eval_recon_mvr', default=False) + parser.add_argument('--generate_multimodal', default=False) + parser.add_argument('--eval_saved', default=False) + parser.add_argument('--eval_redwood', default=True) + + parser.add_argument('--nc', default=3) + parser.add_argument('--npoints', default=2048) + parser.add_argument('--svpoints', default=200) + '''model''' + parser.add_argument('--beta_start', default=0.0001) + parser.add_argument('--beta_end', default=0.02) + parser.add_argument('--schedule_type', default='linear') + parser.add_argument('--time_num', default=1000) + + #params + parser.add_argument('--attention', default=True) + parser.add_argument('--dropout', default=0.1) + parser.add_argument('--embed_dim', type=int, default=64) + parser.add_argument('--loss_type', default='mse') + parser.add_argument('--model_mean_type', default='eps') + parser.add_argument('--model_var_type', default='fixedsmall') + + + # constrain function + parser.add_argument('--constrain_eps', default=0.2) + parser.add_argument('--constrain_steps', type=int, default=1) + + parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/6_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-11-01-19-21-18', help="path to netE (to continue training)") + + '''eval''' + + parser.add_argument('--eval_path', + default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth') + + parser.add_argument('--manualSeed', default=42, type=int, help='random seed') + + parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') + + opt = parser.parse_args() + + if torch.cuda.is_available(): + opt.cuda = True + else: + opt.cuda = False + + return opt +if __name__ == '__main__': + opt = parse_args() + + main(opt) diff --git a/shape_completion/test_completion.py b/shape_completion/test_completion.py new file mode 100644 index 0000000..9be1ed7 --- /dev/null +++ b/shape_completion/test_completion.py @@ -0,0 +1,634 @@ + +from pprint import pprint +from metrics.evaluation_metrics import compute_all_metrics, EMD_CD + +import torch.nn as nn +import torch.utils.data + +import argparse +from torch.distributions import Normal +from utils.visualize import * +from utils.file_utils import * +from utils.mitsuba_renderer import write_to_xml_batch +from model.pvcnn_completion import PVCNN2Base + +from datasets.shapenet_data_pc import ShapeNet15kPointClouds +from datasets.shapenet_data_sv import * +''' +models +''' +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + KL divergence between normal distributions parameterized by mean and log-variance. + """ + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + (mean1 - mean2)**2 * torch.exp(-logvar2)) + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + # Assumes data is integers [0, 1] + assert x.shape == means.shape == log_scales.shape + px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 0.5) + cdf_plus = px0.cdf(plus_in) + min_in = inv_stdv * (centered_x - .5) + cdf_min = px0.cdf(min_in) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where( + x < 0.001, log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + assert log_probs.shape == x.shape + return log_probs + + + +class GaussianDiffusion: + def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): + self.loss_type = loss_type + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + assert isinstance(betas, np.ndarray) + self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy + assert (betas > 0).all() and (betas <= 1).all() + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.sv_points = sv_points + # initialize twice the actual length so we can keep running for eval + # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) + + alphas = 1. - betas + alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + + self.betas = torch.from_numpy(betas).float() + self.alphas_cumprod = alphas_cumprod.float() + self.alphas_cumprod_prev = alphas_cumprod_prev.float() + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + + betas = torch.from_numpy(betas).float() + alphas = torch.from_numpy(alphas).float() + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance = posterior_variance + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) + self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + + @staticmethod + def _extract(a, t, x_shape): + """ + Extract some coefficients at specified timesteps, + then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + bs, = t.shape + assert x_shape[0] == bs + out = torch.gather(a, 0, t) + assert out.shape == torch.Size([bs]) + return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) + + + + def q_mean_variance(self, x_start, t): + mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data (t == 0 means diffused for 1 step) + """ + if noise is None: + noise = torch.randn(x_start.shape, device=x_start.device) + assert noise.shape == x_start.shape + return ( + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + ) + + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) + posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + + model_output = denoise_fn(data, t)[:,:,self.sv_points:] + + + if self.model_var_type in ['fixedsmall', 'fixedlarge']: + # below: only log_variance is used in the KL computations + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood + 'fixedlarge': (self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), + 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + }[self.model_var_type] + model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) + model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) + else: + raise NotImplementedError(self.model_var_type) + + if self.model_mean_type == 'eps': + x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) + + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) + else: + raise NotImplementedError(self.loss_type) + + + assert model_mean.shape == x_recon.shape + assert model_variance.shape == model_log_variance.shape + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - + self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + ) + + ''' samples ''' + + def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): + """ + Sample from the model + """ + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, + return_pred_xstart=True) + noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) + + # no noise when t == 0 + nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) + + sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise + sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) + return (sample, pred_xstart) if return_pred_xstart else sample + + + def p_sample_loop(self, partial_x, denoise_fn, shape, device, + noise_fn=torch.randn, clip_denoised=True, keep_running=False): + """ + Generate samples + keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps + + """ + + assert isinstance(shape, (tuple, list)) + img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) + for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, return_pred_xstart=False) + + assert img_t[:,:,self.sv_points:].shape == shape + return img_t + + def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, + noise_fn=torch.randn,clip_denoised=True, keep_running=False): + """ + Generate samples, returning intermediate images + Useful for visualizing how denoised images evolve over time + Args: + repeat_noise_steps (int): Number of denoising timesteps in which the same noise + is used across the batch. If >= 0, the initial noise is the same for all batch elemements. + """ + assert isinstance(shape, (tuple, list)) + + total_steps = self.num_timesteps if not keep_running else len(self.betas) + + img_t = noise_fn(size=shape, dtype=torch.float, device=device) + imgs = [img_t] + for t in reversed(range(0,total_steps)): + + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False) + if t % freq == 0 or t == total_steps-1: + imgs.append(img_t) + + assert imgs[-1].shape == shape + return imgs + + '''losses''' + + def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( + denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) + + kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) + kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) + + return (kl, pred_xstart) if return_pred_xstart else kl + + def p_losses(self, denoise_fn, data_start, t, noise=None): + """ + Training loss calculation + """ + B, D, N = data_start.shape + assert t.shape == torch.Size([B]) + + if noise is None: + noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) + + data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) + + if self.loss_type == 'mse': + # predict the noise instead of x_start. seems to be weighted naturally like SNR + eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] + + losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) + elif self.loss_type == 'kl': + losses = self._vb_terms_bpd( + denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, + return_pred_xstart=False) + else: + raise NotImplementedError(self.loss_type) + + assert losses.shape == torch.Size([B]) + return losses + + '''debug''' + + def _prior_bpd(self, x_start): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, + mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) + assert kl_prior.shape == x_start.shape + return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) + + def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + + vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) + for t in reversed(range(T)): + + t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) + # Calculate VLB term at the current timestep + data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) + new_vals_b, pred_xstart = self._vb_terms_bpd( + denoise_fn, data_start=x_start, data_t=data_t, t=t_b, + clip_denoised=clip_denoised, return_pred_xstart=True) + # MSE for progressive prediction loss + assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape + new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) + assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) + # Insert the calculated term into the tensor of all terms + mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() + vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt + mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt + assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) + + prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) + total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b + assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ + total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) + return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() + + +class PVCNN2(PVCNN2Base): + sa_blocks = [ + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, + voxel_resolution_multiplier=1): + super().__init__( + num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, + dropout=dropout, extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + + +class Model(nn.Module): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + super(Model, self).__init__() + self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) + + self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, + dropout=args.dropout, extra_feature_channels=0) + + def prior_kl(self, x0): + return self.diffusion._prior_bpd(x0) + + def all_kl(self, x0, clip_denoised=True): + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + + return { + 'total_bpd_b': total_bpd_b, + 'terms_bpd': vals_bt, + 'prior_bpd_b': prior_bpd_b, + 'mse_bt':mse_bt + } + + + def _denoise(self, data, t): + B, D,N= data.shape + assert data.dtype == torch.float + assert t.shape == torch.Size([B]) and t.dtype == torch.int64 + + out = self.model(data, t) + + return out + + def get_loss_iter(self, data, noises=None): + B, D, N = data.shape + t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) + + if noises is not None: + noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + + losses = self.diffusion.p_losses( + denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + assert losses.shape == t.shape == torch.Size([B]) + return losses + + def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, + clip_denoised=True, + keep_running=False): + return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, + clip_denoised=clip_denoised, + keep_running=keep_running) + + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def multi_gpu_wrapper(self, f): + self.model = f(self.model) + +def get_betas(schedule_type, b_start, b_end, time_num): + if schedule_type == 'linear': + betas = np.linspace(b_start, b_end, time_num) + elif schedule_type == 'warm0.1': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.1) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.2': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.2) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.5': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.5) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + else: + raise NotImplementedError(schedule_type) + return betas + + + +############################################################################# + +def get_mvr_dataset(pc_dataroot, views_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=[category], split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, + cache=os.path.join(pc_dataroot, '../cache'), split='val', + categories=[category], + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return te_dataset + + +def evaluate_recon_mvr(opt, model, save_dir, logger): + test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, + opt.npoints, opt.category) + + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + ref = [] + samples = [] + masked = [] + k = 0 + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + gt_all = data['test_points'] + x_all = data['sv_points'] + + B,V,N,C = x_all.shape + gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1) + + x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous() + + m, s = data['mean'].float(), data['std'].float() + + recon = model.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', + clip_denoised=False).detach().cpu() + + recon = recon.transpose(1, 2).contiguous() + x = x.transpose(1, 2).contiguous() + + + x_adj = x.reshape(B,V,N,C)* s + m + recon_adj = recon.reshape(B,V,N,C)* s + m + + ref.append( gt_all * s + m) + masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) + samples.append(recon_adj) + + ref_pcs = torch.cat(ref, dim=0) + sample_pcs = torch.cat(samples, dim=0) + masked = torch.cat(masked, dim=0) + + B, V, N, C = ref_pcs.shape + + + torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth')) + + torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) + # Compute metrics + results = EMD_CD(sample_pcs.reshape(B*V, N, C), + ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) + + results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} + + pprint({key: val.mean().item() for key, val in results.items()}) + logger.info({key: val.mean().item() for key, val in results.items()}) + + results['pc'] = sample_pcs + torch.save(results, os.path.join(save_dir, 'ours_results.pth')) + + del ref_pcs, masked, results + +def evaluate_saved(opt, saved_dir): + # ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn' + + gt_pth = saved_dir + '/recon_gt.pth' + ours_pth = saved_dir + '/ours_results.pth' + gt = torch.load(gt_pth).permute(1,0,2,3) + ours = torch.load(ours_pth)['pc'].permute(1,0,2,3) + + all_res = {} + for i, (gt_, ours_) in enumerate(zip(gt, ours)): + results = compute_all_metrics(gt_, ours_, opt.batch_size) + + for key, val in results.items(): + if i == 0: + all_res[key] = val + else: + all_res[key] += val + pprint(results) + for key, val in all_res.items(): + all_res[key] = val / gt.shape[0] + + pprint({key: val.mean().item() for key, val in all_res.items()}) + + + +def main(opt): + exp_id = os.path.splitext(os.path.basename(__file__))[0] + dir_id = os.path.dirname(__file__) + output_dir = get_output_dir(dir_id, exp_id) + copy_source(__file__, output_dir) + logger = setup_logging(output_dir) + + outf_syn, = setup_output_subdirs(output_dir, 'syn') + + betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) + model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) + + if opt.cuda: + model.cuda() + + def _transform_(m): + return nn.parallel.DataParallel(m) + + model = model.cuda() + model.multi_gpu_wrapper(_transform_) + + model.eval() + + with torch.no_grad(): + + logger.info("Resume Path:%s" % opt.model) + + resumed_param = torch.load(opt.model) + model.load_state_dict(resumed_param['model_state']) + + + if opt.eval_recon_mvr: + # Evaluate generation + evaluate_recon_mvr(opt, model, outf_syn, logger) + + if opt.eval_saved: + evaluate_saved(opt, outf_syn) + + +def parse_args(): + + parser = argparse.ArgumentParser() + parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') + parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', + help='input batch size') + parser.add_argument('--category', default='chair') + + parser.add_argument('--batch_size', type=int, default=50, help='input batch size') + parser.add_argument('--workers', type=int, default=16, help='workers') + parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + + parser.add_argument('--eval_recon_mvr', default=True) + parser.add_argument('--eval_saved', default=True) + + parser.add_argument('--nc', default=3) + parser.add_argument('--npoints', default=2048) + parser.add_argument('--svpoints', default=200) + '''model''' + parser.add_argument('--beta_start', default=0.0001) + parser.add_argument('--beta_end', default=0.02) + parser.add_argument('--schedule_type', default='linear') + parser.add_argument('--time_num', default=1000) + + #params + parser.add_argument('--attention', default=True) + parser.add_argument('--dropout', default=0.1) + parser.add_argument('--embed_dim', type=int, default=64) + parser.add_argument('--loss_type', default='mse') + parser.add_argument('--model_mean_type', default='eps') + parser.add_argument('--model_var_type', default='fixedsmall') + + + parser.add_argument('--model', default='', required=True, help="path to model (to continue training)") + + '''eval''' + + parser.add_argument('--eval_path', + default='') + + parser.add_argument('--manualSeed', default=42, type=int, help='random seed') + + parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') + + opt = parser.parse_args() + + if torch.cuda.is_available(): + opt.cuda = True + else: + opt.cuda = False + + return opt +if __name__ == '__main__': + opt = parse_args() + + main(opt) diff --git a/shape_completion/test_partnet_chair.py b/shape_completion/test_partnet_chair.py new file mode 100644 index 0000000..babb8a0 --- /dev/null +++ b/shape_completion/test_partnet_chair.py @@ -0,0 +1,599 @@ + +from pprint import pprint +from tqdm import tqdm +import torch.nn as nn +import torch.utils.data + +import argparse +from torch.distributions import Normal +from utils.visualize import * +from utils.file_utils import * +from utils.mitsuba_renderer import write_to_xml_batch +from model.pvcnn_completion import PVCNN2Base + +from datasets.shapenet_data_pc import ShapeNet15kPointClouds +from datasets.partnet import GANdatasetPartNet +import trimesh +import csv +import numpy as np +import random +from plyfile import PlyData, PlyElement + + +def write_ply(points, filename, text=False): + """ input: Nx3, write points to filename as PLY format. """ + points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] + vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) + el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) + with open(filename, mode='wb') as f: + PlyData([el], text=text).write(f) + +''' +models +''' +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + KL divergence between normal distributions parameterized by mean and log-variance. + """ + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + (mean1 - mean2)**2 * torch.exp(-logvar2)) + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + # Assumes data is integers [0, 1] + assert x.shape == means.shape == log_scales.shape + px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 0.5) + cdf_plus = px0.cdf(plus_in) + min_in = inv_stdv * (centered_x - .5) + cdf_min = px0.cdf(min_in) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where( + x < 0.001, log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + assert log_probs.shape == x.shape + return log_probs + + + +class GaussianDiffusion: + def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): + self.loss_type = loss_type + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + assert isinstance(betas, np.ndarray) + self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy + assert (betas > 0).all() and (betas <= 1).all() + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.sv_points = sv_points + # initialize twice the actual length so we can keep running for eval + # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) + + alphas = 1. - betas + alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + + self.betas = torch.from_numpy(betas).float() + self.alphas_cumprod = alphas_cumprod.float() + self.alphas_cumprod_prev = alphas_cumprod_prev.float() + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + + betas = torch.from_numpy(betas).float() + alphas = torch.from_numpy(alphas).float() + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance = posterior_variance + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) + self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + + @staticmethod + def _extract(a, t, x_shape): + """ + Extract some coefficients at specified timesteps, + then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + bs, = t.shape + assert x_shape[0] == bs + out = torch.gather(a, 0, t) + assert out.shape == torch.Size([bs]) + return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) + + + + def q_mean_variance(self, x_start, t): + mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data (t == 0 means diffused for 1 step) + """ + if noise is None: + noise = torch.randn(x_start.shape, device=x_start.device) + assert noise.shape == x_start.shape + return ( + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + ) + + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) + posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + + model_output = denoise_fn(data, t)[:,:,self.sv_points:] + + + if self.model_var_type in ['fixedsmall', 'fixedlarge']: + # below: only log_variance is used in the KL computations + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood + 'fixedlarge': (self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), + 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + }[self.model_var_type] + model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) + model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) + else: + raise NotImplementedError(self.model_var_type) + + if self.model_mean_type == 'eps': + x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) + + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) + else: + raise NotImplementedError(self.loss_type) + + + assert model_mean.shape == x_recon.shape + assert model_variance.shape == model_log_variance.shape + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - + self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + ) + + ''' samples ''' + + def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): + """ + Sample from the model + """ + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, + return_pred_xstart=True) + noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) + + # no noise when t == 0 + nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) + + sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise + sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) + return (sample, pred_xstart) if return_pred_xstart else sample + + + def p_sample_loop(self, partial_x, denoise_fn, shape, device, + noise_fn=torch.randn, clip_denoised=True, keep_running=False): + """ + Generate samples + keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps + + """ + + assert isinstance(shape, (tuple, list)) + img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) + for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, return_pred_xstart=False) + + assert img_t[:,:,self.sv_points:].shape == shape + return img_t + + def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, + noise_fn=torch.randn,clip_denoised=True, keep_running=False): + """ + Generate samples, returning intermediate images + Useful for visualizing how denoised images evolve over time + Args: + repeat_noise_steps (int): Number of denoising timesteps in which the same noise + is used across the batch. If >= 0, the initial noise is the same for all batch elemements. + """ + assert isinstance(shape, (tuple, list)) + + total_steps = self.num_timesteps if not keep_running else len(self.betas) + + img_t = noise_fn(size=shape, dtype=torch.float, device=device) + imgs = [img_t] + for t in reversed(range(0,total_steps)): + + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False) + if t % freq == 0 or t == total_steps-1: + imgs.append(img_t) + + assert imgs[-1].shape == shape + return imgs + + '''losses''' + + def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( + denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) + + kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) + kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) + + return (kl, pred_xstart) if return_pred_xstart else kl + + def p_losses(self, denoise_fn, data_start, t, noise=None): + """ + Training loss calculation + """ + B, D, N = data_start.shape + assert t.shape == torch.Size([B]) + + if noise is None: + noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) + + data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) + + if self.loss_type == 'mse': + # predict the noise instead of x_start. seems to be weighted naturally like SNR + eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] + + losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) + elif self.loss_type == 'kl': + losses = self._vb_terms_bpd( + denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, + return_pred_xstart=False) + else: + raise NotImplementedError(self.loss_type) + + assert losses.shape == torch.Size([B]) + return losses + + '''debug''' + + def _prior_bpd(self, x_start): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, + mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) + assert kl_prior.shape == x_start.shape + return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) + + def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + + vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) + for t in reversed(range(T)): + + t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) + # Calculate VLB term at the current timestep + data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) + new_vals_b, pred_xstart = self._vb_terms_bpd( + denoise_fn, data_start=x_start, data_t=data_t, t=t_b, + clip_denoised=clip_denoised, return_pred_xstart=True) + # MSE for progressive prediction loss + assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape + new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) + assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) + # Insert the calculated term into the tensor of all terms + mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() + vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt + mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt + assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) + + prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) + total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b + assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ + total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) + return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() + + +class PVCNN2(PVCNN2Base): + sa_blocks = [ + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, + voxel_resolution_multiplier=1): + super().__init__( + num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, + dropout=dropout, extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + + +class Model(nn.Module): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + super(Model, self).__init__() + self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) + + self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, + dropout=args.dropout, extra_feature_channels=0) + + def prior_kl(self, x0): + return self.diffusion._prior_bpd(x0) + + def all_kl(self, x0, clip_denoised=True): + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + + return { + 'total_bpd_b': total_bpd_b, + 'terms_bpd': vals_bt, + 'prior_bpd_b': prior_bpd_b, + 'mse_bt':mse_bt + } + + + def _denoise(self, data, t): + B, D,N= data.shape + assert data.dtype == torch.float + assert t.shape == torch.Size([B]) and t.dtype == torch.int64 + + out = self.model(data, t) + + return out + + def get_loss_iter(self, data, noises=None): + B, D, N = data.shape + t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) + + if noises is not None: + noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + + losses = self.diffusion.p_losses( + denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + assert losses.shape == t.shape == torch.Size([B]) + return losses + + def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, + clip_denoised=True, + keep_running=False): + return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, + clip_denoised=clip_denoised, + keep_running=keep_running) + + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def multi_gpu_wrapper(self, f): + self.model = f(self.model) + +def get_betas(schedule_type, b_start, b_end, time_num): + if schedule_type == 'linear': + betas = np.linspace(b_start, b_end, time_num) + elif schedule_type == 'warm0.1': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.1) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.2': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.2) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.5': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.5) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + else: + raise NotImplementedError(schedule_type) + return betas + + + +############################################################################# +def get_dataset(data_root, npoints, category): + + train_ds = GANdatasetPartNet('test', data_root, category, npoints) + return train_ds + + +def generate_multimodal(opt, netE, save_dir, logger): + test_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + gt_all = data['real'] + x_all = data['raw'] + + for j in range(5): + x = torch.cat([x_all, gt_all[:, :, opt.svpoints:]], dim=-1) + + recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', + clip_denoised=False).detach().cpu() + + + + + for p, d in enumerate(zip(list(x_all), list(recon), list(data['raw_id']))): + + partial = torch.cat([d[0], d[0][:, 0:1].expand(-1, opt.svpoints)], dim=-1) + rec = d[1] + rid = d[2] + write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d' % (i, p), 'mode_%03d' % j), + (torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy(), cat='chair') + + export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, p), 'mode_%03d' % j), + (torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy()) + + raw_id = rid.split('.')[0] + save_sample_dir = os.path.join(save_dir, "{}".format(raw_id)) + Path(save_sample_dir).mkdir(parents=True, exist_ok=True) + # save input partial shape + if j == 0: + save_path = os.path.join(save_sample_dir, "raw.ply") + write_ply((partial.detach().cpu()).transpose(0, 1).numpy(), save_path) + # save completed shape + save_path = os.path.join(save_sample_dir, "fake-z{}.ply".format(j)) + write_ply((rec.detach().cpu()).transpose(0, 1).numpy(), save_path) + + +def main(opt): + exp_id = os.path.splitext(os.path.basename(__file__))[0] + dir_id = os.path.dirname(__file__) + output_dir = get_output_dir(dir_id, exp_id) + copy_source(__file__, output_dir) + logger = setup_logging(output_dir) + + outf_syn, = setup_output_subdirs(output_dir, 'syn') + + betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) + netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) + + if opt.cuda: + netE.cuda() + + def _transform_(m): + return nn.parallel.DataParallel(m) + + netE = netE.cuda() + netE.multi_gpu_wrapper(_transform_) + + netE.eval() + + ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] + + with torch.no_grad(): + for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): + + opt.netE = ckpt + logger.info("Resume Path:%s" % opt.netE) + + resumed_param = torch.load(opt.netE) + netE.load_state_dict(resumed_param['model_state']) + + + + if opt.generate_multimodal: + + generate_multimodal(opt, netE, outf_syn, logger) + + +def parse_args(): + + parser = argparse.ArgumentParser() + parser.add_argument('--data_root', default='/viscam/u/alexzhou907/01DATA/partnet', + help='input batch size') + + parser.add_argument('--classes', default='Chair') + + parser.add_argument('--batch_size', type=int, default=64, help='input batch size') + parser.add_argument('--workers', type=int, default=16, help='workers') + parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + + parser.add_argument('--eval_recon_mvr', default=False) + parser.add_argument('--generate_multimodal', default=True) + parser.add_argument('--eval_saved', default=False) + + parser.add_argument('--nc', default=3) + parser.add_argument('--npoints', default=2048) + parser.add_argument('--svpoints', default=1024) + '''model''' + parser.add_argument('--beta_start', default=0.0001) + parser.add_argument('--beta_end', default=0.02) + parser.add_argument('--schedule_type', default='linear') + parser.add_argument('--time_num', default=1000) + + #params + parser.add_argument('--attention', default=True) + parser.add_argument('--dropout', default=0.1) + parser.add_argument('--embed_dim', type=int, default=64) + parser.add_argument('--loss_type', default='mse') + parser.add_argument('--model_mean_type', default='eps') + parser.add_argument('--model_var_type', default='fixedsmall') + + + # constrain function + parser.add_argument('--constrain_eps', default=0.2) + parser.add_argument('--constrain_steps', type=int, default=1) + + parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/epn3d_chair', help="path to netE (to continue training)") + + '''eval''' + + parser.add_argument('--eval_path', + default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth') + + parser.add_argument('--manualSeed', default=42, type=int, help='random seed') + + parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') + + opt = parser.parse_args() + + if torch.cuda.is_available(): + opt.cuda = True + else: + opt.cuda = False + + return opt +if __name__ == '__main__': + opt = parse_args() + + main(opt) diff --git a/shape_completion/test_partnet_table.py b/shape_completion/test_partnet_table.py new file mode 100644 index 0000000..8efbde5 --- /dev/null +++ b/shape_completion/test_partnet_table.py @@ -0,0 +1,599 @@ + +from pprint import pprint +from tqdm import tqdm +import torch.nn as nn +import torch.utils.data + +import argparse +from torch.distributions import Normal +from utils.visualize import * +from utils.file_utils import * +from utils.mitsuba_renderer import write_to_xml_batch +from model.pvcnn_completion import PVCNN2Base + +from datasets.shapenet_data_pc import ShapeNet15kPointClouds +from datasets.partnet import GANdatasetPartNet +import trimesh +import csv +import numpy as np +import random +from plyfile import PlyData, PlyElement + + +def write_ply(points, filename, text=False): + """ input: Nx3, write points to filename as PLY format. """ + points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] + vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) + el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) + with open(filename, mode='wb') as f: + PlyData([el], text=text).write(f) + +''' +models +''' +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + KL divergence between normal distributions parameterized by mean and log-variance. + """ + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + (mean1 - mean2)**2 * torch.exp(-logvar2)) + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + # Assumes data is integers [0, 1] + assert x.shape == means.shape == log_scales.shape + px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 0.5) + cdf_plus = px0.cdf(plus_in) + min_in = inv_stdv * (centered_x - .5) + cdf_min = px0.cdf(min_in) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where( + x < 0.001, log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + assert log_probs.shape == x.shape + return log_probs + + + +class GaussianDiffusion: + def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): + self.loss_type = loss_type + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + assert isinstance(betas, np.ndarray) + self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy + assert (betas > 0).all() and (betas <= 1).all() + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.sv_points = sv_points + # initialize twice the actual length so we can keep running for eval + # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) + + alphas = 1. - betas + alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + + self.betas = torch.from_numpy(betas).float() + self.alphas_cumprod = alphas_cumprod.float() + self.alphas_cumprod_prev = alphas_cumprod_prev.float() + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + + betas = torch.from_numpy(betas).float() + alphas = torch.from_numpy(alphas).float() + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance = posterior_variance + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) + self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + + @staticmethod + def _extract(a, t, x_shape): + """ + Extract some coefficients at specified timesteps, + then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + bs, = t.shape + assert x_shape[0] == bs + out = torch.gather(a, 0, t) + assert out.shape == torch.Size([bs]) + return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) + + + + def q_mean_variance(self, x_start, t): + mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data (t == 0 means diffused for 1 step) + """ + if noise is None: + noise = torch.randn(x_start.shape, device=x_start.device) + assert noise.shape == x_start.shape + return ( + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + ) + + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) + posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + + model_output = denoise_fn(data, t)[:,:,self.sv_points:] + + + if self.model_var_type in ['fixedsmall', 'fixedlarge']: + # below: only log_variance is used in the KL computations + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood + 'fixedlarge': (self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), + 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + }[self.model_var_type] + model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) + model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) + else: + raise NotImplementedError(self.model_var_type) + + if self.model_mean_type == 'eps': + x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) + + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) + else: + raise NotImplementedError(self.loss_type) + + + assert model_mean.shape == x_recon.shape + assert model_variance.shape == model_log_variance.shape + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - + self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + ) + + ''' samples ''' + + def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): + """ + Sample from the model + """ + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, + return_pred_xstart=True) + noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) + + # no noise when t == 0 + nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) + + sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise + sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) + return (sample, pred_xstart) if return_pred_xstart else sample + + + def p_sample_loop(self, partial_x, denoise_fn, shape, device, + noise_fn=torch.randn, clip_denoised=True, keep_running=False): + """ + Generate samples + keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps + + """ + + assert isinstance(shape, (tuple, list)) + img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) + for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, return_pred_xstart=False) + + assert img_t[:,:,self.sv_points:].shape == shape + return img_t + + def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, + noise_fn=torch.randn,clip_denoised=True, keep_running=False): + """ + Generate samples, returning intermediate images + Useful for visualizing how denoised images evolve over time + Args: + repeat_noise_steps (int): Number of denoising timesteps in which the same noise + is used across the batch. If >= 0, the initial noise is the same for all batch elemements. + """ + assert isinstance(shape, (tuple, list)) + + total_steps = self.num_timesteps if not keep_running else len(self.betas) + + img_t = noise_fn(size=shape, dtype=torch.float, device=device) + imgs = [img_t] + for t in reversed(range(0,total_steps)): + + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False) + if t % freq == 0 or t == total_steps-1: + imgs.append(img_t) + + assert imgs[-1].shape == shape + return imgs + + '''losses''' + + def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( + denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) + + kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) + kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) + + return (kl, pred_xstart) if return_pred_xstart else kl + + def p_losses(self, denoise_fn, data_start, t, noise=None): + """ + Training loss calculation + """ + B, D, N = data_start.shape + assert t.shape == torch.Size([B]) + + if noise is None: + noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) + + data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) + + if self.loss_type == 'mse': + # predict the noise instead of x_start. seems to be weighted naturally like SNR + eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] + + losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) + elif self.loss_type == 'kl': + losses = self._vb_terms_bpd( + denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, + return_pred_xstart=False) + else: + raise NotImplementedError(self.loss_type) + + assert losses.shape == torch.Size([B]) + return losses + + '''debug''' + + def _prior_bpd(self, x_start): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, + mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) + assert kl_prior.shape == x_start.shape + return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) + + def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + + vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) + for t in reversed(range(T)): + + t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) + # Calculate VLB term at the current timestep + data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) + new_vals_b, pred_xstart = self._vb_terms_bpd( + denoise_fn, data_start=x_start, data_t=data_t, t=t_b, + clip_denoised=clip_denoised, return_pred_xstart=True) + # MSE for progressive prediction loss + assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape + new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) + assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) + # Insert the calculated term into the tensor of all terms + mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() + vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt + mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt + assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) + + prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) + total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b + assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ + total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) + return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() + + +class PVCNN2(PVCNN2Base): + sa_blocks = [ + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, + voxel_resolution_multiplier=1): + super().__init__( + num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, + dropout=dropout, extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + + +class Model(nn.Module): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + super(Model, self).__init__() + self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) + + self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, + dropout=args.dropout, extra_feature_channels=0) + + def prior_kl(self, x0): + return self.diffusion._prior_bpd(x0) + + def all_kl(self, x0, clip_denoised=True): + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + + return { + 'total_bpd_b': total_bpd_b, + 'terms_bpd': vals_bt, + 'prior_bpd_b': prior_bpd_b, + 'mse_bt':mse_bt + } + + + def _denoise(self, data, t): + B, D,N= data.shape + assert data.dtype == torch.float + assert t.shape == torch.Size([B]) and t.dtype == torch.int64 + + out = self.model(data, t) + + return out + + def get_loss_iter(self, data, noises=None): + B, D, N = data.shape + t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) + + if noises is not None: + noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + + losses = self.diffusion.p_losses( + denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + assert losses.shape == t.shape == torch.Size([B]) + return losses + + def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, + clip_denoised=True, + keep_running=False): + return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, + clip_denoised=clip_denoised, + keep_running=keep_running) + + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def multi_gpu_wrapper(self, f): + self.model = f(self.model) + +def get_betas(schedule_type, b_start, b_end, time_num): + if schedule_type == 'linear': + betas = np.linspace(b_start, b_end, time_num) + elif schedule_type == 'warm0.1': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.1) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.2': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.2) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.5': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.5) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + else: + raise NotImplementedError(schedule_type) + return betas + + + +############################################################################# +def get_dataset(data_root, npoints, category): + + train_ds = GANdatasetPartNet('test', data_root, category, npoints) + return train_ds + + +def generate_multimodal(opt, netE, save_dir, logger): + test_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + gt_all = data['real'] + x_all = data['raw'] + + for j in range(5): + x = torch.cat([x_all, gt_all[:, :, opt.svpoints:]], dim=-1) + + recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', + clip_denoised=False).detach().cpu() + + + + + for p, d in enumerate(zip(list(x_all), list(recon), list(data['raw_id']))): + + partial = torch.cat([d[0], d[0][:, 0:1].expand(-1, opt.svpoints)], dim=-1) + rec = d[1] + rid = d[2] + write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d' % (i, p), 'mode_%03d' % j), + (torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy(), cat='chair') + + export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, p), 'mode_%03d' % j), + (torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy()) + + raw_id = rid.split('.')[0] + save_sample_dir = os.path.join(save_dir, "{}".format(raw_id)) + Path(save_sample_dir).mkdir(parents=True, exist_ok=True) + # save input partial shape + if j == 0: + save_path = os.path.join(save_sample_dir, "raw.ply") + write_ply((partial.detach().cpu()).transpose(0, 1).numpy(), save_path) + # save completed shape + save_path = os.path.join(save_sample_dir, "fake-z{}.ply".format(j)) + write_ply((rec.detach().cpu()).transpose(0, 1).numpy(), save_path) + + +def main(opt): + exp_id = os.path.splitext(os.path.basename(__file__))[0] + dir_id = os.path.dirname(__file__) + output_dir = get_output_dir(dir_id, exp_id) + copy_source(__file__, output_dir) + logger = setup_logging(output_dir) + + outf_syn, = setup_output_subdirs(output_dir, 'syn') + + betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) + netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) + + if opt.cuda: + netE.cuda() + + def _transform_(m): + return nn.parallel.DataParallel(m) + + netE = netE.cuda() + netE.multi_gpu_wrapper(_transform_) + + netE.eval() + + ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] + + with torch.no_grad(): + for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): + + opt.netE = ckpt + logger.info("Resume Path:%s" % opt.netE) + + resumed_param = torch.load(opt.netE) + netE.load_state_dict(resumed_param['model_state']) + + + + if opt.generate_multimodal: + + generate_multimodal(opt, netE, outf_syn, logger) + + +def parse_args(): + + parser = argparse.ArgumentParser() + parser.add_argument('--data_root', default='/viscam/u/alexzhou907/01DATA/partnet', + help='input batch size') + + parser.add_argument('--classes', default='Table') + + parser.add_argument('--batch_size', type=int, default=64, help='input batch size') + parser.add_argument('--workers', type=int, default=16, help='workers') + parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + + parser.add_argument('--eval_recon_mvr', default=False) + parser.add_argument('--generate_multimodal', default=True) + parser.add_argument('--eval_saved', default=False) + + parser.add_argument('--nc', default=3) + parser.add_argument('--npoints', default=2048) + parser.add_argument('--svpoints', default=1024) + '''model''' + parser.add_argument('--beta_start', default=0.0001) + parser.add_argument('--beta_end', default=0.02) + parser.add_argument('--schedule_type', default='linear') + parser.add_argument('--time_num', default=1000) + + #params + parser.add_argument('--attention', default=True) + parser.add_argument('--dropout', default=0.1) + parser.add_argument('--embed_dim', type=int, default=64) + parser.add_argument('--loss_type', default='mse') + parser.add_argument('--model_mean_type', default='eps') + parser.add_argument('--model_var_type', default='fixedsmall') + + + # constrain function + parser.add_argument('--constrain_eps', default=0.2) + parser.add_argument('--constrain_steps', type=int, default=1) + + parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/epn3d_chair', help="path to netE (to continue training)") + + '''eval''' + + parser.add_argument('--eval_path', + default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth') + + parser.add_argument('--manualSeed', default=42, type=int, help='random seed') + + parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') + + opt = parser.parse_args() + + if torch.cuda.is_available(): + opt.cuda = True + else: + opt.cuda = False + + return opt +if __name__ == '__main__': + opt = parse_args() + + main(opt) diff --git a/shape_completion/test_plane.py b/shape_completion/test_plane.py new file mode 100644 index 0000000..5981c3e --- /dev/null +++ b/shape_completion/test_plane.py @@ -0,0 +1,681 @@ + +from pprint import pprint +from metrics.evaluation_metrics import compute_all_metrics, EMD_CD + +import torch.nn as nn +import torch.utils.data + +import argparse +from torch.distributions import Normal + +from utils.file_utils import * +from utils.mitsuba_renderer import write_to_xml_batch +from model.pvcnn_completion import PVCNN2Base + +from datasets.shapenet_data_pc import ShapeNet15kPointClouds +from datasets.shapenet_data_sv import * +''' +models +''' +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + KL divergence between normal distributions parameterized by mean and log-variance. + """ + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + (mean1 - mean2)**2 * torch.exp(-logvar2)) + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + # Assumes data is integers [0, 1] + assert x.shape == means.shape == log_scales.shape + px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 0.5) + cdf_plus = px0.cdf(plus_in) + min_in = inv_stdv * (centered_x - .5) + cdf_min = px0.cdf(min_in) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where( + x < 0.001, log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + assert log_probs.shape == x.shape + return log_probs + + + +class GaussianDiffusion: + def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): + self.loss_type = loss_type + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + assert isinstance(betas, np.ndarray) + self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy + assert (betas > 0).all() and (betas <= 1).all() + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.sv_points = sv_points + # initialize twice the actual length so we can keep running for eval + # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) + + alphas = 1. - betas + alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + + self.betas = torch.from_numpy(betas).float() + self.alphas_cumprod = alphas_cumprod.float() + self.alphas_cumprod_prev = alphas_cumprod_prev.float() + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + + betas = torch.from_numpy(betas).float() + alphas = torch.from_numpy(alphas).float() + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance = posterior_variance + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) + self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + + @staticmethod + def _extract(a, t, x_shape): + """ + Extract some coefficients at specified timesteps, + then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + bs, = t.shape + assert x_shape[0] == bs + out = torch.gather(a, 0, t) + assert out.shape == torch.Size([bs]) + return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) + + + + def q_mean_variance(self, x_start, t): + mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data (t == 0 means diffused for 1 step) + """ + if noise is None: + noise = torch.randn(x_start.shape, device=x_start.device) + assert noise.shape == x_start.shape + return ( + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + ) + + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) + posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + + model_output = denoise_fn(data, t)[:,:,self.sv_points:] + + + if self.model_var_type in ['fixedsmall', 'fixedlarge']: + # below: only log_variance is used in the KL computations + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood + 'fixedlarge': (self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), + 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + }[self.model_var_type] + model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) + model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) + else: + raise NotImplementedError(self.model_var_type) + + if self.model_mean_type == 'eps': + x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) + + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) + else: + raise NotImplementedError(self.loss_type) + + + assert model_mean.shape == x_recon.shape + assert model_variance.shape == model_log_variance.shape + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - + self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + ) + + ''' samples ''' + + def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): + """ + Sample from the model + """ + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, + return_pred_xstart=True) + noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) + + # no noise when t == 0 + nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) + + sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise + sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) + return (sample, pred_xstart) if return_pred_xstart else sample + + + def p_sample_loop(self, partial_x, denoise_fn, shape, device, + noise_fn=torch.randn, clip_denoised=True, keep_running=False): + """ + Generate samples + keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps + + """ + + assert isinstance(shape, (tuple, list)) + img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) + for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, return_pred_xstart=False) + + assert img_t[:,:,self.sv_points:].shape == shape + return img_t + + def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, + noise_fn=torch.randn,clip_denoised=True, keep_running=False): + """ + Generate samples, returning intermediate images + Useful for visualizing how denoised images evolve over time + Args: + repeat_noise_steps (int): Number of denoising timesteps in which the same noise + is used across the batch. If >= 0, the initial noise is the same for all batch elemements. + """ + assert isinstance(shape, (tuple, list)) + + total_steps = self.num_timesteps if not keep_running else len(self.betas) + + img_t = noise_fn(size=shape, dtype=torch.float, device=device) + imgs = [img_t] + for t in reversed(range(0,total_steps)): + + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False) + if t % freq == 0 or t == total_steps-1: + imgs.append(img_t) + + assert imgs[-1].shape == shape + return imgs + + '''losses''' + + def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( + denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) + + kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) + kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) + + return (kl, pred_xstart) if return_pred_xstart else kl + + def p_losses(self, denoise_fn, data_start, t, noise=None): + """ + Training loss calculation + """ + B, D, N = data_start.shape + assert t.shape == torch.Size([B]) + + if noise is None: + noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) + + data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) + + if self.loss_type == 'mse': + # predict the noise instead of x_start. seems to be weighted naturally like SNR + eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] + + losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) + elif self.loss_type == 'kl': + losses = self._vb_terms_bpd( + denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, + return_pred_xstart=False) + else: + raise NotImplementedError(self.loss_type) + + assert losses.shape == torch.Size([B]) + return losses + + '''debug''' + + def _prior_bpd(self, x_start): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, + mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) + assert kl_prior.shape == x_start.shape + return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) + + def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + + vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) + for t in reversed(range(T)): + + t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) + # Calculate VLB term at the current timestep + data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) + new_vals_b, pred_xstart = self._vb_terms_bpd( + denoise_fn, data_start=x_start, data_t=data_t, t=t_b, + clip_denoised=clip_denoised, return_pred_xstart=True) + # MSE for progressive prediction loss + assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape + new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) + assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) + # Insert the calculated term into the tensor of all terms + mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() + vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt + mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt + assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) + + prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) + total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b + assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ + total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) + return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() + + +class PVCNN2(PVCNN2Base): + sa_blocks = [ + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, + voxel_resolution_multiplier=1): + super().__init__( + num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, + dropout=dropout, extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + + +class Model(nn.Module): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + super(Model, self).__init__() + self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) + + self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, + dropout=args.dropout, extra_feature_channels=0) + + def prior_kl(self, x0): + return self.diffusion._prior_bpd(x0) + + def all_kl(self, x0, clip_denoised=True): + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + + return { + 'total_bpd_b': total_bpd_b, + 'terms_bpd': vals_bt, + 'prior_bpd_b': prior_bpd_b, + 'mse_bt':mse_bt + } + + + def _denoise(self, data, t): + B, D,N= data.shape + assert data.dtype == torch.float + assert t.shape == torch.Size([B]) and t.dtype == torch.int64 + + out = self.model(data, t) + + return out + + def get_loss_iter(self, data, noises=None): + B, D, N = data.shape + t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) + + if noises is not None: + noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + + losses = self.diffusion.p_losses( + denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + assert losses.shape == t.shape == torch.Size([B]) + return losses + + def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, + clip_denoised=True, + keep_running=False): + return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, + clip_denoised=clip_denoised, + keep_running=keep_running) + + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def multi_gpu_wrapper(self, f): + self.model = f(self.model) + +def get_betas(schedule_type, b_start, b_end, time_num): + if schedule_type == 'linear': + betas = np.linspace(b_start, b_end, time_num) + elif schedule_type == 'warm0.1': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.1) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.2': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.2) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.5': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.5) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + else: + raise NotImplementedError(schedule_type) + return betas + + + +############################################################################# +def get_svr_dataset(pc_dataroot, mesh_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, + cache=os.path.join(mesh_root, '../cache'), split='val', + categories=category, + radius=3, elev=0, azim=0, img_size=512, focal_length=1000, + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return te_dataset + +def get_mvr_dataset(pc_dataroot, views_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, + cache=os.path.join(pc_dataroot, '../cache'), split='val', + categories=category, + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return te_dataset + + +def evaluate_recon_mvr(opt, netE, save_dir, logger): + test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, + opt.npoints, opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + ref = [] + samples = [] + images = [] + masked = [] + k = 0 + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + gt_all = data['test_points'] + x_all = data['sv_points'] + img_all = data['image'] + + + B,V,N,C = x_all.shape + gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1) + + x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous() + img = img_all.reshape(B * V, *img_all.shape[2:]) + + m, s = data['mean'].float(), data['std'].float() + + recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', + clip_denoised=False).detach().cpu() + + recon = recon.transpose(1, 2).contiguous() + x = x.transpose(1, 2).contiguous() + + + x_adj = x.reshape(B,V,N,C)* s + m + recon_adj = recon.reshape(B,V,N,C)* s + m + img = img.reshape(B,V,*img.shape[1:]) + + ref.append( gt_all * s + m) + masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) + samples.append(recon_adj) + images.append(img) + + ref_pcs = torch.cat(ref, dim=0) + sample_pcs = torch.cat(samples, dim=0) + images = torch.cat(images, dim=0) + masked = torch.cat(masked, dim=0) + + B, V, N, C = ref_pcs.shape + + + torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth')) + # torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) + torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) + # Compute metrics + results = EMD_CD(sample_pcs.reshape(B*V, N, C), + ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) + + results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} + + pprint({key: val.mean().item() for key, val in results.items()}) + logger.info({key: val.mean().item() for key, val in results.items()}) + + results['pc'] = sample_pcs + torch.save(results, os.path.join(save_dir, 'ours_results.pth')) + + + +def generate_multimodal(opt, netE, save_dir, logger): + test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv, + opt.npoints, opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + gt_all = data['test_points'] + x_all = data['sv_points'] + img_all = data['image'] + + Path(os.path.join(save_dir, 'x_%03d'%k)).mkdir(exist_ok=True, parents=True) + Path(os.path.join(save_dir, 'x_ply_%03d' % k)).mkdir(exist_ok=True, parents=True) + + for v in range(5): + x = x_all.transpose(1, 2).contiguous() + img = img_all + + m, s = data['mean'].float(), data['std'].float() + + recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', + clip_denoised=False).detach().cpu() + + recon = recon.transpose(1, 2).contiguous() + x = x.transpose(1, 2).contiguous() + + for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))): + + im = np.fliplr(np.flipud(d[-1])) + plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray') + write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + + export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + + +def main(opt): + exp_id = os.path.splitext(os.path.basename(__file__))[0] + dir_id = os.path.dirname(__file__) + output_dir = get_output_dir(dir_id, exp_id) + copy_source(__file__, output_dir) + logger = setup_logging(output_dir) + + outf_syn, = setup_output_subdirs(output_dir, 'syn') + + betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) + netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) + + if opt.cuda: + netE.cuda() + + def _transform_(m): + return nn.parallel.DataParallel(m) + + netE = netE.cuda() + netE.multi_gpu_wrapper(_transform_) + + netE.eval() + + ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] + + with torch.no_grad(): + for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): + + opt.netE = ckpt + logger.info("Resume Path:%s" % opt.netE) + + resumed_param = torch.load(opt.netE) + netE.load_state_dict(resumed_param['model_state']) + + + if opt.eval_recon_mvr: + # Evaluate generation + evaluate_recon_mvr(opt, netE, outf_syn, logger) + + if opt.generate_multimodal: + + generate_multimodal(opt, netE, outf_syn, logger) + + +def parse_args(): + + parser = argparse.ArgumentParser() + parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') + parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', + help='input batch size') + parser.add_argument('--classes', default=['airplane']) + + parser.add_argument('--batch_size', type=int, default=8, help='input batch size') + parser.add_argument('--workers', type=int, default=16, help='workers') + parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + + parser.add_argument('--eval_recon_mvr', default=True) + parser.add_argument('--generate_multimodal', default=False) + + parser.add_argument('--nc', default=3) + parser.add_argument('--npoints', default=2048) + parser.add_argument('--svpoints', default=200) + '''model''' + parser.add_argument('--beta_start', default=0.0001) + parser.add_argument('--beta_end', default=0.02) + parser.add_argument('--schedule_type', default='linear') + parser.add_argument('--time_num', default=1000) + + #params + parser.add_argument('--attention', default=True) + parser.add_argument('--dropout', default=0.1) + parser.add_argument('--embed_dim', type=int, default=64) + parser.add_argument('--loss_type', default='mse') + parser.add_argument('--model_mean_type', default='eps') + parser.add_argument('--model_var_type', default='fixedsmall') + + + # constrain function + parser.add_argument('--constrain_eps', default=0.2) + parser.add_argument('--constrain_steps', type=int, default=1) + + parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/airplane_ckpt/', help="path to netE (to continue training)") + + '''eval''' + + parser.add_argument('--eval_path', + default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth') + + parser.add_argument('--manualSeed', default=42, type=int, help='random seed') + + parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') + + opt = parser.parse_args() + + if torch.cuda.is_available(): + opt.cuda = True + else: + opt.cuda = False + + return opt +if __name__ == '__main__': + opt = parse_args() + set_seed(opt) + + main(opt) diff --git a/shape_completion/test_table.py b/shape_completion/test_table.py new file mode 100644 index 0000000..8fb92be --- /dev/null +++ b/shape_completion/test_table.py @@ -0,0 +1,764 @@ + +from pprint import pprint +from metrics.evaluation_metrics import compute_all_metrics, EMD_CD + +import torch.nn as nn +import torch.utils.data + +import argparse +from torch.distributions import Normal +from utils.visualize import * +from utils.file_utils import * +from utils.mitsuba_renderer import write_to_xml_batch +from model.pvcnn_completion import PVCNN2Base + +from datasets.shapenet_data_pc import ShapeNet15kPointClouds +from datasets.shapenet_data_sv import * +''' +models +''' +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + KL divergence between normal distributions parameterized by mean and log-variance. + """ + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + (mean1 - mean2)**2 * torch.exp(-logvar2)) + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + # Assumes data is integers [0, 1] + assert x.shape == means.shape == log_scales.shape + px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 0.5) + cdf_plus = px0.cdf(plus_in) + min_in = inv_stdv * (centered_x - .5) + cdf_min = px0.cdf(min_in) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where( + x < 0.001, log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + assert log_probs.shape == x.shape + return log_probs + + + +class GaussianDiffusion: + def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): + self.loss_type = loss_type + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + assert isinstance(betas, np.ndarray) + self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy + assert (betas > 0).all() and (betas <= 1).all() + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.sv_points = sv_points + # initialize twice the actual length so we can keep running for eval + # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) + + alphas = 1. - betas + alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + + self.betas = torch.from_numpy(betas).float() + self.alphas_cumprod = alphas_cumprod.float() + self.alphas_cumprod_prev = alphas_cumprod_prev.float() + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + + betas = torch.from_numpy(betas).float() + alphas = torch.from_numpy(alphas).float() + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance = posterior_variance + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) + self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + + @staticmethod + def _extract(a, t, x_shape): + """ + Extract some coefficients at specified timesteps, + then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + bs, = t.shape + assert x_shape[0] == bs + out = torch.gather(a, 0, t) + assert out.shape == torch.Size([bs]) + return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) + + + + def q_mean_variance(self, x_start, t): + mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data (t == 0 means diffused for 1 step) + """ + if noise is None: + noise = torch.randn(x_start.shape, device=x_start.device) + assert noise.shape == x_start.shape + return ( + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + ) + + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) + posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + + model_output = denoise_fn(data, t)[:,:,self.sv_points:] + + + if self.model_var_type in ['fixedsmall', 'fixedlarge']: + # below: only log_variance is used in the KL computations + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood + 'fixedlarge': (self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), + 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + }[self.model_var_type] + model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) + model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) + else: + raise NotImplementedError(self.model_var_type) + + if self.model_mean_type == 'eps': + x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) + + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) + else: + raise NotImplementedError(self.loss_type) + + + assert model_mean.shape == x_recon.shape + assert model_variance.shape == model_log_variance.shape + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - + self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + ) + + ''' samples ''' + + def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): + """ + Sample from the model + """ + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, + return_pred_xstart=True) + noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) + + # no noise when t == 0 + nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) + + sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise + sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) + return (sample, pred_xstart) if return_pred_xstart else sample + + + def p_sample_loop(self, partial_x, denoise_fn, shape, device, + noise_fn=torch.randn, clip_denoised=True, keep_running=False): + """ + Generate samples + keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps + + """ + + assert isinstance(shape, (tuple, list)) + img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) + for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, return_pred_xstart=False) + + assert img_t[:,:,self.sv_points:].shape == shape + return img_t + + def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, + noise_fn=torch.randn,clip_denoised=True, keep_running=False): + """ + Generate samples, returning intermediate images + Useful for visualizing how denoised images evolve over time + Args: + repeat_noise_steps (int): Number of denoising timesteps in which the same noise + is used across the batch. If >= 0, the initial noise is the same for all batch elemements. + """ + assert isinstance(shape, (tuple, list)) + + total_steps = self.num_timesteps if not keep_running else len(self.betas) + + img_t = noise_fn(size=shape, dtype=torch.float, device=device) + imgs = [img_t] + for t in reversed(range(0,total_steps)): + + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False) + if t % freq == 0 or t == total_steps-1: + imgs.append(img_t) + + assert imgs[-1].shape == shape + return imgs + + '''losses''' + + def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( + denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) + + kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) + kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) + + return (kl, pred_xstart) if return_pred_xstart else kl + + def p_losses(self, denoise_fn, data_start, t, noise=None): + """ + Training loss calculation + """ + B, D, N = data_start.shape + assert t.shape == torch.Size([B]) + + if noise is None: + noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) + + data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) + + if self.loss_type == 'mse': + # predict the noise instead of x_start. seems to be weighted naturally like SNR + eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] + + losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) + elif self.loss_type == 'kl': + losses = self._vb_terms_bpd( + denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, + return_pred_xstart=False) + else: + raise NotImplementedError(self.loss_type) + + assert losses.shape == torch.Size([B]) + return losses + + '''debug''' + + def _prior_bpd(self, x_start): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, + mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) + assert kl_prior.shape == x_start.shape + return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) + + def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + + vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) + for t in reversed(range(T)): + + t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) + # Calculate VLB term at the current timestep + data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) + new_vals_b, pred_xstart = self._vb_terms_bpd( + denoise_fn, data_start=x_start, data_t=data_t, t=t_b, + clip_denoised=clip_denoised, return_pred_xstart=True) + # MSE for progressive prediction loss + assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape + new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) + assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) + # Insert the calculated term into the tensor of all terms + mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() + vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt + mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt + assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) + + prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) + total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b + assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ + total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) + return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() + + +class PVCNN2(PVCNN2Base): + sa_blocks = [ + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, + voxel_resolution_multiplier=1): + super().__init__( + num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, + dropout=dropout, extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + + +class Model(nn.Module): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + super(Model, self).__init__() + self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) + + self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, + dropout=args.dropout, extra_feature_channels=0) + + def prior_kl(self, x0): + return self.diffusion._prior_bpd(x0) + + def all_kl(self, x0, clip_denoised=True): + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + + return { + 'total_bpd_b': total_bpd_b, + 'terms_bpd': vals_bt, + 'prior_bpd_b': prior_bpd_b, + 'mse_bt':mse_bt + } + + + def _denoise(self, data, t): + B, D,N= data.shape + assert data.dtype == torch.float + assert t.shape == torch.Size([B]) and t.dtype == torch.int64 + + out = self.model(data, t) + + return out + + def get_loss_iter(self, data, noises=None): + B, D, N = data.shape + t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) + + if noises is not None: + noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + + losses = self.diffusion.p_losses( + denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + assert losses.shape == t.shape == torch.Size([B]) + return losses + + def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, + clip_denoised=True, + keep_running=False): + return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, + clip_denoised=clip_denoised, + keep_running=keep_running) + + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def multi_gpu_wrapper(self, f): + self.model = f(self.model) + +def get_betas(schedule_type, b_start, b_end, time_num): + if schedule_type == 'linear': + betas = np.linspace(b_start, b_end, time_num) + elif schedule_type == 'warm0.1': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.1) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.2': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.2) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.5': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.5) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + else: + raise NotImplementedError(schedule_type) + return betas + + + +############################################################################# +def get_pc_dataset(pc_dataroot, mesh_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + return tr_dataset + +def get_svr_dataset(pc_dataroot, mesh_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, + cache=os.path.join(mesh_root, '../cache'), split='val', + categories=category, + radius=3, elev=-89, azim=180, img_size=512, focal_length=1000, + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return te_dataset + +def get_mvr_dataset(pc_dataroot, views_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, + cache=os.path.join(pc_dataroot, '../cache'), split='val', + categories=category, + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return te_dataset + + +def evaluate_recon_mvr(opt, netE, save_dir, logger): + test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, + opt.npoints, opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + ref = [] + samples = [] + images = [] + masked = [] + k = 0 + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + gt_all = data['test_points'] + x_all = data['sv_points'] + # img_all = data['image'] + + + B,V,N,C = x_all.shape + gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1) + + x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous() + # img = img_all.reshape(B * V, *img_all.shape[2:]) + + m, s = data['mean'].float(), data['std'].float() + + recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', + clip_denoised=False).detach().cpu() + + recon = recon.transpose(1, 2).contiguous() + x = x.transpose(1, 2).contiguous() + + + x_adj = x.reshape(B,V,N,C)* s + m + recon_adj = recon.reshape(B,V,N,C)* s + m + # img = img.reshape(B,V,*img.shape[1:]) + + ref.append( gt_all * s + m) + masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) + samples.append(recon_adj) + # images.append(img) + + ref_pcs = torch.cat(ref, dim=0) + sample_pcs = torch.cat(samples, dim=0) + # images = torch.cat(images, dim=0) + masked = torch.cat(masked, dim=0) + + B, V, N, C = ref_pcs.shape + + + torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth')) + # torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) + torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) + # Compute metrics + results = EMD_CD(sample_pcs.reshape(B*V, N, C), + ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) + + results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} + + pprint({key: val.mean().item() for key, val in results.items()}) + logger.info({key: val.mean().item() for key, val in results.items()}) + + results['pc'] = sample_pcs + torch.save(results, os.path.join(save_dir, 'ours_results.pth')) + + del ref_pcs, masked, results + +def evaluate_saved(opt, netE, save_dir, logger): + ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn' + + gt_pth = ours_base + '/recon_gt.pth' + ours_pth = ours_base + '/ours_results.pth' + gt = torch.load(gt_pth).permute(1,0,2,3) + ours = torch.load(ours_pth)['pc'].permute(1,0,2,3) + + all_res = {} + for i, (gt_, ours_) in enumerate(zip(gt, ours)): + results = compute_all_metrics(gt_, ours_, opt.batch_size) + + for key, val in results.items(): + if i == 0: + all_res[key] = val + else: + all_res[key] += val + pprint(results) + for key, val in all_res.items(): + all_res[key] = val / gt.shape[0] + + pprint({key: val.mean().item() for key, val in all_res.items()}) + + +def generate_multimodal(opt, netE, save_dir, logger): + test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv, + opt.npoints, opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + gt_all = data['test_points'] + x_all = data['sv_points'] + img_all = data['image'] + + for v in range(6): + x = x_all.transpose(1, 2).contiguous() + img = img_all + + m, s = data['mean'].float(), data['std'].float() + + recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', + clip_denoised=False).detach().cpu() + + recon = recon.transpose(1, 2).contiguous() + x = x.transpose(1, 2).contiguous() + + for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))): + + im = np.fliplr(np.flipud(d[-1])) + plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray') + write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy(), cat='chair') + + export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + + +def redwood_demo(opt, netE, save_dir, logger): + import open3d as o3d + pth = "/viscam/u/alexzhou907/01DATA/redwood/01605_sample_1.ply" + pth_gt = "/viscam/u/alexzhou907/01DATA/redwood/01605_pc_gt.ply" + + points = np.asarray(o3d.io.read_point_cloud(pth).points) + + gt_points = np.asarray(o3d.io.read_point_cloud(pth_gt).points) + + np.save('gt.npy', gt_points[np.random.choice(gt_points.shape[0], size=opt.npoints, replace=False)]) + + write_to_xml_batch(os.path.join(save_dir, 'xml_gt'), + gt_points[np.random.choice(gt_points.shape[0], size=opt.npoints, replace=False)][None], cat='table') + + test_dataset = get_pc_dataset(opt.dataroot_pc, opt.dataroot_sv, + opt.npoints, opt.classes) + + m, s = torch.from_numpy(test_dataset[0]['mean']).float(), torch.from_numpy(test_dataset[0]['std']).float() + + x = torch.from_numpy(points[np.random.choice(points.shape[0], size=opt.npoints, replace=False)]).float() + + x = (x-m)/s + + + x = x[None].transpose(1,2).cuda() + + res = [] + for k in range(20): + recon = netE.gen_samples(x[:,:,:opt.svpoints], x[:,:,opt.svpoints:].shape, 'cuda', + clip_denoised=False).detach().cpu() + recon = recon.transpose(1, 2).contiguous() + recon = recon * s+ m + res.append(recon) + res = torch.cat(res, dim=0) + + write_to_xml_batch(os.path.join(save_dir, 'xml'), + (res).numpy(), cat='table') + + export_to_pc_batch(os.path.join(save_dir, 'ply'), + (res).numpy()) + + torch.save(res, os.path.join(save_dir, 'redwood_demo.pth')) + + exit() + +def main(opt): + exp_id = os.path.splitext(os.path.basename(__file__))[0] + dir_id = os.path.dirname(__file__) + output_dir = get_output_dir(dir_id, exp_id) + copy_source(__file__, output_dir) + logger = setup_logging(output_dir) + + outf_syn, = setup_output_subdirs(output_dir, 'syn') + + betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) + netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) + + if opt.cuda: + netE.cuda() + + def _transform_(m): + return nn.parallel.DataParallel(m) + + netE = netE.cuda() + netE.multi_gpu_wrapper(_transform_) + + netE.eval() + + ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] + + with torch.no_grad(): + for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): + + opt.netE = ckpt + logger.info("Resume Path:%s" % opt.netE) + + resumed_param = torch.load(opt.netE) + netE.load_state_dict(resumed_param['model_state']) + + + if opt.eval_recon_mvr: + # Evaluate generation + evaluate_recon_mvr(opt, netE, outf_syn, logger) + + if opt.generate_multimodal: + + generate_multimodal(opt, netE, outf_syn, logger) + + if opt.eval_saved: + evaluate_saved(opt, netE, outf_syn, logger) + + if opt.eval_redwood: + redwood_demo(opt, netE, outf_syn, logger) + +def parse_args(): + + parser = argparse.ArgumentParser() + parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') + parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', + help='input batch size') + parser.add_argument('--classes', default=['table']) + + parser.add_argument('--batch_size', type=int, default=1, help='input batch size') + parser.add_argument('--workers', type=int, default=16, help='workers') + parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + + parser.add_argument('--eval_recon_mvr', default=False) + parser.add_argument('--generate_multimodal', default=False) + parser.add_argument('--eval_saved', default=False) + parser.add_argument('--eval_redwood', default=True) + + parser.add_argument('--nc', default=3) + parser.add_argument('--npoints', default=2048) + parser.add_argument('--svpoints', default=200) + '''model''' + parser.add_argument('--beta_start', default=0.0001) + parser.add_argument('--beta_end', default=0.02) + parser.add_argument('--schedule_type', default='linear') + parser.add_argument('--time_num', default=1000) + + #params + parser.add_argument('--attention', default=True) + parser.add_argument('--dropout', default=0.1) + parser.add_argument('--embed_dim', type=int, default=64) + parser.add_argument('--loss_type', default='mse') + parser.add_argument('--model_mean_type', default='eps') + parser.add_argument('--model_var_type', default='fixedsmall') + + + # constrain function + parser.add_argument('--constrain_eps', default=0.2) + parser.add_argument('--constrain_steps', type=int, default=1) + + parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/9_res32_pc_table_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-12-16-14-09-50', help="path to netE (to continue training)") + + '''eval''' + + parser.add_argument('--eval_path', + default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth') + + parser.add_argument('--manualSeed', default=42, type=int, help='random seed') + + parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') + + opt = parser.parse_args() + + if torch.cuda.is_available(): + opt.cuda = True + else: + opt.cuda = False + + return opt +if __name__ == '__main__': + opt = parse_args() + + main(opt) diff --git a/shape_completion/train_completion.py b/shape_completion/train_completion.py new file mode 100644 index 0000000..6c3aae5 --- /dev/null +++ b/shape_completion/train_completion.py @@ -0,0 +1,841 @@ +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +import torch.utils.data + +import argparse +from model.unet import get_model +from torch.distributions import Normal + +from utils.file_utils import * +from utils.visualize import * +from model.pvcnn_completion import PVCNN2Base +import torch.distributed as dist +from datasets.shapenet_data_pc import ShapeNet15kPointClouds +from datasets.shapenet_data_sv import ShapeNet_Multiview_Points +''' +some utils +''' +def rotation_matrix(axis, theta): + """ + Return the rotation matrix associated with counterclockwise rotation about + the given axis by theta radians. + """ + axis = np.asarray(axis) + axis = axis / np.sqrt(np.dot(axis, axis)) + a = np.cos(theta / 2.0) + b, c, d = -axis * np.sin(theta / 2.0) + aa, bb, cc, dd = a * a, b * b, c * c, d * d + bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d + return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], + [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], + [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) + +def rotate(vertices, faces): + ''' + vertices: [numpoints, 3] + ''' + M = rotation_matrix([0, 1, 0], np.pi / 2).transpose() + N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose() + K = rotation_matrix([0, 0, 1], np.pi).transpose() + + v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]] + return v, f + +def norm(v, f): + v = (v - v.min())/(v.max() - v.min()) - 0.5 + + return v, f + +def getGradNorm(net): + pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters())) + gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters())) + return pNorm, gradNorm + + +def weights_init(m): + """ + xavier initialization + """ + classname = m.__class__.__name__ + if classname.find('Conv') != -1 and m.weight is not None: + torch.nn.init.xavier_normal_(m.weight) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_() + m.bias.data.fill_(0) + +''' +models +''' +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + KL divergence between normal distributions parameterized by mean and log-variance. + """ + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + (mean1 - mean2)**2 * torch.exp(-logvar2)) + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + # Assumes data is integers [0, 1] + assert x.shape == means.shape == log_scales.shape + px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 0.5) + cdf_plus = px0.cdf(plus_in) + min_in = inv_stdv * (centered_x - .5) + cdf_min = px0.cdf(min_in) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where( + x < 0.001, log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + assert log_probs.shape == x.shape + return log_probs + +class GaussianDiffusion: + def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): + self.loss_type = loss_type + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + assert isinstance(betas, np.ndarray) + self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy + assert (betas > 0).all() and (betas <= 1).all() + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.sv_points = sv_points + # initialize twice the actual length so we can keep running for eval + # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) + + alphas = 1. - betas + alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + + self.betas = torch.from_numpy(betas).float() + self.alphas_cumprod = alphas_cumprod.float() + self.alphas_cumprod_prev = alphas_cumprod_prev.float() + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + + betas = torch.from_numpy(betas).float() + alphas = torch.from_numpy(alphas).float() + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance = posterior_variance + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) + self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + + @staticmethod + def _extract(a, t, x_shape): + """ + Extract some coefficients at specified timesteps, + then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + bs, = t.shape + assert x_shape[0] == bs + out = torch.gather(a, 0, t) + assert out.shape == torch.Size([bs]) + return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) + + + + def q_mean_variance(self, x_start, t): + mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data (t == 0 means diffused for 1 step) + """ + if noise is None: + noise = torch.randn(x_start.shape, device=x_start.device) + assert noise.shape == x_start.shape + return ( + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + ) + + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) + posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + + model_output = denoise_fn(data, t)[:,:,self.sv_points:] + + + if self.model_var_type in ['fixedsmall', 'fixedlarge']: + # below: only log_variance is used in the KL computations + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood + 'fixedlarge': (self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), + 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + }[self.model_var_type] + model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) + model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) + else: + raise NotImplementedError(self.model_var_type) + + if self.model_mean_type == 'eps': + x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) + + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) + else: + raise NotImplementedError(self.loss_type) + + + assert model_mean.shape == x_recon.shape + assert model_variance.shape == model_log_variance.shape + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - + self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + ) + + ''' samples ''' + + def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): + """ + Sample from the model + """ + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, + return_pred_xstart=True) + noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) + + # no noise when t == 0 + nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) + + sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise + sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) + return (sample, pred_xstart) if return_pred_xstart else sample + + + def p_sample_loop(self, partial_x, denoise_fn, shape, device, + noise_fn=torch.randn, clip_denoised=True, keep_running=False): + """ + Generate samples + keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps + + """ + + assert isinstance(shape, (tuple, list)) + img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) + for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, return_pred_xstart=False) + + assert img_t[:,:,self.sv_points:].shape == shape + return img_t + + def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, + noise_fn=torch.randn,clip_denoised=True, keep_running=False): + """ + Generate samples, returning intermediate images + Useful for visualizing how denoised images evolve over time + Args: + repeat_noise_steps (int): Number of denoising timesteps in which the same noise + is used across the batch. If >= 0, the initial noise is the same for all batch elemements. + """ + assert isinstance(shape, (tuple, list)) + + total_steps = self.num_timesteps if not keep_running else len(self.betas) + + img_t = noise_fn(size=shape, dtype=torch.float, device=device) + imgs = [img_t] + for t in reversed(range(0,total_steps)): + + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False) + if t % freq == 0 or t == total_steps-1: + imgs.append(img_t) + + assert imgs[-1].shape == shape + return imgs + + '''losses''' + + def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( + denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) + + kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) + kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) + + return (kl, pred_xstart) if return_pred_xstart else kl + + def p_losses(self, denoise_fn, data_start, t, noise=None): + """ + Training loss calculation + """ + B, D, N = data_start.shape + assert t.shape == torch.Size([B]) + + if noise is None: + noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) + + data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) + + if self.loss_type == 'mse': + # predict the noise instead of x_start. seems to be weighted naturally like SNR + eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] + + losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) + elif self.loss_type == 'kl': + losses = self._vb_terms_bpd( + denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, + return_pred_xstart=False) + else: + raise NotImplementedError(self.loss_type) + + assert losses.shape == torch.Size([B]) + return losses + + '''debug''' + + def _prior_bpd(self, x_start): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, + mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) + assert kl_prior.shape == x_start.shape + return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) + + def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + + vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) + for t in reversed(range(T)): + + t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) + # Calculate VLB term at the current timestep + data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) + new_vals_b, pred_xstart = self._vb_terms_bpd( + denoise_fn, data_start=x_start, data_t=data_t, t=t_b, + clip_denoised=clip_denoised, return_pred_xstart=True) + # MSE for progressive prediction loss + assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape + new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) + assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) + # Insert the calculated term into the tensor of all terms + mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() + vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt + mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt + assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) + + prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) + total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b + assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ + total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) + return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() + + +class PVCNN2(PVCNN2Base): + sa_blocks = [ + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, + voxel_resolution_multiplier=1): + super().__init__( + num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, + dropout=dropout, extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + + +class Model(nn.Module): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + super(Model, self).__init__() + self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) + + self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, + dropout=args.dropout, extra_feature_channels=0) + + def prior_kl(self, x0): + return self.diffusion._prior_bpd(x0) + + def all_kl(self, x0, clip_denoised=True): + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + + return { + 'total_bpd_b': total_bpd_b, + 'terms_bpd': vals_bt, + 'prior_bpd_b': prior_bpd_b, + 'mse_bt':mse_bt + } + + + def _denoise(self, data, t): + B, D,N= data.shape + assert data.dtype == torch.float + assert t.shape == torch.Size([B]) and t.dtype == torch.int64 + + out = self.model(data, t) + + return out + + def get_loss_iter(self, data, noises=None): + B, D, N = data.shape + t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) + + if noises is not None: + noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + + losses = self.diffusion.p_losses( + denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + assert losses.shape == t.shape == torch.Size([B]) + return losses + + def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, + clip_denoised=True, + keep_running=False): + return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, + clip_denoised=clip_denoised, + keep_running=keep_running) + + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def multi_gpu_wrapper(self, f): + self.model = f(self.model) + + +def get_betas(schedule_type, b_start, b_end, time_num): + if schedule_type == 'linear': + betas = np.linspace(b_start, b_end, time_num) + elif schedule_type == 'warm0.1': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.1) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.2': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.2) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.5': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.5) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + else: + raise NotImplementedError(schedule_type) + return betas + + +def get_dataset(dataroot_pc, dataroot_sv, npoints, svpoints, category): + tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot_pc, + categories=[category], split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + tr_dataset = ShapeNet_Multiview_Points(root_pc=dataroot_pc, root_views=dataroot_sv, + cache=os.path.join(dataroot_pc, '../cache'), split='train', + categories=[category], + npoints=npoints, sv_samples=svpoints, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return tr_dataset + + +def get_dataloader(opt, train_dataset, test_dataset=None): + + if opt.distribution_type == 'multi': + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + num_replicas=opt.world_size, + rank=opt.rank + ) + if test_dataset is not None: + test_sampler = torch.utils.data.distributed.DistributedSampler( + test_dataset, + num_replicas=opt.world_size, + rank=opt.rank + ) + else: + test_sampler = None + else: + train_sampler = None + test_sampler = None + + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler, + shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True) + + if test_dataset is not None: + test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + else: + test_dataloader = None + + return train_dataloader, test_dataloader, train_sampler, test_sampler + + +def train(gpu, opt, output_dir, noises_init): + + set_seed(opt) + logger = setup_logging(output_dir) + if opt.distribution_type == 'multi': + should_diag = gpu==0 + else: + should_diag = True + if should_diag: + outf_syn, = setup_output_subdirs(output_dir, 'syn') + + if opt.distribution_type == 'multi': + if opt.dist_url == "env://" and opt.rank == -1: + opt.rank = int(os.environ["RANK"]) + + base_rank = opt.rank * opt.ngpus_per_node + opt.rank = base_rank + gpu + dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url, + world_size=opt.world_size, rank=opt.rank) + + opt.bs = int(opt.bs / opt.ngpus_per_node) + opt.workers = 0 + + opt.saveIter = int(opt.saveIter / opt.ngpus_per_node) + opt.diagIter = int(opt.diagIter / opt.ngpus_per_node) + opt.vizIter = int(opt.vizIter / opt.ngpus_per_node) + + + ''' data ''' + train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints,opt.classes) + dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None) + + + ''' + create networks + ''' + + betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) + model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) + + if opt.distribution_type == 'multi': # Multiple processes, single GPU per process + def _transform_(m): + return nn.parallel.DistributedDataParallel( + m, device_ids=[gpu], output_device=gpu) + + torch.cuda.set_device(gpu) + model.cuda(gpu) + model.multi_gpu_wrapper(_transform_) + + + elif opt.distribution_type == 'single': + def _transform_(m): + return nn.parallel.DataParallel(m) + model = model.cuda() + model.multi_gpu_wrapper(_transform_) + + elif gpu is not None: + torch.cuda.set_device(gpu) + model = model.cuda(gpu) + else: + raise ValueError('distribution_type = multi | single | None') + + if should_diag: + logger.info(opt) + + optimizer= optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.decay, betas=(opt.beta1, 0.999)) + + lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.lr_gamma) + + if opt.model != '': + ckpt = torch.load(opt.model) + model.load_state_dict(ckpt['model_state']) + optimizer.load_state_dict(ckpt['optimizer_state']) + + if opt.model != '': + start_epoch = torch.load(opt.model)['epoch'] + 1 + else: + start_epoch = 0 + + def new_x_chain(x, num_chain): + return torch.randn(num_chain, *x.shape[1:], device=x.device) + + + + for epoch in range(start_epoch, opt.niter): + + if opt.distribution_type == 'multi': + train_sampler.set_epoch(epoch) + + lr_scheduler.step(epoch) + + for i, data in enumerate(dataloader): + randind = np.random.choice(20) #20 views + x = data['train_points'].transpose(1,2) + sv_x = data['sv_points'][:,randind].transpose(1,2) + + sv_x[:,:,opt.svpoints:] = x[:,:,opt.svpoints:] + noises_batch = noises_init[data['idx']].transpose(1,2) + + ''' + train diffusion + ''' + + if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None): + sv_x = sv_x.cuda(gpu) + noises_batch = noises_batch.cuda(gpu) + elif opt.distribution_type == 'single': + sv_x = sv_x.cuda() + noises_batch = noises_batch.cuda() + + loss = model.get_loss_iter(sv_x, noises_batch).mean() + + optimizer.zero_grad() + loss.backward() + netpNorm, netgradNorm = getGradNorm(model) + if opt.grad_clip is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) + + optimizer.step() + + + if i % opt.print_freq == 0 and should_diag: + + logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, ' + 'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} ' + .format( + epoch, opt.niter, i, len(dataloader),loss.item(), + netpNorm, netgradNorm, + )) + + + if (epoch + 1) % opt.diagIter == 0 and should_diag: + + logger.info('Diagnosis:') + + x_range = [x.min().item(), x.max().item()] + kl_stats = model.all_kl(sv_x) + logger.info(' [{:>3d}/{:>3d}] ' + 'x_range: [{:>10.4f}, {:>10.4f}], ' + 'total_bpd_b: {:>10.4f}, ' + 'terms_bpd: {:>10.4f}, ' + 'prior_bpd_b: {:>10.4f} ' + 'mse_bt: {:>10.4f} ' + .format( + epoch, opt.niter, + *x_range, + kl_stats['total_bpd_b'].item(), + kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item() + )) + + + + if (epoch + 1) % opt.vizIter == 0 and should_diag: + logger.info('Generation: eval') + + model.eval() + m, s = train_dataset.all_points_mean.reshape(1, -1), train_dataset.all_points_std.reshape(1, -1) + + with torch.no_grad(): + + x_gen_eval = model.gen_samples(sv_x[:,:,:opt.svpoints], sv_x[:,:,opt.svpoints:].shape, sv_x.device, clip_denoised=False).detach().cpu() + + + gen_stats = [x_gen_eval.mean(), x_gen_eval.std()] + gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()] + + logger.info(' [{:>3d}/{:>3d}] ' + 'eval_gen_range: [{:>10.4f}, {:>10.4f}] ' + 'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] ' + .format( + epoch, opt.niter, + *gen_eval_range, *gen_stats, + )) + + export_to_pc_batch('%s/epoch_%03d_samples_eval' % (outf_syn, epoch), + (x_gen_eval.transpose(1, 2)*s+m).numpy()*3) + + export_to_pc_batch('%s/epoch_%03d_ground_truth' % (outf_syn, epoch), + (sv_x.transpose(1, 2).detach().cpu()*s+m).numpy()*3) + + export_to_pc_batch('%s/epoch_%03d_partial' % (outf_syn, epoch), + (sv_x[:,:,:opt.svpoints].transpose(1, 2).detach().cpu()*s+m).numpy()*3) + + + model.train() + + + + + + + + if (epoch + 1) % opt.saveIter == 0: + + if should_diag: + + + save_dict = { + 'epoch': epoch, + 'model_state': model.state_dict(), + 'optimizer_state': optimizer.state_dict() + } + + torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch)) + + + if opt.distribution_type == 'multi': + dist.barrier() + map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu} + model.load_state_dict( + torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state']) + + dist.destroy_process_group() + +def main(): + opt = parse_args() + + exp_id = os.path.splitext(os.path.basename(__file__))[0] + dir_id = os.path.dirname(__file__) + output_dir = get_output_dir(dir_id, exp_id) + copy_source(__file__, output_dir) + + ''' workaround ''' + + train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints,opt.classes) + noises_init = torch.randn(len(train_dataset), opt.npoints-opt.svpoints, opt.nc) + + if opt.dist_url == "env://" and opt.world_size == -1: + opt.world_size = int(os.environ["WORLD_SIZE"]) + + if opt.distribution_type == 'multi': + opt.ngpus_per_node = torch.cuda.device_count() + opt.world_size = opt.ngpus_per_node * opt.world_size + mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init)) + else: + train(opt.gpu, opt, output_dir, noises_init) + + + +def parse_args(): + + parser = argparse.ArgumentParser() + parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') + parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', + help='input batch size') + parser.add_argument('--category', default='chair') + + parser.add_argument('--bs', type=int, default=48, help='input batch size') + parser.add_argument('--workers', type=int, default=16, help='workers') + parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + + parser.add_argument('--nc', default=3) + parser.add_argument('--npoints', default=2048) + parser.add_argument('--svpoints', default=200) + '''model''' + parser.add_argument('--beta_start', default=0.0001) + parser.add_argument('--beta_end', default=0.02) + parser.add_argument('--schedule_type', default='linear') + parser.add_argument('--time_num', default=1000) + + #params + parser.add_argument('--attention', default=True) + parser.add_argument('--dropout', default=0.1) + parser.add_argument('--embed_dim', type=int, default=64) + parser.add_argument('--loss_type', default='mse') + parser.add_argument('--model_mean_type', default='eps') + parser.add_argument('--model_var_type', default='fixedsmall') + + parser.add_argument('--lr', type=float, default=2e-4, help='learning rate for E, default=0.0002') + parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') + parser.add_argument('--decay', type=float, default=0, help='weight decay for EBM') + parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM') + parser.add_argument('--lr_gamma', type=float, default=0.998, help='lr decay for EBM') + + parser.add_argument('--model', default='', help="path to model (to continue training)") + + + '''distributed''' + parser.add_argument('--world_size', default=1, type=int, + help='Number of distributed nodes.') + parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str, + help='url used to set up distributed training') + parser.add_argument('--dist_backend', default='nccl', type=str, + help='distributed backend') + parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None], + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') + parser.add_argument('--rank', default=0, type=int, + help='node rank for distributed training') + parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use. None means using all available GPUs.') + + '''eval''' + parser.add_argument('--saveIter', default=100, help='unit: epoch') + parser.add_argument('--diagIter', default=50, help='unit: epoch') + parser.add_argument('--vizIter', default=50, help='unit: epoch') + parser.add_argument('--print_freq', default=50, help='unit: iter') + + parser.add_argument('--manualSeed', default=42, type=int, help='random seed') + + + opt = parser.parse_args() + + return opt + +if __name__ == '__main__': + main() diff --git a/shapenet/__init__.py b/shapenet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/shapenet/test_car.py b/shapenet/test_car.py new file mode 100644 index 0000000..7997128 --- /dev/null +++ b/shapenet/test_car.py @@ -0,0 +1,905 @@ +import torch +from pprint import pprint +from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD +from metrics.evaluation_metrics import compute_all_metrics, EMD_CD + +import torch.nn as nn +import torch.optim as optim +import torch.utils.data + +import argparse +from model.unet import get_model +from torch.distributions import Normal + +from utils.file_utils import * +from utils.visualize import * +from utils.mitsuba_renderer import write_to_xml_batch +from model.pvcnn_generation import PVCNN2Base + +from tqdm import tqdm + +from datasets.shapenet_data_pc import ShapeNet15kPointClouds +from datasets.shapenet_data_sv import * +''' +models +''' +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + KL divergence between normal distributions parameterized by mean and log-variance. + """ + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + (mean1 - mean2)**2 * torch.exp(-logvar2)) + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + # Assumes data is integers [0, 1] + assert x.shape == means.shape == log_scales.shape + px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 0.5) + cdf_plus = px0.cdf(plus_in) + min_in = inv_stdv * (centered_x - .5) + cdf_min = px0.cdf(min_in) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where( + x < 0.001, log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + assert log_probs.shape == x.shape + return log_probs + + +class GaussianDiffusion: + def __init__(self,betas, loss_type, model_mean_type, model_var_type): + self.loss_type = loss_type + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + assert isinstance(betas, np.ndarray) + self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy + assert (betas > 0).all() and (betas <= 1).all() + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + + # initialize twice the actual length so we can keep running for eval + # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) + + alphas = 1. - betas + alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + + self.betas = torch.from_numpy(betas).float() + self.alphas_cumprod = alphas_cumprod.float() + self.alphas_cumprod_prev = alphas_cumprod_prev.float() + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + + betas = torch.from_numpy(betas).float() + alphas = torch.from_numpy(alphas).float() + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance = posterior_variance + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) + self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + + @staticmethod + def _extract(a, t, x_shape): + """ + Extract some coefficients at specified timesteps, + then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + bs, = t.shape + assert x_shape[0] == bs + out = torch.gather(a, 0, t) + assert out.shape == torch.Size([bs]) + return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) + + + + def q_mean_variance(self, x_start, t): + mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data (t == 0 means diffused for 1 step) + """ + if noise is None: + noise = torch.randn(x_start.shape, device=x_start.device) + assert noise.shape == x_start.shape + return ( + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + ) + + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) + posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + + model_output = denoise_fn(data, t) + + + if self.model_var_type in ['fixedsmall', 'fixedlarge']: + # below: only log_variance is used in the KL computations + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood + 'fixedlarge': (self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), + 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + }[self.model_var_type] + model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data) + model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data) + else: + raise NotImplementedError(self.model_var_type) + + if self.model_mean_type == 'eps': + x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output) + + if clip_denoised: + x_recon = torch.clamp(x_recon, -.5, .5) + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t) + else: + raise NotImplementedError(self.loss_type) + + + assert model_mean.shape == x_recon.shape == data.shape + assert model_variance.shape == model_log_variance.shape == data.shape + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - + self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + ) + + ''' samples ''' + + def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True): + """ + Sample from the model + """ + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, + return_pred_xstart=True) + noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device) + assert noise.shape == data.shape + # no noise when t == 0 + nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(data.shape) - 1)) + + sample = model_mean + if use_var: + sample = sample + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise + assert sample.shape == pred_xstart.shape + return (sample, pred_xstart) if return_pred_xstart else sample + + + def p_sample_loop(self, denoise_fn, shape, device, + noise_fn=torch.randn, constrain_fn=lambda x, t:x, + clip_denoised=True, max_timestep=None, keep_running=False): + """ + Generate samples + keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps + + """ + if max_timestep is None: + final_time = self.num_timesteps + else: + final_time = max_timestep + + assert isinstance(shape, (tuple, list)) + img_t = noise_fn(size=shape, dtype=torch.float, device=device) + for t in reversed(range(0, final_time if not keep_running else len(self.betas))): + img_t = constrain_fn(img_t, t) + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, return_pred_xstart=False).detach() + + + assert img_t.shape == shape + return img_t + + def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x): + + assert t >= 1 + + t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1) + encoding = self.q_sample(x0, t_vec) + + img_t = encoding + + for k in reversed(range(0,t)): + img_t = constrain_fn(img_t, k) + t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=False, return_pred_xstart=False, use_var=True).detach() + + + return img_t + + +class PVCNN2(PVCNN2Base): + sa_blocks = [ + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, + voxel_resolution_multiplier=1): + super().__init__( + num_classes=num_classes, embed_dim=embed_dim, use_att=use_att, + dropout=dropout, extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + +class Model(nn.Module): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + super(Model, self).__init__() + self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type) + + self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention, + dropout=args.dropout, extra_feature_channels=0) + + def prior_kl(self, x0): + return self.diffusion._prior_bpd(x0) + + def all_kl(self, x0, clip_denoised=True): + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + + return { + 'total_bpd_b': total_bpd_b, + 'terms_bpd': vals_bt, + 'prior_bpd_b': prior_bpd_b, + 'mse_bt':mse_bt + } + + + def _denoise(self, data, t): + B, D,N= data.shape + assert data.dtype == torch.float + assert t.shape == torch.Size([B]) and t.dtype == torch.int64 + + out = self.model(data, t) + + assert out.shape == torch.Size([B, D, N]) + return out + + def get_loss_iter(self, data, noises=None): + B, D, N = data.shape + t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) + + if noises is not None: + noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + + losses = self.diffusion.p_losses( + denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + assert losses.shape == t.shape == torch.Size([B]) + return losses + + def gen_samples(self, shape, device, noise_fn=torch.randn, constrain_fn=lambda x, t:x, + clip_denoised=False, max_timestep=None, + keep_running=False): + return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn, + constrain_fn=constrain_fn, + clip_denoised=clip_denoised, max_timestep=max_timestep, + keep_running=keep_running) + + def reconstruct(self, x0, t, constrain_fn=lambda x, t:x): + + return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn) + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def multi_gpu_wrapper(self, f): + self.model = f(self.model) + + +def get_betas(schedule_type, b_start, b_end, time_num): + if schedule_type == 'linear': + betas = np.linspace(b_start, b_end, time_num) + elif schedule_type == 'warm0.1': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.1) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.2': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.2) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.5': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.5) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + else: + raise NotImplementedError(schedule_type) + return betas + +def get_constrain_function(ground_truth, mask, eps, num_steps=1): + ''' + + :param target_shape_constraint: target voxels + :return: constrained x + ''' + # eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2)) + eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000)**2 )) + def constrain_fn(x, t): + eps_ = eps_all[t] if (t<1000) else 0 + for _ in range(num_steps): + x = x - eps_ * ((x - ground_truth) * mask) + + + return x + return constrain_fn + + +############################################################################# + +def get_dataset(dataroot, npoints,category,use_mask=False): + tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True, use_mask = use_mask) + te_dataset = ShapeNet15kPointClouds(root_dir=dataroot, + categories=category, split='val', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + use_mask=use_mask + ) + return tr_dataset, te_dataset + +def get_svr_dataset(pc_dataroot, mesh_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, + cache=os.path.join(mesh_root, '../cache'), split='val', + categories=category, + radius=3, elev=0, azim=0, img_size=512, focal_length=1000, + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return te_dataset + + +def get_mvr_dataset(pc_dataroot, mesh_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, + cache=os.path.join(mesh_root, '../cache'), split='val', + categories=category, + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + + + return te_dataset +def get_mvr_dataset_v2(pc_dataroot, views_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, + cache=os.path.join(pc_dataroot, '../cache'), split='val', + categories=category, + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return te_dataset + + +def evaluate_gen(opt, ref_pcs, logger): + + if ref_pcs is None: + _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=False) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + ref = [] + for data in tqdm(test_dataloader, total=len(test_dataloader), desc='Generating Samples'): + x = data['test_points'] + m, s = data['mean'].float(), data['std'].float() + + ref.append(x*s + m) + + ref_pcs = torch.cat(ref, dim=0).contiguous() + + logger.info("Loading sample path: %s" + % (opt.eval_path)) + sample_pcs = torch.load(opt.eval_path).contiguous() + + logger.info("Generation sample size:%s reference size: %s" + % (sample_pcs.size(), ref_pcs.size())) + + + # Compute metrics + results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) + results = {k: (v.cpu().detach().item() + if not isinstance(v, float) else v) for k, v in results.items()} + + pprint(results) + logger.info(results) + + jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy()) + pprint('JSD: {}'.format(jsd)) + logger.info('JSD: {}'.format(jsd)) + +def evaluate_recon(opt, netE, save_dir, logger): + test_dataset = get_mvr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2', + opt.npoints, opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + ref = [] + samples = [] + images = [] + masked = [] + k = 0 + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + randind = i%24 + gt_all = data['test_points'][:,randind:randind+1] + x_all = data['sv_points'][:,randind:randind+1] + mask_all= data['masks'][:,randind:randind+1] + img_all = data['image'][:,randind:randind+1] + + + B,V,N,C = x_all.shape + + x = x_all.reshape(B*V,N,C).transpose(1,2).contiguous() + mask = mask_all.reshape(B*V,N,C).transpose(1,2).contiguous() + img = img_all.reshape(B*V, *img_all.shape[2:]) + + m, s = data['mean'].float(), data['std'].float() + + # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None) + # for t in [10]: + # recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, + # opt.constrain_steps)).detach().cpu() + + recon = netE.gen_samples(x.shape, 'cuda', + constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, + opt.constrain_steps), + clip_denoised=False).detach().cpu() + + recon = recon.transpose(1, 2).contiguous() + x = x.transpose(1, 2).contiguous() + + + # for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))): + # write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + # visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None) + # + # export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + # + # k+=1 + + x_adj = x.reshape(B,V,N,C)* s + m + recon_adj = recon.reshape(B,V,N,C)* s + m + img = img.reshape(B,V,*img.shape[1:]) + + ref.append( gt_all * s + m) + masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) + samples.append(recon_adj) + images.append(img) + + + ref_pcs = torch.cat(ref, dim=0) + sample_pcs = torch.cat(samples, dim=0) + images = torch.cat(images, dim=0) + masked = torch.cat(masked, dim=0) + + B, V, N, C = ref_pcs.shape + + + torch.save(ref_pcs.reshape(B*V, N, C), os.path.join(save_dir, 'recon_gt.pth')) + torch.save(images.reshape(B*V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) + torch.save(masked.reshape(B*V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) + # Compute metrics + results = EMD_CD(sample_pcs.reshape(B*V, N, C), + ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) + + results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} + + pprint({key: val.mean() for key, val in results.items()}) + logger.info({key: val.mean() for key, val in results.items()}) + + results['pc'] = sample_pcs + torch.save(results, os.path.join(save_dir, 'ours_results.pth')) + + # + # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) + # + # results = {k: (v.cpu().detach().item() + # if not isinstance(v, float) else v) for k, v in results.items()} + # pprint(results) + # logger.info(results) +def evaluate_recon_mvr(opt, netE, save_dir, logger): + test_dataset = get_mvr_dataset_v2(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', + opt.npoints, opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + ref = [] + samples = [] + images = [] + masked = [] + k = 0 + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + gt_all = data['test_points'] + x_all = data['sv_points'] + mask_all= data['masks'] + img_all = data['image'] + + + B,V,N,C = x_all.shape + gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1) + + # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None) + # for t in [10]: + # recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, + # opt.constrain_steps)).detach().cpu() + + cd_res = [] + recon_res = [] + for p in range(5): + x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous() + mask = mask_all.reshape(B * V, N, C).transpose(1, 2).contiguous() + img = img_all.reshape(B * V, *img_all.shape[2:]) + + m, s = data['mean'].float(), data['std'].float() + + recon = netE.gen_samples(x.shape, 'cuda', + constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, + opt.constrain_steps), + clip_denoised=False).detach().cpu() + + recon = recon.transpose(1, 2).contiguous() + x = x.transpose(1, 2).contiguous() + + cd = (((recon - x)**2)*mask.transpose(1, 2).contiguous()).sum(dim=(1,2)) + + cd_res.append(cd) + recon_res.append(recon) + + cd_res = torch.stack(cd_res, dim=0) + recon_res = torch.stack(recon_res, dim=0) + _, argmin = torch.min(cd_res, 0) + recon = recon_res[argmin,torch.arange(0,argmin.shape[0])] + + # for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))): + # write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + # visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None) + # + # export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + # + # k+=1 + + x_adj = x.reshape(B,V,N,C)* s + m + recon_adj = recon.reshape(B,V,N,C)* s + m + img = img.reshape(B,V,*img.shape[1:]) + + ref.append( gt_all * s + m) + masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) + samples.append(recon_adj) + images.append(img) + + ref_pcs = torch.cat(ref, dim=0) + sample_pcs = torch.cat(samples, dim=0) + images = torch.cat(images, dim=0) + masked = torch.cat(masked, dim=0) + + B, V, N, C = ref_pcs.shape + + + torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth')) + torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) + torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) + # Compute metrics + results = EMD_CD(sample_pcs.reshape(B*V, N, C), + ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) + + results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} + + pprint({key: val.mean().item() for key, val in results.items()}) + logger.info({key: val.mean().item() for key, val in results.items()}) + + results['pc'] = sample_pcs + torch.save(results, os.path.join(save_dir, 'ours_results.pth')) + + # + # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) + # + # results = {k: (v.cpu().detach().item() + # if not isinstance(v, float) else v) for k, v in results.items()} + # pprint(results) + # logger.info(results) + + + +def generate(netE, opt, logger): + + _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes) + + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + + with torch.no_grad(): + + samples = [] + ref = [] + + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Generating Samples'): + + x = data['test_points'].transpose(1,2) + m, s = data['mean'].float(), data['std'].float() + + gen = netE.gen_samples(x.shape, + 'cuda', clip_denoised=False).detach().cpu() + + gen = gen.transpose(1,2).contiguous() + x = x.transpose(1,2).contiguous() + + + + gen = gen * s + m + x = x * s + m + samples.append(gen) + ref.append(x) + + visualize_pointcloud_batch(os.path.join(str(Path(opt.eval_path).parent), 'x.png'), gen[:64], None, + None, None) + + write_to_xml_batch(os.path.join(str(Path(opt.eval_path).parent), 'xml_samples_%03d'%i), gen[:min(gen.shape[0], 40)].numpy(), default_color='b') + + samples = torch.cat(samples, dim=0) + ref = torch.cat(ref, dim=0) + + torch.save(samples, opt.eval_path) + + + + return ref + + +def generate_multimodal(opt, netE, save_dir, logger): + test_dataset = get_svr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2', + opt.npoints, opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + + k = 0 + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + gt_all = data['test_points'] + x_all = data['sv_points'] + mask_all= data['masks'] + img_all = data['image'] + + + # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None) + # for t in [10]: + # recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, + # opt.constrain_steps)).detach().cpu() + + Path(os.path.join(save_dir, 'x_%03d'%k)).mkdir(exist_ok=True, parents=True) + Path(os.path.join(save_dir, 'x_ply_%03d' % k)).mkdir(exist_ok=True, parents=True) + + for v in range(5): + x = x_all.transpose(1, 2).contiguous() + mask = mask_all.transpose(1, 2).contiguous() + img = img_all + + m, s = data['mean'].float(), data['std'].float() + + recon = netE.gen_samples(x.shape, 'cuda', + constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, + opt.constrain_steps), + clip_denoised=False).detach().cpu() + + recon = recon.transpose(1, 2).contiguous() + x = x.transpose(1, 2).contiguous() + + + for d in zip(list(gt_all), list(recon), list(x), list(img)): + + im = np.fliplr(np.flipud(d[-1])) + plt.imsave(os.path.join(save_dir, 'depth_%03d.png'%k), im, cmap='gray') + write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k, 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + + export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k, 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + + k+=1 + +def main(opt): + exp_id = os.path.splitext(os.path.basename(__file__))[0] + dir_id = os.path.dirname(__file__) + output_dir = get_output_dir(dir_id, exp_id) + copy_source(__file__, output_dir) + logger = setup_logging(output_dir) + + outf_syn, = setup_output_subdirs(output_dir, 'syn') + + betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) + netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) + + if opt.cuda: + netE.cuda() + + def _transform_(m): + return nn.parallel.DataParallel(m) + + netE = netE.cuda() + netE.multi_gpu_wrapper(_transform_) + + netE.eval() + + ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] + + with torch.no_grad(): + for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): + #'/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-02-45/epoch_1399.pth' + #'/viscam/u/alexzhou907/research/diffusion/shapenet/output/67_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-03-38/epoch_2799.pth' + opt.netE = '/viscam/u/alexzhou907/research/diffusion/shapenet/output/67_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-03-38/epoch_2799.pth'#ckpt + logger.info("Resume Path:%s" % opt.netE) + + resumed_param = torch.load(opt.netE) + netE.load_state_dict(resumed_param['model_state']) + + + ref = None + if opt.generate: + epoch = int(os.path.basename(ckpt).split('.')[0].split('_')[-1]) + opt.eval_path = os.path.join(outf_syn, 'epoch_{}_samples.pth'.format(epoch)) + Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True) + ref=generate(netE, opt, logger) + if opt.eval_gen: + # Evaluate generation + evaluate_gen(opt, ref, logger) + + if opt.eval_recon: + # Evaluate generation + evaluate_recon(opt, netE, outf_syn, logger) + + if opt.eval_recon_mvr: + # Evaluate generation + evaluate_recon_mvr(opt, netE, outf_syn, logger) + + if opt.generate_multimodal: + + generate_multimodal(opt, netE, outf_syn, logger) + + exit() + +def parse_args(): + + parser = argparse.ArgumentParser() + parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') + parser.add_argument('--classes', default=['car']) + + parser.add_argument('--batch_size', type=int, default=50, help='input batch size') + parser.add_argument('--workers', type=int, default=16, help='workers') + parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + + parser.add_argument('--generate',default=True) + parser.add_argument('--eval_gen', default=True) + + parser.add_argument('--nc', default=3) + parser.add_argument('--npoints', default=2048) + '''model''' + parser.add_argument('--beta_start', default=0.0001) + parser.add_argument('--beta_end', default=0.02) + parser.add_argument('--schedule_type', default='linear') + parser.add_argument('--time_num', default=1000) + + #params + parser.add_argument('--attention', default=True) + parser.add_argument('--dropout', default=0.1) + parser.add_argument('--embed_dim', type=int, default=64) + parser.add_argument('--loss_type', default='mse') + parser.add_argument('--model_mean_type', default='eps') + parser.add_argument('--model_var_type', default='fixedsmall') + + + # constrain function + parser.add_argument('--constrain_eps', default=0.2) + parser.add_argument('--constrain_steps', type=int, default=1) + + parser.add_argument('--ckpt_dir', default='',required=True, help="path to netE (to continue training)") + + '''eval''' + + parser.add_argument('--eval_path', + default='') + + parser.add_argument('--manualSeed', default=42, type=int, help='random seed') + + parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') + + opt = parser.parse_args() + + if torch.cuda.is_available(): + opt.cuda = True + else: + opt.cuda = False + + return opt +if __name__ == '__main__': + opt = parse_args() + set_seed(opt) + + main(opt) + + # results in /viscam/u/alexzhou907/research/diffusion/shapenet/output/test_chair \ No newline at end of file diff --git a/shapenet/test_chair.py b/shapenet/test_chair.py new file mode 100644 index 0000000..c88ec0b --- /dev/null +++ b/shapenet/test_chair.py @@ -0,0 +1,911 @@ +import torch +from pprint import pprint +from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD +from metrics.evaluation_metrics import compute_all_metrics, EMD_CD + +import torch.nn as nn +import torch.optim as optim +import torch.utils.data + +import argparse +from model.unet import get_model +from torch.distributions import Normal + +from utils.file_utils import * +from utils.visualize import * +from utils.mitsuba_renderer import write_to_xml_batch +from model.pvcnn_generation import PVCNN2Base + +from tqdm import tqdm + +from datasets.shapenet_data_pc import ShapeNet15kPointClouds +from datasets.shapenet_data_sv import * +''' +models +''' +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + KL divergence between normal distributions parameterized by mean and log-variance. + """ + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + (mean1 - mean2)**2 * torch.exp(-logvar2)) + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + # Assumes data is integers [0, 1] + assert x.shape == means.shape == log_scales.shape + px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 0.5) + cdf_plus = px0.cdf(plus_in) + min_in = inv_stdv * (centered_x - .5) + cdf_min = px0.cdf(min_in) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where( + x < 0.001, log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + assert log_probs.shape == x.shape + return log_probs + + +class GaussianDiffusion: + def __init__(self,betas, loss_type, model_mean_type, model_var_type): + self.loss_type = loss_type + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + assert isinstance(betas, np.ndarray) + self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy + assert (betas > 0).all() and (betas <= 1).all() + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + + # initialize twice the actual length so we can keep running for eval + # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) + + alphas = 1. - betas + alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + + self.betas = torch.from_numpy(betas).float() + self.alphas_cumprod = alphas_cumprod.float() + self.alphas_cumprod_prev = alphas_cumprod_prev.float() + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + + betas = torch.from_numpy(betas).float() + alphas = torch.from_numpy(alphas).float() + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance = posterior_variance + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) + self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + + @staticmethod + def _extract(a, t, x_shape): + """ + Extract some coefficients at specified timesteps, + then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + bs, = t.shape + assert x_shape[0] == bs + out = torch.gather(a, 0, t) + assert out.shape == torch.Size([bs]) + return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) + + + + def q_mean_variance(self, x_start, t): + mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data (t == 0 means diffused for 1 step) + """ + if noise is None: + noise = torch.randn(x_start.shape, device=x_start.device) + assert noise.shape == x_start.shape + return ( + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + ) + + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) + posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + + model_output = denoise_fn(data, t) + + + if self.model_var_type in ['fixedsmall', 'fixedlarge']: + # below: only log_variance is used in the KL computations + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood + 'fixedlarge': (self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), + 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + }[self.model_var_type] + model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data) + model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data) + else: + raise NotImplementedError(self.model_var_type) + + if self.model_mean_type == 'eps': + x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output) + + if clip_denoised: + x_recon = torch.clamp(x_recon, -.5, .5) + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t) + else: + raise NotImplementedError(self.loss_type) + + + assert model_mean.shape == x_recon.shape == data.shape + assert model_variance.shape == model_log_variance.shape == data.shape + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - + self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + ) + + ''' samples ''' + + def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True): + """ + Sample from the model + """ + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, + return_pred_xstart=True) + noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device) + assert noise.shape == data.shape + # no noise when t == 0 + nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(data.shape) - 1)) + + sample = model_mean + if use_var: + sample = sample + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise + assert sample.shape == pred_xstart.shape + return (sample, pred_xstart) if return_pred_xstart else sample + + + def p_sample_loop(self, denoise_fn, shape, device, + noise_fn=torch.randn, constrain_fn=lambda x, t:x, + clip_denoised=True, max_timestep=None, keep_running=False): + """ + Generate samples + keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps + + """ + if max_timestep is None: + final_time = self.num_timesteps + else: + final_time = max_timestep + + assert isinstance(shape, (tuple, list)) + img_t = noise_fn(size=shape, dtype=torch.float, device=device) + for t in reversed(range(0, final_time if not keep_running else len(self.betas))): + img_t = constrain_fn(img_t, t) + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, return_pred_xstart=False).detach() + + + assert img_t.shape == shape + return img_t + + def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x): + + assert t >= 1 + + t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1) + encoding = self.q_sample(x0, t_vec) + + img_t = encoding + + for k in reversed(range(0,t)): + img_t = constrain_fn(img_t, k) + t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=False, return_pred_xstart=False, use_var=True).detach() + + + return img_t + + def interpolate(self, x0, x1, t, lamb, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x): + + assert t >= 1 + + t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1) + encoding0 = self.q_sample(x0, t_vec) + encoding1 = self.q_sample(x1, t_vec) + + enc = encoding0 * lamb + (1-lamb) * encoding1 + + img_t = enc + + for k in reversed(range(0,t)): + img_t = constrain_fn(img_t, k) + t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=False, return_pred_xstart=False, use_var=True).detach() + + + return img_t + +class PVCNN2(PVCNN2Base): + sa_blocks = [ + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, + voxel_resolution_multiplier=1): + super().__init__( + num_classes=num_classes, embed_dim=embed_dim, use_att=use_att, + dropout=dropout, extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + +class Model(nn.Module): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + super(Model, self).__init__() + self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type) + + self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention, + dropout=args.dropout, extra_feature_channels=0) + + def prior_kl(self, x0): + return self.diffusion._prior_bpd(x0) + + def all_kl(self, x0, clip_denoised=True): + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + + return { + 'total_bpd_b': total_bpd_b, + 'terms_bpd': vals_bt, + 'prior_bpd_b': prior_bpd_b, + 'mse_bt':mse_bt + } + + + def _denoise(self, data, t): + B, D,N= data.shape + assert data.dtype == torch.float + assert t.shape == torch.Size([B]) and t.dtype == torch.int64 + + out = self.model(data, t) + + assert out.shape == torch.Size([B, D, N]) + return out + + def get_loss_iter(self, data, noises=None): + B, D, N = data.shape + t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) + + if noises is not None: + noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + + losses = self.diffusion.p_losses( + denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + assert losses.shape == t.shape == torch.Size([B]) + return losses + + def gen_samples(self, shape, device, noise_fn=torch.randn, constrain_fn=lambda x, t:x, + clip_denoised=False, max_timestep=None, + keep_running=False): + return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn, + constrain_fn=constrain_fn, + clip_denoised=clip_denoised, max_timestep=max_timestep, + keep_running=keep_running) + + def reconstruct(self, x0, t, constrain_fn=lambda x, t:x): + + return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn) + + def interpolate(self, x0, x1, t, lamb, constrain_fn=lambda x, t:x): + + return self.diffusion.interpolate(x0, x1, t, lamb, self._denoise, constrain_fn=constrain_fn) + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def multi_gpu_wrapper(self, f): + self.model = f(self.model) + + +def get_betas(schedule_type, b_start, b_end, time_num): + if schedule_type == 'linear': + betas = np.linspace(b_start, b_end, time_num) + elif schedule_type == 'warm0.1': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.1) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.2': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.2) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.5': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.5) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + else: + raise NotImplementedError(schedule_type) + return betas + +def get_constrain_function(ground_truth, mask, eps, num_steps=1): + ''' + + :param target_shape_constraint: target voxels + :return: constrained x + ''' + # eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2)) + eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000)**2 )) + def constrain_fn(x, t): + eps_ = eps_all[t] if (t<1000) else 0 + for _ in range(num_steps): + x = x - eps_ * ((x - ground_truth) * mask) + + + return x + return constrain_fn + + +############################################################################# + +def get_dataset(dataroot, npoints,category,use_mask=False): + tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True, use_mask = use_mask) + te_dataset = ShapeNet15kPointClouds(root_dir=dataroot, + categories=category, split='val', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + use_mask=use_mask + ) + return tr_dataset, te_dataset + +def get_svr_dataset(pc_dataroot, mesh_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, + cache=os.path.join(mesh_root, '../cache'), split='val', + categories=category, + radius=3, elev=-89, azim=180, img_size=512, focal_length=1000, + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return te_dataset + +def get_mvr_dataset(pc_dataroot, mesh_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, + cache=os.path.join(mesh_root, '../cache'), split='val', + categories=category, + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + + + return te_dataset + +def get_mvr_dataset_v2(pc_dataroot, views_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, + cache=os.path.join(pc_dataroot, '../cache'), split='val', + categories=category, + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return te_dataset + + +def evaluate_gen(opt, ref_pcs, logger): + + if ref_pcs is None: + _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=False) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + ref = [] + for data in tqdm(test_dataloader, total=len(test_dataloader), desc='Generating Samples'): + x = data['test_points'] + m, s = data['mean'].float(), data['std'].float() + + ref.append(x*s + m) + + ref_pcs = torch.cat(ref, dim=0).contiguous() + + logger.info("Loading sample path: %s" + % (opt.eval_path)) + sample_pcs = torch.load(opt.eval_path).contiguous() + + logger.info("Generation sample size:%s reference size: %s" + % (sample_pcs.size(), ref_pcs.size())) + + + # Compute metrics + # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) + # results = {k: (v.cpu().detach().item() + # if not isinstance(v, float) else v) for k, v in results.items()} + # + # pprint(results) + # logger.info(results) + + jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy()) + pprint('JSD: {}'.format(jsd)) + logger.info('JSD: {}'.format(jsd)) + +def evaluate_recon(opt, netE, save_dir, logger): + test_dataset = get_mvr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2', + opt.npoints, opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + ref = [] + samples = [] + images = [] + masked = [] + k = 0 + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + randind = i%24 + gt_all = data['test_points'][:,randind:randind+1] + x_all = data['sv_points'][:,randind:randind+1] + mask_all= data['masks'][:,randind:randind+1] + img_all = data['image'][:,randind:randind+1] + + + B,V,N,C = x_all.shape + + gt = gt_all.reshape(B*V,N,C).transpose(1,2).contiguous() + x = x_all.reshape(B*V,N,C).transpose(1,2).contiguous() + mask = mask_all.reshape(B*V,N,C).transpose(1,2).contiguous() + img = img_all.reshape(B*V, *img_all.shape[2:]) + + m, s = data['mean'].float(), data['std'].float() + + # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None) + # for t in [10]: + # t_vec = torch.empty(gt.shape[0], dtype=torch.int64, device='cuda').fill_(80) + # recon = netE.diffusion.q_sample(gt.cuda(), t_vec).detach().cpu() + recon = netE.gen_samples(x.shape, 'cuda', + constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, + opt.constrain_steps), + clip_denoised=False).detach().cpu() + + # recon = recon.transpose(1, 2).contiguous() + # x = x.transpose(1, 2).contiguous() + # gt = gt.transpose(1, 2).contiguous() + # write_to_xml_batch(os.path.join(save_dir, 'intermediate_%03d' % i), + # (recon.detach().cpu() * s[0].squeeze() + m[0].squeeze()).numpy()) + # write_to_xml_batch(os.path.join(save_dir, 'x_%03d' % i), + # (gt.detach().cpu() * s[0].squeeze() + m[0].squeeze()).numpy()) + # write_to_xml_batch(os.path.join(save_dir, 'noise_%03d' % i), + # (torch.randn_like(gt).detach().cpu() * s[0].squeeze() + m[0].squeeze()).numpy()) + # for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))): + # write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + # visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None) + # + # k+=1 + + x_adj = x.reshape(B,V,N,C)* s + m + recon_adj = recon.reshape(B,V,N,C)* s + m + img = img.reshape(B,V,*img.shape[1:]) + + ref.append( gt_all * s + m) + masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) + samples.append(recon_adj) + images.append(img) + + + ref_pcs = torch.cat(ref, dim=0) + sample_pcs = torch.cat(samples, dim=0) + images = torch.cat(images, dim=0) + masked = torch.cat(masked, dim=0) + + B, V, N, C = ref_pcs.shape + + + torch.save(ref_pcs.reshape(B*V, N, C), os.path.join(save_dir, 'recon_gt.pth')) + torch.save(images.reshape(B*V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) + torch.save(masked.reshape(B*V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) + # Compute metrics + results = EMD_CD(sample_pcs.reshape(B*V, N, C), + ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) + + results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} + + pprint({key: val.mean().item() for key, val in results.items()}) + logger.info({key: val.mean().item() for key, val in results.items()}) + + results['pc'] = sample_pcs + torch.save(results, os.path.join(save_dir, 'ours_results.pth')) + + # + # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) + # + # results = {k: (v.cpu().detach().item() + # if not isinstance(v, float) else v) for k, v in results.items()} + # pprint(results) + # logger.info(results) + +def evaluate_recon_mvr(opt, netE, save_dir, logger): + test_dataset = get_mvr_dataset_v2(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', + opt.npoints, opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + ref = [] + samples = [] + images = [] + masked = [] + k = 0 + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + gt_all = data['test_points'] + x_all = data['sv_points'] + mask_all= data['masks'] + # img_all = data['image'] + + + B,V,N,C = x_all.shape + gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1) + + x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous() + mask = mask_all.reshape(B * V, N, C).transpose(1, 2).contiguous() + # img = img_all.reshape(B * V, *img_all.shape[2:]) + + m, s = data['mean'].float(), data['std'].float() + + recon = netE.gen_samples(x.shape, 'cuda', + constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, + opt.constrain_steps), + clip_denoised=False).detach().cpu() + + recon = recon.transpose(1, 2).contiguous() + x = x.transpose(1, 2).contiguous() + + x_adj = x.reshape(B,V,N,C)* s + m + recon_adj = recon.reshape(B,V,N,C)* s + m + # img = img.reshape(B,V,*img.shape[1:]) + + ref.append( gt_all * s + m) + masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) + samples.append(recon_adj) + # images.append(img) + + ref_pcs = torch.cat(ref, dim=0) + sample_pcs = torch.cat(samples, dim=0) + # images = torch.cat(images, dim=0) + masked = torch.cat(masked, dim=0) + + B, V, N, C = ref_pcs.shape + + + torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth')) + # torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) + torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) + # Compute metrics + results = EMD_CD(sample_pcs.reshape(B*V, N, C), + ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) + + results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} + + pprint({key: val.mean().item() for key, val in results.items()}) + logger.info({key: val.mean().item() for key, val in results.items()}) + + results['pc'] = sample_pcs + torch.save(results, os.path.join(save_dir, 'ours_results.pth')) + + # + # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) + # + # results = {k: (v.cpu().detach().item() + # if not isinstance(v, float) else v) for k, v in results.items()} + # pprint(results) + # logger.info(results) + + del ref_pcs, masked + + +def generate(netE, opt, logger): + + _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes) + + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + + with torch.no_grad(): + + samples = [] + ref = [] + + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Generating Samples'): + + x = data['test_points'].transpose(1,2) + m, s = data['mean'].float(), data['std'].float() + + gen = netE.gen_samples(x.shape, + 'cuda', clip_denoised=False).detach().cpu() + + gen = gen.transpose(1,2).contiguous() + x = x.transpose(1,2).contiguous() + + + + gen = gen * s + m + x = x * s + m + samples.append(gen) + ref.append(x) + + visualize_pointcloud_batch(os.path.join(str(Path(opt.eval_path).parent), 'x.png'), gen[:64], None, + None, None) + + write_to_xml_batch(os.path.join(str(Path(opt.eval_path).parent), 'xml_samples_%03d'%i), gen[:min(gen.shape[0], 40)].numpy(), cat='chair') + + samples = torch.cat(samples, dim=0) + ref = torch.cat(ref, dim=0) + + torch.save(samples, opt.eval_path) + + + + return ref + + + +def generate_multimodal(opt, netE, save_dir, logger): + test_dataset = get_svr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2', + opt.npoints, opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + gt_all = data['test_points'] + x_all = data['sv_points'] + mask_all= data['masks'] + img_all = data['image'] + + + # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None) + # for t in [10]: + # recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, + # opt.constrain_steps)).detach().cpu() + + for v in range(10): + x = x_all.transpose(1, 2).contiguous() + mask = mask_all.transpose(1, 2).contiguous() + img = img_all + + m, s = data['mean'].float(), data['std'].float() + + recon = netE.gen_samples(x.shape, 'cuda', + constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, + opt.constrain_steps), + clip_denoised=False).detach().cpu() + + recon = recon.transpose(1, 2).contiguous() + x = x.transpose(1, 2).contiguous() + + for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))): + + im = np.fliplr(np.flipud(d[-1])) + plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray') + write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + + export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + + + + + + +def main(opt): + exp_id = os.path.splitext(os.path.basename(__file__))[0] + dir_id = os.path.dirname(__file__) + output_dir = get_output_dir(dir_id, exp_id) + copy_source(__file__, output_dir) + logger = setup_logging(output_dir) + + outf_syn, = setup_output_subdirs(output_dir, 'syn') + + betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) + netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) + + if opt.cuda: + netE.cuda() + + def _transform_(m): + return nn.parallel.DataParallel(m) + + netE = netE.cuda() + netE.multi_gpu_wrapper(_transform_) + + netE.eval() + + ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] + + with torch.no_grad(): + for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): + #'/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-02-45/epoch_1399.pth' + #'/viscam/u/alexzhou907/research/diffusion/shapenet/output/67_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-03-38/epoch_2799.pth' + opt.netE = '/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do_best/2020-10-16-12-23-44/epoch_1799.pth'#ckpt + logger.info("Resume Path:%s" % opt.netE) + + resumed_param = torch.load(opt.netE) + netE.load_state_dict(resumed_param['model_state']) + + + ref = None + if opt.generate: + epoch = int(os.path.basename(ckpt).split('.')[0].split('_')[-1]) + opt.eval_path = os.path.join(outf_syn, 'epoch_{}_samples.pth'.format(epoch)) + Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True) + ref=generate(netE, opt, logger) + if opt.eval_gen: + # Evaluate generation + evaluate_gen(opt, ref, logger) + + if opt.eval_recon: + # Evaluate generation + evaluate_recon(opt, netE, outf_syn, logger) + + if opt.eval_recon_mvr: + # Evaluate generation + evaluate_recon_mvr(opt, netE, outf_syn, logger) + + if opt.generate_multimodal: + + generate_multimodal(opt, netE, outf_syn, logger) + + + + exit() + +def parse_args(): + + parser = argparse.ArgumentParser() + parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') + parser.add_argument('--classes', default=['chair']) + + parser.add_argument('--batch_size', type=int, default=1, help='input batch size') + parser.add_argument('--workers', type=int, default=16, help='workers') + parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + + parser.add_argument('--generate',default=True) + parser.add_argument('--eval_gen', default=False) + parser.add_argument('--eval_recon', default=False) + parser.add_argument('--eval_recon_mvr', default=False) + parser.add_argument('--generate_multimodal', default=False) + + parser.add_argument('--nc', default=3) + parser.add_argument('--npoints', default=2048) + '''model''' + parser.add_argument('--beta_start', default=0.0001) + parser.add_argument('--beta_end', default=0.02) + parser.add_argument('--schedule_type', default='linear') + parser.add_argument('--time_num', default=1000) + + #params + parser.add_argument('--attention', default=True) + parser.add_argument('--dropout', default=0.1) + parser.add_argument('--embed_dim', type=int, default=64) + parser.add_argument('--loss_type', default='mse') + parser.add_argument('--model_mean_type', default='eps') + parser.add_argument('--model_var_type', default='fixedsmall') + + # constrain function + parser.add_argument('--constrain_eps', default=.051) + parser.add_argument('--constrain_steps', type=int, default=1) + + parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do_best/2020-10-16-12-23-44/', help="path to netE (to continue training)") + + '''eval''' + + parser.add_argument('--eval_path', + default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test_chair/2020-10-18-13-46-21/syn/epoch_1699_samples.pth') + + parser.add_argument('--manualSeed', default=42, type=int, help='random seed') + + parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') + + opt = parser.parse_args() + + if torch.cuda.is_available(): + opt.cuda = True + else: + opt.cuda = False + + return opt +if __name__ == '__main__': + opt = parse_args() + set_seed(opt) + + main(opt) + + # results in /viscam/u/alexzhou907/research/diffusion/shapenet/output/test_chair/2020-10-18-13-46-21 \ No newline at end of file diff --git a/shapenet/test_generation.py b/shapenet/test_generation.py new file mode 100644 index 0000000..979e0d0 --- /dev/null +++ b/shapenet/test_generation.py @@ -0,0 +1,589 @@ +import torch +from pprint import pprint +from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD +from metrics.evaluation_metrics import compute_all_metrics, EMD_CD + +import torch.nn as nn +import torch.optim as optim +import torch.utils.data + +import argparse +from model.unet import get_model +from torch.distributions import Normal + +from utils.file_utils import * +from utils.visualize import * +from utils.mitsuba_renderer import write_to_xml_batch +from model.pvcnn_generation import PVCNN2Base + +from tqdm import tqdm + +from datasets.shapenet_data_pc import ShapeNet15kPointClouds + +''' +models +''' +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + KL divergence between normal distributions parameterized by mean and log-variance. + """ + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + (mean1 - mean2)**2 * torch.exp(-logvar2)) + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + # Assumes data is integers [0, 1] + assert x.shape == means.shape == log_scales.shape + px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 0.5) + cdf_plus = px0.cdf(plus_in) + min_in = inv_stdv * (centered_x - .5) + cdf_min = px0.cdf(min_in) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where( + x < 0.001, log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + assert log_probs.shape == x.shape + return log_probs + + +class GaussianDiffusion: + def __init__(self,betas, loss_type, model_mean_type, model_var_type): + self.loss_type = loss_type + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + assert isinstance(betas, np.ndarray) + self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy + assert (betas > 0).all() and (betas <= 1).all() + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + + # initialize twice the actual length so we can keep running for eval + # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) + + alphas = 1. - betas + alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + + self.betas = torch.from_numpy(betas).float() + self.alphas_cumprod = alphas_cumprod.float() + self.alphas_cumprod_prev = alphas_cumprod_prev.float() + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + + betas = torch.from_numpy(betas).float() + alphas = torch.from_numpy(alphas).float() + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance = posterior_variance + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) + self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + + @staticmethod + def _extract(a, t, x_shape): + """ + Extract some coefficients at specified timesteps, + then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + bs, = t.shape + assert x_shape[0] == bs + out = torch.gather(a, 0, t) + assert out.shape == torch.Size([bs]) + return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) + + + + def q_mean_variance(self, x_start, t): + mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data (t == 0 means diffused for 1 step) + """ + if noise is None: + noise = torch.randn(x_start.shape, device=x_start.device) + assert noise.shape == x_start.shape + return ( + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + ) + + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) + posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + + model_output = denoise_fn(data, t) + + + if self.model_var_type in ['fixedsmall', 'fixedlarge']: + # below: only log_variance is used in the KL computations + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood + 'fixedlarge': (self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), + 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + }[self.model_var_type] + model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data) + model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data) + else: + raise NotImplementedError(self.model_var_type) + + if self.model_mean_type == 'eps': + x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output) + + if clip_denoised: + x_recon = torch.clamp(x_recon, -.5, .5) + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t) + else: + raise NotImplementedError(self.loss_type) + + + assert model_mean.shape == x_recon.shape == data.shape + assert model_variance.shape == model_log_variance.shape == data.shape + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - + self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + ) + + ''' samples ''' + + def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True): + """ + Sample from the model + """ + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, + return_pred_xstart=True) + noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device) + assert noise.shape == data.shape + # no noise when t == 0 + nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(data.shape) - 1)) + + sample = model_mean + if use_var: + sample = sample + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise + assert sample.shape == pred_xstart.shape + return (sample, pred_xstart) if return_pred_xstart else sample + + + def p_sample_loop(self, denoise_fn, shape, device, + noise_fn=torch.randn, constrain_fn=lambda x, t:x, + clip_denoised=True, max_timestep=None, keep_running=False): + """ + Generate samples + keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps + + """ + if max_timestep is None: + final_time = self.num_timesteps + else: + final_time = max_timestep + + assert isinstance(shape, (tuple, list)) + img_t = noise_fn(size=shape, dtype=torch.float, device=device) + for t in reversed(range(0, final_time if not keep_running else len(self.betas))): + img_t = constrain_fn(img_t, t) + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, return_pred_xstart=False).detach() + + + assert img_t.shape == shape + return img_t + + def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x): + + assert t >= 1 + + t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1) + encoding = self.q_sample(x0, t_vec) + + img_t = encoding + + for k in reversed(range(0,t)): + img_t = constrain_fn(img_t, k) + t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=False, return_pred_xstart=False, use_var=True).detach() + + + return img_t + + +class PVCNN2(PVCNN2Base): + sa_blocks = [ + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, + voxel_resolution_multiplier=1): + super().__init__( + num_classes=num_classes, embed_dim=embed_dim, use_att=use_att, + dropout=dropout, extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + +class Model(nn.Module): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + super(Model, self).__init__() + self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type) + + self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention, + dropout=args.dropout, extra_feature_channels=0) + + def prior_kl(self, x0): + return self.diffusion._prior_bpd(x0) + + def all_kl(self, x0, clip_denoised=True): + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + + return { + 'total_bpd_b': total_bpd_b, + 'terms_bpd': vals_bt, + 'prior_bpd_b': prior_bpd_b, + 'mse_bt':mse_bt + } + + + def _denoise(self, data, t): + B, D,N= data.shape + assert data.dtype == torch.float + assert t.shape == torch.Size([B]) and t.dtype == torch.int64 + + out = self.model(data, t) + + assert out.shape == torch.Size([B, D, N]) + return out + + def get_loss_iter(self, data, noises=None): + B, D, N = data.shape + t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) + + if noises is not None: + noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + + losses = self.diffusion.p_losses( + denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + assert losses.shape == t.shape == torch.Size([B]) + return losses + + def gen_samples(self, shape, device, noise_fn=torch.randn, constrain_fn=lambda x, t:x, + clip_denoised=False, max_timestep=None, + keep_running=False): + return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn, + constrain_fn=constrain_fn, + clip_denoised=clip_denoised, max_timestep=max_timestep, + keep_running=keep_running) + + def reconstruct(self, x0, t, constrain_fn=lambda x, t:x): + + return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn) + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def multi_gpu_wrapper(self, f): + self.model = f(self.model) + + +def get_betas(schedule_type, b_start, b_end, time_num): + if schedule_type == 'linear': + betas = np.linspace(b_start, b_end, time_num) + elif schedule_type == 'warm0.1': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.1) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.2': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.2) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.5': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.5) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + else: + raise NotImplementedError(schedule_type) + return betas + +def get_constrain_function(ground_truth, mask, eps, num_steps=1): + ''' + + :param target_shape_constraint: target voxels + :return: constrained x + ''' + # eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2)) + eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000)**2 )) + def constrain_fn(x, t): + eps_ = eps_all[t] if (t<1000) else 0 + for _ in range(num_steps): + x = x - eps_ * ((x - ground_truth) * mask) + + + return x + return constrain_fn + + +############################################################################# + +def get_dataset(dataroot, npoints,category,use_mask=False): + tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot, + categories=[category], split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True, use_mask = use_mask) + te_dataset = ShapeNet15kPointClouds(root_dir=dataroot, + categories=[category], split='val', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + use_mask=use_mask + ) + return tr_dataset, te_dataset + + + +def evaluate_gen(opt, ref_pcs, logger): + + if ref_pcs is None: + _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.category, use_mask=False) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + ref = [] + for data in tqdm(test_dataloader, total=len(test_dataloader), desc='Generating Samples'): + x = data['test_points'] + m, s = data['mean'].float(), data['std'].float() + + ref.append(x*s + m) + + ref_pcs = torch.cat(ref, dim=0).contiguous() + + logger.info("Loading sample path: %s" + % (opt.eval_path)) + sample_pcs = torch.load(opt.eval_path).contiguous() + + logger.info("Generation sample size:%s reference size: %s" + % (sample_pcs.size(), ref_pcs.size())) + + + # Compute metrics + results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) + results = {k: (v.cpu().detach().item() + if not isinstance(v, float) else v) for k, v in results.items()} + + pprint(results) + logger.info(results) + + jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy()) + pprint('JSD: {}'.format(jsd)) + logger.info('JSD: {}'.format(jsd)) + + + +def generate(model, opt): + + _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.category) + + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + + with torch.no_grad(): + + samples = [] + ref = [] + + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Generating Samples'): + + x = data['test_points'].transpose(1,2) + m, s = data['mean'].float(), data['std'].float() + + gen = model.gen_samples(x.shape, + 'cuda', clip_denoised=False).detach().cpu() + + gen = gen.transpose(1,2).contiguous() + x = x.transpose(1,2).contiguous() + + + + gen = gen * s + m + x = x * s + m + samples.append(gen) + ref.append(x) + + visualize_pointcloud_batch(os.path.join(str(Path(opt.eval_path).parent), 'x.png'), gen[:64], None, + None, None) + + samples = torch.cat(samples, dim=0) + ref = torch.cat(ref, dim=0) + + torch.save(samples, opt.eval_path) + + + + return ref + + +def main(opt): + + if opt.category == 'airplane': + opt.beta_start = 1e-5 + opt.beta_end = 0.008 + opt.schedule_type = 'warm0.1' + + exp_id = os.path.splitext(os.path.basename(__file__))[0] + dir_id = os.path.dirname(__file__) + output_dir = get_output_dir(dir_id, exp_id) + copy_source(__file__, output_dir) + logger = setup_logging(output_dir) + + outf_syn, = setup_output_subdirs(output_dir, 'syn') + + betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) + model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) + + if opt.cuda: + model.cuda() + + def _transform_(m): + return nn.parallel.DataParallel(m) + + model = model.cuda() + model.multi_gpu_wrapper(_transform_) + + model.eval() + + with torch.no_grad(): + + logger.info("Resume Path:%s" % opt.model) + + resumed_param = torch.load(opt.model) + model.load_state_dict(resumed_param['model_state']) + + + ref = None + if opt.generate: + opt.eval_path = os.path.join(outf_syn, 'samples.pth') + Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True) + ref=generate(model, opt) + + if opt.eval_gen: + # Evaluate generation + evaluate_gen(opt, ref, logger) + + +def parse_args(): + + parser = argparse.ArgumentParser() + parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') + parser.add_argument('--category', default='car') + + parser.add_argument('--batch_size', type=int, default=50, help='input batch size') + parser.add_argument('--workers', type=int, default=16, help='workers') + parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + + parser.add_argument('--generate',default=True) + parser.add_argument('--eval_gen', default=True) + + parser.add_argument('--nc', default=3) + parser.add_argument('--npoints', default=2048) + '''model''' + parser.add_argument('--beta_start', default=0.0001) + parser.add_argument('--beta_end', default=0.02) + parser.add_argument('--schedule_type', default='linear') + parser.add_argument('--time_num', default=1000) + + #params + parser.add_argument('--attention', default=True) + parser.add_argument('--dropout', default=0.1) + parser.add_argument('--embed_dim', type=int, default=64) + parser.add_argument('--loss_type', default='mse') + parser.add_argument('--model_mean_type', default='eps') + parser.add_argument('--model_var_type', default='fixedsmall') + + + parser.add_argument('--model', default='',required=True, help="path to model (to continue training)") + + '''eval''' + + parser.add_argument('--eval_path', + default='') + + parser.add_argument('--manualSeed', default=42, type=int, help='random seed') + + parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') + + opt = parser.parse_args() + + if torch.cuda.is_available(): + opt.cuda = True + else: + opt.cuda = False + + return opt +if __name__ == '__main__': + opt = parse_args() + set_seed(opt) + + main(opt) + + # results in /viscam/u/alexzhou907/research/diffusion/shapenet/output/test_chair \ No newline at end of file diff --git a/shapenet/test_plane.py b/shapenet/test_plane.py new file mode 100644 index 0000000..5fd3b1d --- /dev/null +++ b/shapenet/test_plane.py @@ -0,0 +1,925 @@ +import torch +import functools +from pprint import pprint +from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD +from metrics.evaluation_metrics import compute_all_metrics, EMD_CD, distChamfer + +import torch.nn as nn +import torch.optim as optim +import torch.utils.data + +import argparse +from model.unet import get_model +from torch.distributions import Normal + +from utils.file_utils import * +from utils.visualize import * +from utils.mitsuba_renderer import write_to_xml_batch +from model.pvcnn_generation import PVCNN2Base + +from tqdm import tqdm + +from datasets.shapenet_data_pc import ShapeNet15kPointClouds +from datasets.shapenet_data_sv import * +''' +models +''' +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + KL divergence between normal distributions parameterized by mean and log-variance. + """ + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + (mean1 - mean2)**2 * torch.exp(-logvar2)) + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + # Assumes data is integers [0, 1] + assert x.shape == means.shape == log_scales.shape + px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 0.5) + cdf_plus = px0.cdf(plus_in) + min_in = inv_stdv * (centered_x - .5) + cdf_min = px0.cdf(min_in) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where( + x < 0.001, log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + assert log_probs.shape == x.shape + return log_probs + + +class GaussianDiffusion: + def __init__(self,betas, loss_type, model_mean_type, model_var_type): + self.loss_type = loss_type + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + assert isinstance(betas, np.ndarray) + self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy + assert (betas > 0).all() and (betas <= 1).all() + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + + # initialize twice the actual length so we can keep running for eval + # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) + + alphas = 1. - betas + alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + + self.betas = torch.from_numpy(betas).float() + self.alphas_cumprod = alphas_cumprod.float() + self.alphas_cumprod_prev = alphas_cumprod_prev.float() + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + + betas = torch.from_numpy(betas).float() + alphas = torch.from_numpy(alphas).float() + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance = posterior_variance + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) + self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + + @staticmethod + def _extract(a, t, x_shape): + """ + Extract some coefficients at specified timesteps, + then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + bs, = t.shape + assert x_shape[0] == bs + out = torch.gather(a, 0, t) + assert out.shape == torch.Size([bs]) + return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) + + + + def q_mean_variance(self, x_start, t): + mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data (t == 0 means diffused for 1 step) + """ + if noise is None: + noise = torch.randn(x_start.shape, device=x_start.device) + assert noise.shape == x_start.shape + return ( + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + ) + + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) + posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + + model_output = denoise_fn(data, t) + + + if self.model_var_type in ['fixedsmall', 'fixedlarge']: + # below: only log_variance is used in the KL computations + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood + 'fixedlarge': (self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), + 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + }[self.model_var_type] + model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data) + model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data) + else: + raise NotImplementedError(self.model_var_type) + + if self.model_mean_type == 'eps': + x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output) + + if clip_denoised: + x_recon = torch.clamp(x_recon, -.5, .5) + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t) + else: + raise NotImplementedError(self.loss_type) + + + assert model_mean.shape == x_recon.shape == data.shape + assert model_variance.shape == model_log_variance.shape == data.shape + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - + self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + ) + + ''' samples ''' + + def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True): + """ + Sample from the model + """ + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, + return_pred_xstart=True) + noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device) + assert noise.shape == data.shape + # no noise when t == 0 + nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(data.shape) - 1)) + + sample = model_mean + if use_var: + sample = sample + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise + assert sample.shape == pred_xstart.shape + return (sample, pred_xstart) if return_pred_xstart else sample + + + def p_sample_loop(self, denoise_fn, shape, device, + noise_fn=torch.randn, constrain_fn=lambda x, t:x, + clip_denoised=True, max_timestep=None, keep_running=False): + """ + Generate samples + keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps + + """ + if max_timestep is None: + final_time = self.num_timesteps + else: + final_time = max_timestep + + assert isinstance(shape, (tuple, list)) + img_t = noise_fn(size=shape, dtype=torch.float, device=device) + for t in reversed(range(0, final_time if not keep_running else len(self.betas))): + img_t = constrain_fn(img_t, t) + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, return_pred_xstart=False).detach() + + + assert img_t.shape == shape + return img_t + + def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x): + + assert t >= 1 + + t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1) + encoding = self.q_sample(x0, t_vec) + + img_t = encoding + + for k in reversed(range(0,t)): + img_t = constrain_fn(img_t, k) + t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=False, return_pred_xstart=False, use_var=True).detach() + + + return img_t + + def reconstruct2(self, x0, mask, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda forward, x, t:x): + z = noise_fn(size=x0.shape, dtype=torch.float, device=x0.device) + + for _ in range(10): + img_t = z + outputs =[None for _ in range(len(self.betas))] + for t in reversed(range(0, len(self.betas))): + + t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t) + outputs[t] = img_t.detach().cpu().clone() + + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=False, return_pred_xstart=False, use_var=True).detach() + + + img_t = torch.autograd.Variable(img_t.data, requires_grad=True) + + dist = ((img_t - x0) ** 2 * mask).sum(dim=0).mean() + grad = torch.autograd.grad(dist, [img_t])[0].detach() + + print('Dist', dist.detach().cpu().item()) + + for t in (range(0, len(outputs))): + + t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t) + + x = outputs[t].to(x0).requires_grad_() + + y = self.p_sample(denoise_fn=denoise_fn, data=x, t=t_, noise_fn=noise_fn, + clip_denoised=False, return_pred_xstart=False, use_var=True) + + grad = torch.autograd.grad(y, [x], grad_outputs=grad)[0] + + + z = x.detach().to(x0) - 0.1 * grad.detach() + + img_t = z + for t in reversed(range(0, len(self.betas))): + + t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t) + + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=False, return_pred_xstart=False, use_var=True).detach() + + return img_t + + +# class PVCNN2(PVCNN2Base): +# sa_blocks = [ +# ((32, 2, 32), (1024, 0.1, 32, (32, 64))), +# ((64, 3, 16), (256, 0.2, 32, (64, 128))), +# ((128, 3, 8), (64, 0.4, 32, (128, 256))), +# (None, (16, 0.8, 32, (256, 256, 512))), +# ] +# fp_blocks = [ +# ((256, 256), (256, 3, 8)), +# ((256, 256), (256, 3, 8)), +# ((256, 128), (128, 2, 16)), +# ((128, 128, 64), (64, 2, 32)), +# ] +# +# def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, +# voxel_resolution_multiplier=1): +# super().__init__( +# num_classes=num_classes, embed_dim=embed_dim, use_att=use_att, +# dropout=dropout, extra_feature_channels=extra_feature_channels, +# width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier +# ) + +class PVCNN2(PVCNN2Base): + sa_blocks = [ + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, + voxel_resolution_multiplier=1): + super().__init__( + num_classes=num_classes, embed_dim=embed_dim, use_att=use_att, + dropout=dropout, extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + + +class Model(nn.Module): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + super(Model, self).__init__() + self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type) + + self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention, + dropout=args.dropout, extra_feature_channels=0) + + def prior_kl(self, x0): + return self.diffusion._prior_bpd(x0) + + def all_kl(self, x0, clip_denoised=True): + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + + return { + 'total_bpd_b': total_bpd_b, + 'terms_bpd': vals_bt, + 'prior_bpd_b': prior_bpd_b, + 'mse_bt':mse_bt + } + + + def _denoise(self, data, t): + B, D,N= data.shape + assert data.dtype == torch.float + assert t.shape == torch.Size([B]) and t.dtype == torch.int64 + + out = self.model(data, t) + + assert out.shape == torch.Size([B, D, N]) + return out + + def get_loss_iter(self, data, noises=None): + B, D, N = data.shape + t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) + + if noises is not None: + noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + + losses = self.diffusion.p_losses( + denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + assert losses.shape == t.shape == torch.Size([B]) + return losses + + def gen_samples(self, shape, device, noise_fn=torch.randn, constrain_fn=lambda x, t:x, + clip_denoised=False, max_timestep=None, + keep_running=False): + return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn, + constrain_fn=constrain_fn, + clip_denoised=clip_denoised, max_timestep=max_timestep, + keep_running=keep_running) + + def reconstruct(self, x0, t, constrain_fn=lambda x, t:x): + + return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn) + + def reconstruct2(self, x0, mask, constrain_fn): + + return self.diffusion.reconstruct2(x0, mask, self._denoise, constrain_fn=constrain_fn) + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def multi_gpu_wrapper(self, f): + self.model = f(self.model) + + +def get_betas(schedule_type, b_start, b_end, time_num): + if schedule_type == 'linear': + betas = np.linspace(b_start, b_end, time_num) + elif schedule_type == 'warm0.1': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.1) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.2': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.2) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.5': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.5) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + else: + raise NotImplementedError(schedule_type) + return betas + +def get_constrain_function(ground_truth, mask, eps, num_steps=1): + ''' + + :param target_shape_constraint: target voxels + :return: constrained x + ''' + # eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2)) + eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 500)**2 )) + def constrain_fn(x, t): + eps_ = eps_all[t] if (t<500) else 0 + for _ in range(num_steps): + x = x - eps_ * ((x - ground_truth) * mask) + + + return x + + # mask_single = mask[0, :, 0] + # num = mask_single.sum().int().item() + def constrain_fn2(forward, x, t): + + x = torch.autograd.Variable(x.data, requires_grad=True) + y = forward(x) + + + + dist = ((y - ground_truth)**2 * mask).sum(dim=0).mean() + grad = torch.autograd.grad(dist, [x], retain_graph=True)[0] + x = x - eps * (grad) + + print('Dist', dist.detach().cpu().item()) + + return x + return constrain_fn + + + +############################################################################# + +def get_dataset(dataroot, npoints,category,use_mask=False): + tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True, use_mask = use_mask) + te_dataset = ShapeNet15kPointClouds(root_dir=dataroot, + categories=category, split='val', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + use_mask=use_mask + ) + return tr_dataset, te_dataset + +def get_mvr_dataset(pc_dataroot, mesh_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, + cache=os.path.join(mesh_root, '../cache'), split='val', + categories=category, + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return te_dataset + +def get_mvr_dataset_v2(pc_dataroot, views_root, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, + categories=category, split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, + cache=os.path.join(pc_dataroot, '../cache'), split='val', + categories=category, + npoints=npoints, sv_samples=200, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return te_dataset + + +def evaluate_gen(opt, ref_pcs, logger): + + if ref_pcs is None: + _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=False) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + ref = [] + for data in tqdm(test_dataloader, total=len(test_dataloader), desc='Generating Samples'): + x = data['test_points'] + m, s = data['mean'].float(), data['std'].float() + + ref.append(x*s + m) + + ref_pcs = torch.cat(ref, dim=0).contiguous() + + logger.info("Loading sample path: %s" + % (opt.eval_path)) + sample_pcs = torch.load(opt.eval_path).contiguous() + + logger.info("Generation sample size:%s reference size: %s" + % (sample_pcs.size(), ref_pcs.size())) + + + # Compute metrics + # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) + # results = {k: (v.cpu().detach().item() + # if not isinstance(v, float) else v) for k, v in results.items()} + # + # pprint(results) + # logger.info(results) + + jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy()) + pprint('JSD: {}'.format(jsd)) + logger.info('JSD: {}'.format(jsd)) + + +def evaluate_recon(opt, netE, save_dir, logger): + test_dataset = get_mvr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2', + opt.npoints, opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + ref = [] + samples = [] + images = [] + masked = [] + k = 0 + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + randind = i%24 + gt_all = data['test_points'][:,randind:randind+1] + x_all = data['sv_points'][:,randind:randind+1] + mask_all= data['masks'][:,randind:randind+1] + img_all = data['image'][:,randind:randind+1] + + + B,V,N,C = x_all.shape + + x = x_all.reshape(B*V,N,C).transpose(1,2).contiguous() + mask = mask_all.reshape(B*V,N,C).transpose(1,2).contiguous() + img = img_all.reshape(B*V, *img_all.shape[2:]) + + m, s = data['mean'].float(), data['std'].float() + + # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None) + # for t in [10]: + # recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, + # opt.constrain_steps)).detach().cpu() + + recon = netE.gen_samples(x.shape, 'cuda', + constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, + opt.constrain_steps), + clip_denoised=False).detach().cpu() + + recon = recon.transpose(1, 2).contiguous() + x = x.transpose(1, 2).contiguous() + + + # for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))): + # write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + # visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None) + # + # export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + # + # k+=1 + + x_adj = x.reshape(B,V,N,C)* s + m + recon_adj = recon.reshape(B,V,N,C)* s + m + img = img.reshape(B,V,*img.shape[1:]) + + ref.append( gt_all * s + m) + masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) + samples.append(recon_adj) + images.append(img) + + + ref_pcs = torch.cat(ref, dim=0) + sample_pcs = torch.cat(samples, dim=0) + images = torch.cat(images, dim=0) + masked = torch.cat(masked, dim=0) + + B, V, N, C = ref_pcs.shape + + + torch.save(ref_pcs.reshape(B*V, N, C), os.path.join(save_dir, 'recon_gt.pth')) + torch.save(images.reshape(B*V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) + torch.save(masked.reshape(B*V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) + # Compute metrics + results = EMD_CD(sample_pcs.reshape(B*V, N, C), + ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) + + results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} + + pprint({key: val.mean() for key, val in results.items()}) + logger.info({key: val.mean() for key, val in results.items()}) + + results['pc'] = sample_pcs + torch.save(results, os.path.join(save_dir, 'ours_results.pth')) + + # + # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) + # + # results = {k: (v.cpu().detach().item() + # if not isinstance(v, float) else v) for k, v in results.items()} + # pprint(results) + # logger.info(results) + + + +def evaluate_recon_mvr(opt, netE, save_dir, logger): + test_dataset = get_mvr_dataset_v2(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', + opt.npoints, opt.classes) + # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + ref = [] + samples = [] + images = [] + masked = [] + k = 0 + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + + gt_all = data['test_points'] + x_all = data['sv_points'] + mask_all= data['masks'] + img_all = data['image'] + + + B,V,N,C = x_all.shape + gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1) + + # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None) + # for t in [10]: + # recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, + # opt.constrain_steps)).detach().cpu() + + cd_res = [] + recon_res = [] + for p in range(5): + x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous() + mask = mask_all.reshape(B * V, N, C).transpose(1, 2).contiguous() + img = img_all.reshape(B * V, *img_all.shape[2:]) + + m, s = data['mean'].float(), data['std'].float() + + recon = netE.gen_samples(x.shape, 'cuda', + constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, + opt.constrain_steps), + clip_denoised=False).detach().cpu() + + recon = recon.transpose(1, 2).contiguous() + x = x.transpose(1, 2).contiguous() + + cd = (((recon - x)**2)*mask.transpose(1, 2).contiguous()).sum(dim=(1,2)) + + cd_res.append(cd) + recon_res.append(recon) + + cd_res = torch.stack(cd_res, dim=0) + recon_res = torch.stack(recon_res, dim=0) + _, argmin = torch.min(cd_res, 0) + recon = recon_res[argmin,torch.arange(0,argmin.shape[0])] + + # for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))): + # write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + # visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None) + # + # export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) + # + # k+=1 + + x_adj = x.reshape(B,V,N,C)* s + m + recon_adj = recon.reshape(B,V,N,C)* s + m + img = img.reshape(B,V,*img.shape[1:]) + + ref.append( gt_all * s + m) + masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) + samples.append(recon_adj) + images.append(img) + + ref_pcs = torch.cat(ref, dim=0) + sample_pcs = torch.cat(samples, dim=0) + images = torch.cat(images, dim=0) + masked = torch.cat(masked, dim=0) + + B, V, N, C = ref_pcs.shape + + + torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth')) + torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) + torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) + # Compute metrics + results = EMD_CD(sample_pcs.reshape(B*V, N, C), + ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) + + results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} + + pprint({key: val.mean().item() for key, val in results.items()}) + logger.info({key: val.mean().item() for key, val in results.items()}) + + results['pc'] = sample_pcs + torch.save(results, os.path.join(save_dir, 'ours_results.pth')) + + # + # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) + # + # results = {k: (v.cpu().detach().item() + # if not isinstance(v, float) else v) for k, v in results.items()} + # pprint(results) + # logger.info(results) + + + +def generate(netE, opt, logger): + + _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes) + + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + + with torch.no_grad(): + + samples = [] + ref = [] + + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Generating Samples'): + + x = data['test_points'].transpose(1,2) + m, s = data['mean'].float(), data['std'].float() + + gen = netE.gen_samples(x.shape, + 'cuda', clip_denoised=False).detach().cpu() + + gen = gen.transpose(1,2).contiguous() + x = x.transpose(1,2).contiguous() + + + + gen = gen * s + m + x = x * s + m + samples.append(gen) + ref.append(x) + + # visualize_pointcloud_batch(os.path.join(str(Path(opt.eval_path).parent), 'x_%03d.png'%i), gen[:64], None, + # None, None) + # export_to_pc_batch(os.path.join(str(Path(opt.eval_path).parent), 'ply_%03d'%i), + # gen[:64].numpy()) + write_to_xml_batch(os.path.join(str(Path(opt.eval_path).parent), 'xml_samples_%03d'%i), gen[:min(gen.shape[0], 40)].numpy(), cat='airplane') + + samples = torch.cat(samples, dim=0) + ref = torch.cat(ref, dim=0) + + torch.save(samples, opt.eval_path) + + + + return ref + +def main(opt): + exp_id = os.path.splitext(os.path.basename(__file__))[0] + dir_id = os.path.dirname(__file__) + output_dir = get_output_dir(dir_id, exp_id) + copy_source(__file__, output_dir) + logger = setup_logging(output_dir) + + outf_syn, = setup_output_subdirs(output_dir, 'syn') + + betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) + netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) + + if opt.cuda: + netE.cuda() + + def _transform_(m): + return nn.parallel.DataParallel(m) + + netE = netE.cuda() + netE.multi_gpu_wrapper(_transform_) + + # netE.eval() + + ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] + + with torch.no_grad(): + for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): + #'/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-02-45/epoch_1399.pth' + #'/viscam/u/alexzhou907/research/diffusion/shapenet/output/67_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-03-38/epoch_2799.pth' + # opt.netE = '/viscam/u/alexzhou907/research/diffusion/shapenet/output/71_res32_pc_plane_mse_fs_dattng_8e3-1e5_w0.1beta_0wd_0.1do_best/2020-10-13-13-33-53/epoch_2899.pth'#ckpt + opt.netE = '/viscam/u/alexzhou907/research/diffusion/shapenet/output/71_res32_pc_plane_mse_fs_dattng_8e3-1e5_w0.1beta_0wd_0.1do/2020-10-07-13-26-10/epoch_2299.pth' + logger.info("Resume Path:%s" % opt.netE) + + resumed_param = torch.load(opt.netE) + netE.load_state_dict(resumed_param['model_state']) + + + ref = None + if opt.generate: + epoch = int(os.path.basename(ckpt).split('.')[0].split('_')[-1]) + opt.eval_path = os.path.join(outf_syn, 'epoch_{}_samples.pth'.format(epoch)) + Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True) + ref=generate(netE, opt, logger) + if opt.eval_gen: + # Evaluate generation + evaluate_gen(opt, ref, logger) + + if opt.eval_recon: + # Evaluate generation + evaluate_recon(opt, netE, outf_syn, logger) + + if opt.eval_recon_mvr: + # Evaluate generation + evaluate_recon_mvr(opt, netE, outf_syn, logger) + + +def parse_args(): + + parser = argparse.ArgumentParser() + parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') + parser.add_argument('--classes', default=['airplane']) + + parser.add_argument('--batch_size', type=int, default=20, help='input batch size') + parser.add_argument('--workers', type=int, default=16, help='workers') + parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + + parser.add_argument('--generate',default=False) + parser.add_argument('--eval_gen', default=True) + parser.add_argument('--eval_recon', default=False) + parser.add_argument('--eval_recon_mvr', default=False) + + parser.add_argument('--nc', default=3) + parser.add_argument('--npoints', default=2048) + '''model''' + parser.add_argument('--beta_start', default=0.00001) + parser.add_argument('--beta_end', default=0.008) + parser.add_argument('--schedule_type', default='warm0.1') + parser.add_argument('--time_num', default=1000) + + #params + parser.add_argument('--attention', default=True) + parser.add_argument('--dropout', default=0.1) + parser.add_argument('--embed_dim', type=int, default=64) + parser.add_argument('--loss_type', default='mse') + parser.add_argument('--model_mean_type', default='eps') + parser.add_argument('--model_var_type', default='fixedsmall') + + # constrain function + parser.add_argument('--constrain_eps', default=.1) + parser.add_argument('--constrain_steps', type=int, default=1) + + parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/71_res32_pc_plane_mse_fs_dattng_8e3-1e5_w0.1beta_0wd_0.1do_best/2020-10-13-13-33-53', help="path to netE (to continue training)") + + '''eval''' + + parser.add_argument('--eval_path', + default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test_plane/2020-10-18-13-49-20/syn/epoch_2499_samples.pth') + + parser.add_argument('--manualSeed', default=42, type=int, help='random seed') + + parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') + + opt = parser.parse_args() + + if torch.cuda.is_available(): + opt.cuda = True + else: + opt.cuda = False + + return opt +if __name__ == '__main__': + opt = parse_args() + set_seed(opt) + + main(opt) diff --git a/shapenet/train_generation.py b/shapenet/train_generation.py new file mode 100644 index 0000000..9141740 --- /dev/null +++ b/shapenet/train_generation.py @@ -0,0 +1,853 @@ +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +import torch.utils.data + +import argparse +from model.unet import get_model +from torch.distributions import Normal + +from utils.file_utils import * +from utils.visualize import * +from model.pvcnn_generation import PVCNN2Base +import torch.distributed as dist +from datasets.shapenet_data_pc import ShapeNet15kPointClouds + +''' +some utils +''' +def rotation_matrix(axis, theta): + """ + Return the rotation matrix associated with counterclockwise rotation about + the given axis by theta radians. + """ + axis = np.asarray(axis) + axis = axis / np.sqrt(np.dot(axis, axis)) + a = np.cos(theta / 2.0) + b, c, d = -axis * np.sin(theta / 2.0) + aa, bb, cc, dd = a * a, b * b, c * c, d * d + bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d + return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], + [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], + [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) + +def rotate(vertices, faces): + ''' + vertices: [numpoints, 3] + ''' + M = rotation_matrix([0, 1, 0], np.pi / 2).transpose() + N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose() + K = rotation_matrix([0, 0, 1], np.pi).transpose() + + v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]] + return v, f + +def norm(v, f): + v = (v - v.min())/(v.max() - v.min()) - 0.5 + + return v, f + +def getGradNorm(net): + pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters())) + gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters())) + return pNorm, gradNorm + + +def weights_init(m): + """ + xavier initialization + """ + classname = m.__class__.__name__ + if classname.find('Conv') != -1 and m.weight is not None: + torch.nn.init.xavier_normal_(m.weight) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_() + m.bias.data.fill_(0) + +''' +models +''' +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + KL divergence between normal distributions parameterized by mean and log-variance. + """ + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + (mean1 - mean2)**2 * torch.exp(-logvar2)) + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + # Assumes data is integers [0, 1] + assert x.shape == means.shape == log_scales.shape + px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 0.5) + cdf_plus = px0.cdf(plus_in) + min_in = inv_stdv * (centered_x - .5) + cdf_min = px0.cdf(min_in) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where( + x < 0.001, log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + assert log_probs.shape == x.shape + return log_probs + +class GaussianDiffusion: + def __init__(self,betas, loss_type, model_mean_type, model_var_type): + self.loss_type = loss_type + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + assert isinstance(betas, np.ndarray) + self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy + assert (betas > 0).all() and (betas <= 1).all() + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + + # initialize twice the actual length so we can keep running for eval + # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) + + alphas = 1. - betas + alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + + self.betas = torch.from_numpy(betas).float() + self.alphas_cumprod = alphas_cumprod.float() + self.alphas_cumprod_prev = alphas_cumprod_prev.float() + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + + betas = torch.from_numpy(betas).float() + alphas = torch.from_numpy(alphas).float() + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.posterior_variance = posterior_variance + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) + self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + + @staticmethod + def _extract(a, t, x_shape): + """ + Extract some coefficients at specified timesteps, + then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + bs, = t.shape + assert x_shape[0] == bs + out = torch.gather(a, 0, t) + assert out.shape == torch.Size([bs]) + return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) + + + + def q_mean_variance(self, x_start, t): + mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data (t == 0 means diffused for 1 step) + """ + if noise is None: + noise = torch.randn(x_start.shape, device=x_start.device) + assert noise.shape == x_start.shape + return ( + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + ) + + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + ) + posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) + posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + + model_output = denoise_fn(data, t) + + + if self.model_var_type in ['fixedsmall', 'fixedlarge']: + # below: only log_variance is used in the KL computations + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood + 'fixedlarge': (self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), + 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + }[self.model_var_type] + model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data) + model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data) + else: + raise NotImplementedError(self.model_var_type) + + if self.model_mean_type == 'eps': + x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output) + + if clip_denoised: + x_recon = torch.clamp(x_recon, -.5, .5) + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t) + else: + raise NotImplementedError(self.loss_type) + + + assert model_mean.shape == x_recon.shape == data.shape + assert model_variance.shape == model_log_variance.shape == data.shape + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - + self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + ) + + ''' samples ''' + + def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): + """ + Sample from the model + """ + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, + return_pred_xstart=True) + noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device) + assert noise.shape == data.shape + # no noise when t == 0 + nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(data.shape) - 1)) + + sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise + assert sample.shape == pred_xstart.shape + return (sample, pred_xstart) if return_pred_xstart else sample + + + def p_sample_loop(self, denoise_fn, shape, device, + noise_fn=torch.randn, clip_denoised=True, keep_running=False): + """ + Generate samples + keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps + + """ + + assert isinstance(shape, (tuple, list)) + img_t = noise_fn(size=shape, dtype=torch.float, device=device) + for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, return_pred_xstart=False) + + assert img_t.shape == shape + return img_t + + def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, + noise_fn=torch.randn,clip_denoised=True, keep_running=False): + """ + Generate samples, returning intermediate images + Useful for visualizing how denoised images evolve over time + Args: + repeat_noise_steps (int): Number of denoising timesteps in which the same noise + is used across the batch. If >= 0, the initial noise is the same for all batch elemements. + """ + assert isinstance(shape, (tuple, list)) + + total_steps = self.num_timesteps if not keep_running else len(self.betas) + + img_t = noise_fn(size=shape, dtype=torch.float, device=device) + imgs = [img_t] + for t in reversed(range(0,total_steps)): + + t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) + img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False) + if t % freq == 0 or t == total_steps-1: + imgs.append(img_t) + + assert imgs[-1].shape == shape + return imgs + + '''losses''' + + def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start, x_t=data_t, t=t) + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( + denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) + kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) + kl = kl.mean(dim=list(range(1, len(data_start.shape)))) / np.log(2.) + + return (kl, pred_xstart) if return_pred_xstart else kl + + def p_losses(self, denoise_fn, data_start, t, noise=None): + """ + Training loss calculation + """ + B, D, N = data_start.shape + assert t.shape == torch.Size([B]) + + if noise is None: + noise = torch.randn(data_start.shape, dtype=data_start.dtype, device=data_start.device) + assert noise.shape == data_start.shape and noise.dtype == data_start.dtype + + data_t = self.q_sample(x_start=data_start, t=t, noise=noise) + + if self.loss_type == 'mse': + # predict the noise instead of x_start. seems to be weighted naturally like SNR + eps_recon = denoise_fn(data_t, t) + assert data_t.shape == data_start.shape + assert eps_recon.shape == torch.Size([B, D, N]) + assert eps_recon.shape == data_start.shape + losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) + elif self.loss_type == 'kl': + losses = self._vb_terms_bpd( + denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, + return_pred_xstart=False) + else: + raise NotImplementedError(self.loss_type) + + assert losses.shape == torch.Size([B]) + return losses + + '''debug''' + + def _prior_bpd(self, x_start): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, + mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) + assert kl_prior.shape == x_start.shape + return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) + + def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): + + with torch.no_grad(): + B, T = x_start.shape[0], self.num_timesteps + + vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) + for t in reversed(range(T)): + + t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) + # Calculate VLB term at the current timestep + new_vals_b, pred_xstart = self._vb_terms_bpd( + denoise_fn, data_start=x_start, data_t=self.q_sample(x_start=x_start, t=t_b), t=t_b, + clip_denoised=clip_denoised, return_pred_xstart=True) + # MSE for progressive prediction loss + assert pred_xstart.shape == x_start.shape + new_mse_b = ((pred_xstart-x_start)**2).mean(dim=list(range(1, len(x_start.shape)))) + assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) + # Insert the calculated term into the tensor of all terms + mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() + vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt + mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt + assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) + + prior_bpd_b = self._prior_bpd(x_start) + total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b + assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ + total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) + return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() + + +class PVCNN2(PVCNN2Base): + sa_blocks = [ + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, + voxel_resolution_multiplier=1): + super().__init__( + num_classes=num_classes, embed_dim=embed_dim, use_att=use_att, + dropout=dropout, extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + + +class Model(nn.Module): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + super(Model, self).__init__() + self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type) + + self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention, + dropout=args.dropout, extra_feature_channels=0) + + def prior_kl(self, x0): + return self.diffusion._prior_bpd(x0) + + def all_kl(self, x0, clip_denoised=True): + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + + return { + 'total_bpd_b': total_bpd_b, + 'terms_bpd': vals_bt, + 'prior_bpd_b': prior_bpd_b, + 'mse_bt':mse_bt + } + + + def _denoise(self, data, t): + B, D,N= data.shape + assert data.dtype == torch.float + assert t.shape == torch.Size([B]) and t.dtype == torch.int64 + + out = self.model(data, t) + + assert out.shape == torch.Size([B, D, N]) + return out + + def get_loss_iter(self, data, noises=None): + B, D, N = data.shape + t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) + + if noises is not None: + noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + + losses = self.diffusion.p_losses( + denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + assert losses.shape == t.shape == torch.Size([B]) + return losses + + def gen_samples(self, shape, device, noise_fn=torch.randn, + clip_denoised=True, + keep_running=False): + return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn, + clip_denoised=clip_denoised, + keep_running=keep_running) + + def gen_sample_traj(self, shape, device, freq, noise_fn=torch.randn, + clip_denoised=True,keep_running=False): + return self.diffusion.p_sample_loop_trajectory(self._denoise, shape=shape, device=device, noise_fn=noise_fn, freq=freq, + clip_denoised=clip_denoised, + keep_running=keep_running) + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def multi_gpu_wrapper(self, f): + self.model = f(self.model) + + +def get_betas(schedule_type, b_start, b_end, time_num): + if schedule_type == 'linear': + betas = np.linspace(b_start, b_end, time_num) + elif schedule_type == 'warm0.1': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.1) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.2': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.2) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + elif schedule_type == 'warm0.5': + + betas = b_end * np.ones(time_num, dtype=np.float64) + warmup_time = int(time_num * 0.5) + betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) + else: + raise NotImplementedError(schedule_type) + return betas + + +def get_dataset(dataroot, npoints,category): + tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot, + categories=[category], split='train', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=True) + te_dataset = ShapeNet15kPointClouds(root_dir=dataroot, + categories=[category], split='val', + tr_sample_size=npoints, + te_sample_size=npoints, + scale=1., + normalize_per_shape=False, + normalize_std_per_axis=False, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return tr_dataset, te_dataset + + +def get_dataloader(opt, train_dataset, test_dataset=None): + + if opt.distribution_type == 'multi': + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + num_replicas=opt.world_size, + rank=opt.rank + ) + if test_dataset is not None: + test_sampler = torch.utils.data.distributed.DistributedSampler( + test_dataset, + num_replicas=opt.world_size, + rank=opt.rank + ) + else: + test_sampler = None + else: + train_sampler = None + test_sampler = None + + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler, + shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True) + + if test_dataset is not None: + test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler, + shuffle=False, num_workers=int(opt.workers), drop_last=False) + else: + test_dataloader = None + + return train_dataloader, test_dataloader, train_sampler, test_sampler + + +def train(gpu, opt, output_dir, noises_init): + + set_seed(opt) + logger = setup_logging(output_dir) + if opt.distribution_type == 'multi': + should_diag = gpu==0 + else: + should_diag = True + if should_diag: + outf_syn, = setup_output_subdirs(output_dir, 'syn') + + if opt.distribution_type == 'multi': + if opt.dist_url == "env://" and opt.rank == -1: + opt.rank = int(os.environ["RANK"]) + + base_rank = opt.rank * opt.ngpus_per_node + opt.rank = base_rank + gpu + dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url, + world_size=opt.world_size, rank=opt.rank) + + opt.bs = int(opt.bs / opt.ngpus_per_node) + opt.workers = 0 + + opt.saveIter = int(opt.saveIter / opt.ngpus_per_node) + opt.diagIter = int(opt.diagIter / opt.ngpus_per_node) + opt.vizIter = int(opt.vizIter / opt.ngpus_per_node) + + + ''' data ''' + train_dataset, _ = get_dataset(opt.dataroot, opt.npoints, opt.category) + dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None) + + + ''' + create networks + ''' + + betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) + model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) + + if opt.distribution_type == 'multi': # Multiple processes, single GPU per process + def _transform_(m): + return nn.parallel.DistributedDataParallel( + m, device_ids=[gpu], output_device=gpu) + + torch.cuda.set_device(gpu) + model.cuda(gpu) + model.multi_gpu_wrapper(_transform_) + + + elif opt.distribution_type == 'single': + def _transform_(m): + return nn.parallel.DataParallel(m) + model = model.cuda() + model.multi_gpu_wrapper(_transform_) + + elif gpu is not None: + torch.cuda.set_device(gpu) + model = model.cuda(gpu) + else: + raise ValueError('distribution_type = multi | single | None') + + if should_diag: + logger.info(opt) + + optimizer= optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.decay, betas=(opt.beta1, 0.999)) + + lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.lr_gamma) + + if opt.model != '': + ckpt = torch.load(opt.model) + model.load_state_dict(ckpt['model_state']) + optimizer.load_state_dict(ckpt['optimizer_state']) + + if opt.model != '': + start_epoch = torch.load(opt.model)['epoch'] + 1 + else: + start_epoch = 0 + + def new_x_chain(x, num_chain): + return torch.randn(num_chain, *x.shape[1:], device=x.device) + + + + for epoch in range(start_epoch, opt.niter): + + if opt.distribution_type == 'multi': + train_sampler.set_epoch(epoch) + + lr_scheduler.step(epoch) + + for i, data in enumerate(dataloader): + x = data['train_points'].transpose(1,2) + noises_batch = noises_init[data['idx']].transpose(1,2) + + ''' + train diffusion + ''' + + if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None): + x = x.cuda(gpu) + noises_batch = noises_batch.cuda(gpu) + elif opt.distribution_type == 'single': + x = x.cuda() + noises_batch = noises_batch.cuda() + + loss = model.get_loss_iter(x, noises_batch).mean() + + optimizer.zero_grad() + loss.backward() + netpNorm, netgradNorm = getGradNorm(model) + if opt.grad_clip is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) + + optimizer.step() + + + if i % opt.print_freq == 0 and should_diag: + + logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, ' + 'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} ' + .format( + epoch, opt.niter, i, len(dataloader),loss.item(), + netpNorm, netgradNorm, + )) + + + if (epoch + 1) % opt.diagIter == 0 and should_diag: + + logger.info('Diagnosis:') + + x_range = [x.min().item(), x.max().item()] + kl_stats = model.all_kl(x) + logger.info(' [{:>3d}/{:>3d}] ' + 'x_range: [{:>10.4f}, {:>10.4f}], ' + 'total_bpd_b: {:>10.4f}, ' + 'terms_bpd: {:>10.4f}, ' + 'prior_bpd_b: {:>10.4f} ' + 'mse_bt: {:>10.4f} ' + .format( + epoch, opt.niter, + *x_range, + kl_stats['total_bpd_b'].item(), + kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item() + )) + + + + if (epoch + 1) % opt.vizIter == 0 and should_diag: + logger.info('Generation: eval') + + model.eval() + with torch.no_grad(): + + x_gen_eval = model.gen_samples(new_x_chain(x, 25).shape, x.device, clip_denoised=False) + x_gen_list = model.gen_sample_traj(new_x_chain(x, 1).shape, x.device, freq=40, clip_denoised=False) + x_gen_all = torch.cat(x_gen_list, dim=0) + + gen_stats = [x_gen_eval.mean(), x_gen_eval.std()] + gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()] + + logger.info(' [{:>3d}/{:>3d}] ' + 'eval_gen_range: [{:>10.4f}, {:>10.4f}] ' + 'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] ' + .format( + epoch, opt.niter, + *gen_eval_range, *gen_stats, + )) + + visualize_pointcloud_batch('%s/epoch_%03d_samples_eval.png' % (outf_syn, epoch), + x_gen_eval.transpose(1, 2), None, None, + None) + + visualize_pointcloud_batch('%s/epoch_%03d_samples_eval_all.png' % (outf_syn, epoch), + x_gen_all.transpose(1, 2), None, + None, + None) + + visualize_pointcloud_batch('%s/epoch_%03d_x.png' % (outf_syn, epoch), x.transpose(1, 2), None, + None, + None) + + logger.info('Generation: train') + model.train() + + + + + + + + + if (epoch + 1) % opt.saveIter == 0: + + if should_diag: + + + save_dict = { + 'epoch': epoch, + 'model_state': model.state_dict(), + 'optimizer_state': optimizer.state_dict() + } + + torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch)) + + + if opt.distribution_type == 'multi': + dist.barrier() + map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu} + model.load_state_dict( + torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state']) + + dist.destroy_process_group() + +def main(): + opt = parse_args() + if opt.category == 'airplane': + opt.beta_start = 1e-5 + opt.beta_end = 0.008 + opt.schedule_type = 'warm0.1' + + exp_id = os.path.splitext(os.path.basename(__file__))[0] + dir_id = os.path.dirname(__file__) + output_dir = get_output_dir(dir_id, exp_id) + copy_source(__file__, output_dir) + + ''' workaround ''' + train_dataset, _ = get_dataset(opt.dataroot, opt.npoints, opt.category) + noises_init = torch.randn(len(train_dataset), opt.npoints, opt.nc) + + if opt.dist_url == "env://" and opt.world_size == -1: + opt.world_size = int(os.environ["WORLD_SIZE"]) + + if opt.distribution_type == 'multi': + opt.ngpus_per_node = torch.cuda.device_count() + opt.world_size = opt.ngpus_per_node * opt.world_size + mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init)) + else: + train(opt.gpu, opt, output_dir, noises_init) + + + +def parse_args(): + + parser = argparse.ArgumentParser() + parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') + parser.add_argument('--category', default='chair') + + parser.add_argument('--bs', type=int, default=48, help='input batch size') + parser.add_argument('--workers', type=int, default=16, help='workers') + parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + + parser.add_argument('--nc', default=3) + parser.add_argument('--npoints', default=2048) + '''model''' + parser.add_argument('--beta_start', default=0.0001) + parser.add_argument('--beta_end', default=0.02) + parser.add_argument('--schedule_type', default='linear') + parser.add_argument('--time_num', default=1000) + + #params + parser.add_argument('--attention', default=True) + parser.add_argument('--dropout', default=0.1) + parser.add_argument('--embed_dim', type=int, default=64) + parser.add_argument('--loss_type', default='mse') + parser.add_argument('--model_mean_type', default='eps') + parser.add_argument('--model_var_type', default='fixedsmall') + + parser.add_argument('--lr', type=float, default=2e-4, help='learning rate for E, default=0.0002') + parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') + parser.add_argument('--decay', type=float, default=0, help='weight decay for EBM') + parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM') + parser.add_argument('--lr_gamma', type=float, default=0.998, help='lr decay for EBM') + + parser.add_argument('--model', default='', help="path to model (to continue training)") + + + '''distributed''' + parser.add_argument('--world_size', default=1, type=int, + help='Number of distributed nodes.') + parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str, + help='url used to set up distributed training') + parser.add_argument('--dist_backend', default='nccl', type=str, + help='distributed backend') + parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None], + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') + parser.add_argument('--rank', default=0, type=int, + help='node rank for distributed training') + parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use. None means using all available GPUs.') + + '''eval''' + parser.add_argument('--saveIter', default=100, help='unit: epoch') + parser.add_argument('--diagIter', default=50, help='unit: epoch') + parser.add_argument('--vizIter', default=50, help='unit: epoch') + parser.add_argument('--print_freq', default=50, help='unit: iter') + + parser.add_argument('--manualSeed', default=42, type=int, help='random seed') + + + opt = parser.parse_args() + + return opt + +if __name__ == '__main__': + main() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/binvox_rw.py b/utils/binvox_rw.py new file mode 100644 index 0000000..73190d2 --- /dev/null +++ b/utils/binvox_rw.py @@ -0,0 +1,266 @@ +# Copyright (C) 2012 Daniel Maturana +# This file is part of binvox-rw-py. +# +# binvox-rw-py is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# binvox-rw-py is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with binvox-rw-py. If not, see . +# + +""" +Binvox to Numpy and back. +>>> import numpy as np +>>> import binvox_rw +>>> with open('chair.binvox', 'rb') as f: +... m1 = binvox_rw.read_as_3d_array(f) +... +>>> m1.dims +[32, 32, 32] +>>> m1.scale +41.133000000000003 +>>> m1.translate +[0.0, 0.0, 0.0] +>>> with open('chair_out.binvox', 'wb') as f: +... m1.write(f) +... +>>> with open('chair_out.binvox', 'rb') as f: +... m2 = binvox_rw.read_as_3d_array(f) +... +>>> m1.dims==m2.dims +True +>>> m1.scale==m2.scale +True +>>> m1.translate==m2.translate +True +>>> np.all(m1.data==m2.data) +True +>>> with open('chair.binvox', 'rb') as f: +... md = binvox_rw.read_as_3d_array(f) +... +>>> with open('chair.binvox', 'rb') as f: +... ms = binvox_rw.read_as_coord_array(f) +... +>>> data_ds = binvox_rw.dense_to_sparse(md.data) +>>> data_sd = binvox_rw.sparse_to_dense(ms.data, 32) +>>> np.all(data_sd==md.data) +True +>>> # the ordering of elements returned by numpy.nonzero changes with axis +>>> # ordering, so to compare for equality we first lexically sort the voxels. +>>> np.all(ms.data[:, np.lexsort(ms.data)] == data_ds[:, np.lexsort(data_ds)]) +True +""" + +import numpy as np + +class Voxels(object): + """ Holds a binvox model. + data is either a three-dimensional numpy boolean array (dense representation) + or a two-dimensional numpy float array (coordinate representation). + dims, translate and scale are the model metadata. + dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model. + scale and translate relate the voxels to the original model coordinates. + To translate voxel coordinates i, j, k to original coordinates x, y, z: + x_n = (i+.5)/dims[0] + y_n = (j+.5)/dims[1] + z_n = (k+.5)/dims[2] + x = scale*x_n + translate[0] + y = scale*y_n + translate[1] + z = scale*z_n + translate[2] + """ + + def __init__(self, data, dims, translate, scale, axis_order): + self.data = data + self.dims = dims + self.translate = translate + self.scale = scale + assert (axis_order in ('xzy', 'xyz')) + self.axis_order = axis_order + + def clone(self): + data = self.data.copy() + dims = self.dims[:] + translate = self.translate[:] + return Voxels(data, dims, translate, self.scale, self.axis_order) + + def write(self, fp): + write(self, fp) + +def read_header(fp): + """ Read binvox header. Mostly meant for internal use. + """ + line = fp.readline().strip() + if not line.startswith(b'#binvox'): + raise IOError('Not a binvox file') + dims = list(map(int, fp.readline().strip().split(b' ')[1:])) + translate = list(map(float, fp.readline().strip().split(b' ')[1:])) + scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0] + line = fp.readline() + return dims, translate, scale + +def read_as_3d_array(fp, fix_coords=True): + """ Read binary binvox format as array. + Returns the model with accompanying metadata. + Voxels are stored in a three-dimensional numpy array, which is simple and + direct, but may use a lot of memory for large models. (Storage requirements + are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy + boolean arrays use a byte per element). + Doesn't do any checks on input except for the '#binvox' line. + """ + dims, translate, scale = read_header(fp) + raw_data = np.frombuffer(fp.read(), dtype=np.uint8) + # if just using reshape() on the raw data: + # indexing the array as array[i,j,k], the indices map into the + # coords as: + # i -> x + # j -> z + # k -> y + # if fix_coords is true, then data is rearranged so that + # mapping is + # i -> x + # j -> y + # k -> z + values, counts = raw_data[::2], raw_data[1::2] + data = np.repeat(values, counts).astype(np.bool) + data = data.reshape(dims) + if fix_coords: + # xzy to xyz TODO the right thing + data = np.transpose(data, (0, 2, 1)) + axis_order = 'xyz' + else: + axis_order = 'xzy' + return Voxels(data, dims, translate, scale, axis_order) + +def read_as_coord_array(fp, fix_coords=True): + """ Read binary binvox format as coordinates. + Returns binvox model with voxels in a "coordinate" representation, i.e. an + 3 x N array where N is the number of nonzero voxels. Each column + corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates + of the voxel. (The odd ordering is due to the way binvox format lays out + data). Note that coordinates refer to the binvox voxels, without any + scaling or translation. + Use this to save memory if your model is very sparse (mostly empty). + Doesn't do any checks on input except for the '#binvox' line. + """ + dims, translate, scale = read_header(fp) + raw_data = np.frombuffer(fp.read(), dtype=np.uint8) + + values, counts = raw_data[::2], raw_data[1::2] + + sz = np.prod(dims) + index, end_index = 0, 0 + end_indices = np.cumsum(counts) + indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype) + + values = values.astype(np.bool) + indices = indices[values] + end_indices = end_indices[values] + + nz_voxels = [] + for index, end_index in zip(indices, end_indices): + nz_voxels.extend(range(index, end_index)) + nz_voxels = np.array(nz_voxels) + # TODO are these dims correct? + # according to docs, + # index = x * wxh + z * width + y; // wxh = width * height = d * d + + x = nz_voxels / (dims[0]*dims[1]) + zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y + z = zwpy / dims[0] + y = zwpy % dims[0] + if fix_coords: + data = np.vstack((x, y, z)) + axis_order = 'xyz' + else: + data = np.vstack((x, z, y)) + axis_order = 'xzy' + + #return Voxels(data, dims, translate, scale, axis_order) + return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order) + +def dense_to_sparse(voxel_data, dtype=np.int): + """ From dense representation to sparse (coordinate) representation. + No coordinate reordering. + """ + if voxel_data.ndim!=3: + raise ValueError('voxel_data is wrong shape; should be 3D array.') + return np.asarray(np.nonzero(voxel_data), dtype) + +def sparse_to_dense(voxel_data, dims, dtype=np.bool): + if voxel_data.ndim!=2 or voxel_data.shape[0]!=3: + raise ValueError('voxel_data is wrong shape; should be 3xN array.') + if np.isscalar(dims): + dims = [dims]*3 + dims = np.atleast_2d(dims).T + # truncate to integers + xyz = voxel_data.astype(np.int) + # discard voxels that fall outside dims + valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0) + xyz = xyz[:,valid_ix] + out = np.zeros(dims.flatten(), dtype=dtype) + out[tuple(xyz)] = True + return out + +#def get_linear_index(x, y, z, dims): + #""" Assuming xzy order. (y increasing fastest. + #TODO ensure this is right when dims are not all same + #""" + #return x*(dims[1]*dims[2]) + z*dims[1] + y + +def write(voxel_model, fp): + """ Write binary binvox format. + Note that when saving a model in sparse (coordinate) format, it is first + converted to dense format. + Doesn't check if the model is 'sane'. + """ + if voxel_model.data.ndim==2: + # TODO avoid conversion to dense + dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims) + else: + dense_voxel_data = voxel_model.data + + fp.write('#binvox 1\n') + fp.write('dim '+' '.join(map(str, voxel_model.dims))+'\n') + fp.write('translate '+' '.join(map(str, voxel_model.translate))+'\n') + fp.write('scale '+str(voxel_model.scale)+'\n') + fp.write('data\n') + if not voxel_model.axis_order in ('xzy', 'xyz'): + raise ValueError('Unsupported voxel model axis order') + + if voxel_model.axis_order=='xzy': + voxels_flat = dense_voxel_data.flatten() + elif voxel_model.axis_order=='xyz': + voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten() + + # keep a sort of state machine for writing run length encoding + state = voxels_flat[0] + ctr = 0 + for c in voxels_flat: + if c==state: + ctr += 1 + # if ctr hits max, dump + if ctr==255: + fp.write(chr(state)) + fp.write(chr(ctr)) + ctr = 0 + else: + # if switch state, dump + fp.write(chr(state)) + fp.write(chr(ctr)) + state = c + ctr = 1 + # flush out remainders + if ctr > 0: + fp.write(chr(state)) + fp.write(chr(ctr)) + +if __name__ == '__main__': + import doctest + doctest.testmod() \ No newline at end of file diff --git a/utils/conversion.py b/utils/conversion.py new file mode 100644 index 0000000..1d90335 --- /dev/null +++ b/utils/conversion.py @@ -0,0 +1,46 @@ + +from skimage import measure +import numpy as np + + +def get_mesh(tsdf_vol, color_vol, threshold=0, vol_max=.5, vol_min=-.5): + """Compute a mesh from the voxel volume using marching cubes. + """ + vol_origin = vol_min + voxel_size = (vol_max - vol_min) / tsdf_vol.shape[-1] + + # Marching cubes + verts, faces, norms, vals = measure.marching_cubes_lewiner(tsdf_vol, level=threshold) + verts_ind = np.round(verts).astype(int) + verts = verts * voxel_size + vol_origin # voxel grid coordinates to world coordinates + + # Get vertex colors + if color_vol is None: + return verts, faces, norms + colors = color_vol[:, verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]].T + + return verts, faces, norms, colors + + +def get_point_cloud(tsdf_vol, color_vol, vol_max=0.5, vol_min=-0.5): + vol_origin = vol_min + voxel_size = (vol_max - vol_min) / tsdf_vol.shape[-1] + # Marching cubes + verts = measure.marching_cubes_lewiner(tsdf_vol, level=0)[0] + verts_ind = np.round(verts).astype(int) + verts = verts * voxel_size + vol_origin + + # Get vertex colors + colors = color_vol[:, verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]].T + + pc = np.hstack([verts, colors]) + + return pc + +def sparse_to_dense_voxel(coords, feats, res): + coords = coords.astype('int64', copy=False) + a = np.zeros((res, res, res), dtype=feats.dtype) + + a[coords[:,0],coords[:,1],coords[:,2] ] = feats[:,0].astype(a.dtype, copy=False) + + return a \ No newline at end of file diff --git a/utils/file_utils.py b/utils/file_utils.py new file mode 100644 index 0000000..e6dbe67 --- /dev/null +++ b/utils/file_utils.py @@ -0,0 +1,85 @@ +import os +import random +import sys + +from shutil import copyfile +import datetime + +import torch + +import logging +logger = logging.getLogger() + +import numpy as np + +def set_global_gpu_env(opt): + + os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' + os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu) + + + torch.cuda.set_device(opt.gpu) + +def copy_source(file, output_dir): + copyfile(file, os.path.join(output_dir, os.path.basename(file))) + + + +def setup_logging(output_dir): + log_format = logging.Formatter("%(asctime)s : %(message)s") + logger = logging.getLogger() + logger.handlers = [] + output_file = os.path.join(output_dir, 'output.log') + file_handler = logging.FileHandler(output_file) + file_handler.setFormatter(log_format) + logger.addHandler(file_handler) + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(log_format) + err_handler = logging.StreamHandler(sys.stderr) + err_handler.setFormatter(log_format) + logger.addHandler(err_handler) + logger.setLevel(logging.INFO) + + return logger + + +def get_output_dir(prefix, exp_id): + t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + output_dir = os.path.join(prefix, 'output/' + exp_id, t) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + return output_dir + + + +def set_seed(opt): + + if opt.manualSeed is None: + opt.manualSeed = random.randint(1, 10000) + print("Random Seed: ", opt.manualSeed) + random.seed(opt.manualSeed) + torch.manual_seed(opt.manualSeed) + np.random.seed(opt.manualSeed) + if opt.gpu is not None and torch.cuda.is_available(): + torch.cuda.manual_seed_all(opt.manualSeed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + +def setup_output_subdirs(output_dir, *subfolders): + + output_subdirs = output_dir + try: + os.makedirs(output_subdirs) + except OSError: + pass + + subfolder_list = [] + for sf in subfolders: + curr_subf = os.path.join(output_subdirs, sf) + try: + os.makedirs(curr_subf) + except OSError: + pass + subfolder_list.append(curr_subf) + + return subfolder_list \ No newline at end of file diff --git a/utils/metrics.py b/utils/metrics.py new file mode 100644 index 0000000..ec25e26 --- /dev/null +++ b/utils/metrics.py @@ -0,0 +1,232 @@ +import numpy as np +import warnings + +from scipy.stats import entropy + +def iterate_in_chunks(l, n): + '''Yield successive 'n'-sized chunks from iterable 'l'. + Note: last chunk will be smaller than l if n doesn't divide l perfectly. + ''' + for i in range(0, len(l), n): + yield l[i:i + n] + +def unit_cube_grid_point_cloud(resolution, clip_sphere=False): + '''Returns the center coordinates of each cell of a 3D grid with resolution^3 cells, + that is placed in the unit-cube. + If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere. + ''' + grid = np.ndarray((resolution, resolution, resolution, 3), np.float32) + spacing = 1.0 / float(resolution - 1) + for i in range(resolution): + for j in range(resolution): + for k in range(resolution): + grid[i, j, k, 0] = i * spacing - 0.5 + grid[i, j, k, 1] = j * spacing - 0.5 + grid[i, j, k, 2] = k * spacing - 0.5 + + if clip_sphere: + grid = grid.reshape(-1, 3) + grid = grid[np.linalg.norm(grid, axis=1) <= 0.5] + + return grid, spacing + +def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False, + use_EMD=False): + '''Computes the MMD between two sets of point-clouds. + Args: + sample_pcs (numpy array SxKx3): the S point-clouds, each of K points that will be matched and + compared to a set of "reference" point-clouds. + ref_pcs (numpy array RxKx3): the R point-clouds, each of K points that constitute the set of + "reference" point-clouds. + batch_size (int): specifies how large will the batches be that the compute will use to make + the comparisons of the sample-vs-ref point-clouds. + normalize (boolean): if True, the distances are normalized by diving them with + the number of the points of the point-clouds (n_pc_points). + use_sqrt: (boolean): When the matching is based on Chamfer (default behavior), if True, the + Chamfer is computed based on the (not-squared) euclidean distances of the matched point-wise + euclidean distances. + sess (tf.Session, default None): if None, it will make a new Session for this. + use_EMD (boolean: If true, the matchings are based on the EMD. + Returns: + A tuple containing the MMD and all the matched distances of which the MMD is their mean. + ''' + + n_ref, n_pc_points, pc_dim = ref_pcs.shape + _, n_pc_points_s, pc_dim_s = sample_pcs.shape + + if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s: + raise ValueError('Incompatible size of point-clouds.') + + ref_pl, sample_pl, best_in_batch, _, sess = minimum_mathing_distance_tf_graph(n_pc_points, normalize=normalize, + sess=sess, use_sqrt=use_sqrt, + use_EMD=use_EMD) + matched_dists = [] + for i in range(n_ref): + best_in_all_batches = [] + if verbose and i % 50 == 0: + print(i) + + for sample_chunk in iterate_in_chunks(sample_pcs, batch_size): + feed_dict = {ref_pl: np.expand_dims(ref_pcs[i], 0), sample_pl: sample_chunk} + b = sess.run(best_in_batch, feed_dict=feed_dict) + best_in_all_batches.append(b) + matched_dists.append(np.min(best_in_all_batches)) + mmd = np.mean(matched_dists) + return mmd, matched_dists + + +def coverage(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False, use_EMD=False, + ret_dist=False): + '''Computes the Coverage between two sets of point-clouds. + Args: + sample_pcs (numpy array SxKx3): the S point-clouds, each of K points that will be matched + and compared to a set of "reference" point-clouds. + ref_pcs (numpy array RxKx3): the R point-clouds, each of K points that constitute the + set of "reference" point-clouds. + batch_size (int): specifies how large will the batches be that the compute will use to + make the comparisons of the sample-vs-ref point-clouds. + normalize (boolean): if True, the distances are normalized by diving them with + the number of the points of the point-clouds (n_pc_points). + use_sqrt (boolean): When the matching is based on Chamfer (default behavior), if True, + the Chamfer is computed based on the (not-squared) euclidean distances of the matched + point-wise euclidean distances. + sess (tf.Session): If None, it will make a new Session for this. + use_EMD (boolean): If true, the matchings are based on the EMD. + ret_dist (boolean): If true, it will also return the distances between each sample_pcs and + it's matched ground-truth. + Returns: the coverage score (int), + the indices of the ref_pcs that are matched with each sample_pc + and optionally the matched distances of the samples_pcs. + ''' + n_ref, n_pc_points, pc_dim = ref_pcs.shape + n_sam, n_pc_points_s, pc_dim_s = sample_pcs.shape + + if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s: + raise ValueError('Incompatible Point-Clouds.') + + ref_pl, sample_pl, best_in_batch, loc_of_best, sess = minimum_mathing_distance_tf_graph(n_pc_points, + normalize=normalize, + sess=sess, + use_sqrt=use_sqrt, + use_EMD=use_EMD) + matched_gt = [] + matched_dist = [] + for i in xrange(n_sam): + best_in_all_batches = [] + loc_in_all_batches = [] + + if verbose and i % 50 == 0: + print + i + + for ref_chunk in iterate_in_chunks(ref_pcs, batch_size): + feed_dict = {ref_pl: np.expand_dims(sample_pcs[i], 0), sample_pl: ref_chunk} + b, loc = sess.run([best_in_batch, loc_of_best], feed_dict=feed_dict) + best_in_all_batches.append(b) + loc_in_all_batches.append(loc) + + best_in_all_batches = np.array(best_in_all_batches) + b_hit = np.argmin(best_in_all_batches) # In which batch the minimum occurred. + matched_dist.append(np.min(best_in_all_batches)) + hit = np.array(loc_in_all_batches)[b_hit] + matched_gt.append(batch_size * b_hit + hit) + + cov = len(np.unique(matched_gt)) / float(n_ref) + + if ret_dist: + return cov, matched_gt, matched_dist + else: + return cov, matched_gt + + +def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28): + '''Computes the JSD between two sets of point-clouds, as introduced in the paper ```Learning Representations And Generative Models For 3D Point Clouds```. + Args: + sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points. + ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points. + resolution: (int) grid-resolution. Affects granularity of measurements. + ''' + in_unit_sphere = True + sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1] + ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1] + return jensen_shannon_divergence(sample_grid_var, ref_grid_var) + + +def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False): + '''Given a collection of point-clouds, estimate the entropy of the random variables + corresponding to occupancy-grid activation patterns. + Inputs: + pclouds: (numpy array) #point-clouds x points per point-cloud x 3 + grid_resolution (int) size of occupancy grid that will be used. + ''' + epsilon = 10e-4 + bound = 0.5 + epsilon + if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound: + warnings.warn('Point-clouds are not in unit cube.') + + if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound: + warnings.warn('Point-clouds are not in unit sphere.') + + grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere) + grid_coordinates = grid_coordinates.reshape(-1, 3) + grid_counters = np.zeros(len(grid_coordinates)) + grid_bernoulli_rvars = np.zeros(len(grid_coordinates)) + nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates) + + for pc in pclouds: + _, indices = nn.kneighbors(pc) + indices = np.squeeze(indices) + for i in indices: + grid_counters[i] += 1 + indices = np.unique(indices) + for i in indices: + grid_bernoulli_rvars[i] += 1 + + acc_entropy = 0.0 + n = float(len(pclouds)) + for g in grid_bernoulli_rvars: + p = 0.0 + if g > 0: + p = float(g) / n + acc_entropy += entropy([p, 1.0 - p]) + + return acc_entropy / len(grid_counters), grid_counters + +def jensen_shannon_divergence(P, Q): + if np.any(P < 0) or np.any(Q < 0): + raise ValueError('Negative values.') + if len(P) != len(Q): + raise ValueError('Non equal size.') + + P_ = P / np.sum(P) # Ensure probabilities. + Q_ = Q / np.sum(Q) + + e1 = entropy(P_, base=2) + e2 = entropy(Q_, base=2) + e_sum = entropy((P_ + Q_) / 2.0, base=2) + res = e_sum - ((e1 + e2) / 2.0) + + res2 = _jsdiv(P_, Q_) + + if not np.allclose(res, res2, atol=10e-5, rtol=0): + warnings.warn('Numerical values of two JSD methods don\'t agree.') + + return res + + +def _jsdiv(P, Q): + '''another way of computing JSD''' + def _kldiv(A, B): + a = A.copy() + b = B.copy() + idx = np.logical_and(a > 0, b > 0) + a = a[idx] + b = b[idx] + return np.sum([v for v in a * np.log2(a / b)]) + + P_ = P / np.sum(P) + Q_ = Q / np.sum(Q) + + M = 0.5 * (P_ + Q_) + + return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M)) \ No newline at end of file diff --git a/utils/mitsuba_renderer.py b/utils/mitsuba_renderer.py new file mode 100644 index 0000000..26cc3bf --- /dev/null +++ b/utils/mitsuba_renderer.py @@ -0,0 +1,146 @@ +import numpy as np +from pathlib import Path +import os + + +def standardize_bbox(pcl, points_per_object, scale=None): + pt_indices = np.random.choice(pcl.shape[0], points_per_object, replace=False) + np.random.shuffle(pt_indices) + pcl = pcl[pt_indices] # n by 3 + mins = np.amin(pcl, axis=0) + maxs = np.amax(pcl, axis=0) + center = (mins + maxs) / 2. + if scale is None: + scale = np.amax(maxs - mins) + result = ((pcl - center) / scale).astype(np.float32) # [-0.5, 0.5] + return result + + +xml_head = \ + """ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + """ + +xml_ball_segment = \ + """ + + + + + + + + + + """ + +xml_tail = \ + """ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + """ + + +def colormap_fn(x, y, z): + vec = np.array([x, y, z]) + vec = np.clip(vec, 0.001, 1.0) + norm = np.sqrt(np.sum(vec ** 2)) + vec /= norm + return [vec[0], vec[1], vec[2]] + + +color_dict = {'r': [163, 102, 96], 'g': [20, 130, 3], + 'o': [145, 128, 47], 'b': [91, 102, 112], 'p':[133,111,139], 'br':[111,92,81]} + +color_map = {'airplane': 'r', 'chair': 'o', 'car': 'b', 'table': 'p', 'lamp':'br'} +fov_map = {'airplane': 12, 'chair': 16, 'car':15, 'table': 13, 'lamp':13} +radius_map = {'airplane': 0.02, 'chair': 0.035, 'car': 0.01, 'table':0.035, 'lamp':0.035} + +def write_to_xml_batch(dir, pcl_batch, filenames=None, color_batch=None, cat='airplane'): + default_color = color_map[cat] + Path(dir).mkdir(parents=True, exist_ok=True) + if filenames is not None: + assert len(filenames) == pcl_batch.shape[0] + # mins = np.amin(pcl_batch, axis=(0,1)) + # maxs = np.amax(pcl_batch, axis=(0,1)) + # scale = 1; print(np.amax(maxs - mins)) + + for k, pcl in enumerate(pcl_batch): + xml_segments = [xml_head.format(fov_map[cat])] + pcl = standardize_bbox(pcl, pcl.shape[0]) + pcl = pcl[:, [2, 0, 1]] + pcl[:, 0] *= -1 + pcl[:, 2] += 0.0125 + for i in range(pcl.shape[0]): + if color_batch is not None: + color = color_batch[k, i] + else: + color = np.array(color_dict[default_color]) / 255 + # color = colormap_fn(pcl[i,0]+0.5,pcl[i,1]+0.5,pcl[i,2]+0.5-0.0125) + xml_segments.append(xml_ball_segment.format(radius_map[cat], pcl[i, 0], pcl[i, 1], pcl[i, 2], *color)) + xml_segments.append( + xml_tail.format(pcl[:, 2].min())) + + xml_content = str.join('', xml_segments) + + if filenames is None: + fn = 'sample_{}.xml'.format(k) + else: + fn = filenames[k] + with open(os.path.join(dir, fn), 'w') as f: + f.write(xml_content) diff --git a/utils/mitsuba_renderer2.py b/utils/mitsuba_renderer2.py new file mode 100644 index 0000000..cb725fd --- /dev/null +++ b/utils/mitsuba_renderer2.py @@ -0,0 +1,170 @@ +import numpy as np +from pathlib import Path +import os + + +def standardize_bbox(pcl, points_per_object): + pt_indices = np.random.choice(pcl.shape[0], points_per_object, replace=False) + np.random.shuffle(pt_indices) + pcl = pcl[pt_indices] # n by 3 + mins = np.amin(pcl, axis=0) + maxs = np.amax(pcl, axis=0) + center = (mins + maxs) / 2. + scale = np.amax(maxs - mins) + result = ((pcl - center) / scale).astype(np.float32) # [-0.5, 0.5] + return result + + +xml_head = \ + """ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + """ + +xml_ball_segment = \ + """ + + + + + + + + + + """ + +xml_tail = \ + """ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + """ + + +def colormap_fn(x, y, z): + vec = np.array([x, y, z]) + vec = np.clip(vec, 0.001, 1.0) + norm = np.sqrt(np.sum(vec ** 2)) + vec /= norm + return [vec[0], vec[1], vec[2]] + + +color_dict = {'r': [163, 102, 96], 'p': [133,111,139], 'g': [20, 130, 3], + 'o': [145, 128, 47], 'b': [91, 102, 112]} + +color_map = {'airplane': 'r', 'chair': 'o', 'car': 'b', 'table': 'p'} +fov_map = {'airplane': 12, 'chair': 15, 'car':12, 'table':12} +radius_map = {'airplane': 0.0175, 'chair': 0.035, 'car': 0.025, 'table': 0.02} + +def write_to_xml_batch(dir, pcl_batch, color_batch=None, cat='airplane', elev=15, azim=45, radius=np.sqrt(18)): + elev_rad = elev * np.pi / 180 + azim_rad = azim * np.pi / 180 + + x = radius * np.cos(elev_rad)*np.cos(azim_rad) + y = radius * np.cos(elev_rad)*np.sin(azim_rad) + z = radius * np.sin(elev_rad) + + default_color = color_map[cat] + Path(dir).mkdir(parents=True, exist_ok=True) + for k, pcl in enumerate(pcl_batch): + xml_segments = [xml_head.format(x,y,z)] + pcl = standardize_bbox(pcl, pcl.shape[0]) + pcl = pcl[:, [2, 0, 1]] + pcl[:, 0] *= -1 + pcl[:, 2] += 0.0125 + for i in range(pcl.shape[0]): + if color_batch is not None: + color = color_batch[k, i] + else: + color = np.array(color_dict[default_color]) / 255 + # color = colormap_fn(pcl[i,0]+0.5,pcl[i,1]+0.5,pcl[i,2]+0.5-0.0125) + xml_segments.append(xml_ball_segment.format(0.0175, pcl[i, 0], pcl[i, 1], pcl[i, 2], *color)) + xml_segments.append( + xml_tail.format(pcl[:, 2].min())) + + xml_content = str.join('', xml_segments) + + with open(os.path.join(dir, 'sample_{}.xml'.format(k)), 'w') as f: + f.write(xml_content) + +def write_to_xml(file, pcl, cat='airplane', elev=15, azim=45, radius=np.sqrt(18)): + assert pcl.ndim == 2 + elev_rad = elev * np.pi / 180 + azim_rad = azim * np.pi / 180 + + x = radius * np.cos(elev_rad)*np.cos(azim_rad) + y = radius * np.cos(elev_rad)*np.sin(azim_rad) + z = radius * np.sin(elev_rad) + + default_color = color_map[cat] + + xml_segments = [xml_head.format(x,y,z)] + pcl = standardize_bbox(pcl, pcl.shape[0]) + pcl = pcl[:, [2, 0, 1]] + pcl[:, 0] *= -1 + pcl[:, 2] += 0.0125 + for i in range(pcl.shape[0]): + color = np.array(color_dict[default_color]) / 255 + # color = colormap_fn(pcl[i,0]+0.5,pcl[i,1]+0.5,pcl[i,2]+0.5-0.0125) + xml_segments.append(xml_ball_segment.format(radius_map[cat], pcl[i, 0], pcl[i, 1], pcl[i, 2], *color)) + xml_segments.append( + xml_tail.format(pcl[:, 2].min())) + + xml_content = str.join('', xml_segments) + + with open(file, 'w') as f: + f.write(xml_content) diff --git a/utils/visualize.py b/utils/visualize.py new file mode 100644 index 0000000..8153f8a --- /dev/null +++ b/utils/visualize.py @@ -0,0 +1,222 @@ +import matplotlib +matplotlib.use('agg') +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +from mpl_toolkits.mplot3d.art3d import Poly3DCollection +import numpy as np +import os +import trimesh +from pathlib import Path + +''' +Custom visualization +''' + +def export_to_pc_batch(dir, pcs, colors=None): + + Path(dir).mkdir(parents=True, exist_ok=True) + for i, xyz in enumerate(pcs): + if colors is None: + color = None + else: + color = colors[i] + pcwrite(os.path.join(dir, 'sample_'+str(i)+'.ply'), xyz, color) + + +def export_to_obj(dir, meshes, transform=lambda v,f:(v,f)): + ''' + transform: f(vertices, faces) --> transformed (vertices, faces) + ''' + Path(dir).mkdir(parents=True, exist_ok=True) + for i, data in enumerate(meshes): + v, f = transform(data[0], data[1]) + if len(data) > 2: + v_color = data[2] + else: + v_color = None + mesh = trimesh.Trimesh(v, f, vertex_colors=v_color) + out = trimesh.exchange.obj.export_obj(mesh) + with open(os.path.join(dir, 'sample_'+str(i)+'.obj'), 'w') as f: + f.write(out) + f.close() + +def export_to_obj_single(path, data, transform=lambda v,f:(v,f)): + ''' + transform: f(vertices, faces) --> transformed (vertices, faces) + ''' + v, f = transform(data[0], data[1]) + if len(data) > 2: + v_color = data[2] + else: + v_color = None + mesh = trimesh.Trimesh(v, f, vertex_colors=v_color) + out = trimesh.exchange.obj.export_obj(mesh) + with open(path, 'w') as f: + f.write(out) + f.close() + +def meshwrite(filename, verts, faces, norms, colors): + """Save a 3D mesh to a polygon .ply file. + """ + # Write header + ply_file = open(filename, 'w') + ply_file.write("ply\n") + ply_file.write("format ascii 1.0\n") + ply_file.write("element vertex %d\n" % (verts.shape[0])) + ply_file.write("property float x\n") + ply_file.write("property float y\n") + ply_file.write("property float z\n") + ply_file.write("property float nx\n") + ply_file.write("property float ny\n") + ply_file.write("property float nz\n") + ply_file.write("property uchar red\n") + ply_file.write("property uchar green\n") + ply_file.write("property uchar blue\n") + ply_file.write("element face %d\n" % (faces.shape[0])) + ply_file.write("property list uchar int vertex_index\n") + ply_file.write("end_header\n") + + # Write vertex list + for i in range(verts.shape[0]): + ply_file.write("%f %f %f %f %f %f %d %d %d\n" % ( + verts[i, 0], verts[i, 1], verts[i, 2], + norms[i, 0], norms[i, 1], norms[i, 2], + colors[i, 0], colors[i, 1], colors[i, 2], + )) + + # Write face list + for i in range(faces.shape[0]): + ply_file.write("3 %d %d %d\n" % (faces[i, 0], faces[i, 1], faces[i, 2])) + + ply_file.close() + + +def pcwrite(filename, xyz, rgb=None): + """Save a point cloud to a polygon .ply file. + """ + if rgb is None: + rgb = np.ones_like(xyz) * 128 + rgb = rgb.astype(np.uint8) + + # Write header + ply_file = open(filename, 'w') + ply_file.write("ply\n") + ply_file.write("format ascii 1.0\n") + ply_file.write("element vertex %d\n" % (xyz.shape[0])) + ply_file.write("property float x\n") + ply_file.write("property float y\n") + ply_file.write("property float z\n") + ply_file.write("property uchar red\n") + ply_file.write("property uchar green\n") + ply_file.write("property uchar blue\n") + ply_file.write("end_header\n") + + # Write vertex list + for i in range(xyz.shape[0]): + ply_file.write("%f %f %f %d %d %d\n" % ( + xyz[i, 0], xyz[i, 1], xyz[i, 2], + rgb[i, 0], rgb[i, 1], rgb[i, 2], + )) + +''' +Matplotlib Visualization +''' + +def visualize_voxels(out_file, voxels, num_shown=16, threshold=0.5): + r''' Visualizes voxel data. + show only first num_shown + ''' + batch_size =voxels.shape[0] + voxels = voxels.squeeze(1) > threshold + + num_shown = min(num_shown, batch_size) + + n = int(np.sqrt(num_shown)) + fig = plt.figure(figsize=(20,20)) + + for idx, pc in enumerate(voxels[:num_shown]): + if idx >= n*n: + break + pc = voxels[idx] + ax = fig.add_subplot(n, n, idx + 1, projection='3d') + ax.voxels(pc, edgecolor='k', facecolors='green', linewidth=0.1, alpha=0.5) + ax.view_init() + ax.axis('off') + plt.savefig(out_file, bbox_inches='tight') + plt.close() + +def visualize_pointcloud(points, normals=None, + out_file=None, show=False, elev=30, azim=225): + r''' Visualizes point cloud data. + Args: + points (tensor): point data + normals (tensor): normal data (if existing) + out_file (string): output file + show (bool): whether the plot should be shown + ''' + # Create plot + fig = plt.figure() + ax = fig.gca(projection=Axes3D.name) + ax.scatter(points[:, 2], points[:, 0], points[:, 1]) + if normals is not None: + ax.quiver( + points[:, 2], points[:, 0], points[:, 1], + normals[:, 2], normals[:, 0], normals[:, 1], + length=0.1, color='k' + ) + ax.set_xlabel('Z') + ax.set_ylabel('X') + ax.set_zlabel('Y') + # ax.set_xlim(-0.5, 0.5) + # ax.set_ylim(-0.5, 0.5) + # ax.set_zlim(-0.5, 0.5) + ax.view_init(elev=elev, azim=azim) + if out_file is not None: + plt.savefig(out_file) + if show: + plt.show() + plt.close(fig) + + +def visualize_pointcloud_batch(path, pointclouds, pred_labels, labels, categories, vis_label=False, target=None, elev=30, azim=225): + batch_size = len(pointclouds) + fig = plt.figure(figsize=(20,20)) + + ncols = int(np.sqrt(batch_size)) + nrows = max(1, (batch_size-1) // ncols+1) + for idx, pc in enumerate(pointclouds): + if vis_label: + label = categories[labels[idx].item()] + pred = categories[pred_labels[idx]] + colour = 'g' if label == pred else 'r' + elif target is None: + + colour = 'g' + else: + colour = target[idx] + pc = pc.cpu().numpy() + ax = fig.add_subplot(nrows, ncols, idx + 1, projection='3d') + ax.scatter(pc[:, 0], pc[:, 2], pc[:, 1], c=colour, s=5) + ax.view_init(elev=elev, azim=azim) + ax.axis('off') + if vis_label: + ax.set_title('GT: {0}\nPred: {1}'.format(label, pred)) + + plt.savefig(path) + plt.close(fig) + + +''' +Plot stats +''' + +def plot_stats(output_dir, stats, interval): + content = stats.keys() + # f = plt.figure(figsize=(20, len(content) * 5)) + f, axs = plt.subplots(len(content), 1, figsize=(20, len(content) * 5)) + for j, (k, v) in enumerate(stats.items()): + axs[j].plot(interval, v) + axs[j].set_ylabel(k) + + f.savefig(os.path.join(output_dir, 'stat.pdf'), bbox_inches='tight') + plt.close(f) diff --git a/utils/xml_from_mesh.py b/utils/xml_from_mesh.py new file mode 100644 index 0000000..b833148 --- /dev/null +++ b/utils/xml_from_mesh.py @@ -0,0 +1,86 @@ +import sys + +sys.path.append('..') +import argparse +import os +import numpy as np +import trimesh +import glob +from joblib import Parallel, delayed +import re +from utils.mitsuba_renderer import write_to_xml_batch +def rotation_matrix(axis, theta): + """ + Return the rotation matrix associated with counterclockwise rotation about + the given axis by theta radians. + """ + axis = np.asarray(axis) + axis = axis / np.sqrt(np.dot(axis, axis)) + a = np.cos(theta / 2.0) + b, c, d = -axis * np.sin(theta / 2.0) + aa, bb, cc, dd = a * a, b * b, c * c, d * d + bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d + return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], + [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], + [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) + +def rotate(vertices, faces): + ''' + vertices: [numpoints, 3] + ''' + N = rotation_matrix([1, 0, 0], 3* np.pi / 4).transpose() + # M = rotation_matrix([0, 1, 0], -np.pi / 2).transpose() + + + v, f = vertices.dot(N), faces + return v, f + +def as_mesh(scene_or_mesh): + if isinstance(scene_or_mesh, trimesh.Scene): + mesh = trimesh.util.concatenate([ + trimesh.Trimesh(vertices=m.vertices, faces=m.faces) + for m in scene_or_mesh.geometry.values()]) + else: + mesh = scene_or_mesh + return mesh +def process_one(shape_dir, cat): + pc_paths = glob.glob(os.path.join(shape_dir, "*.obj")) + pc_paths = sorted(pc_paths) + + xml_paths = [] #[re.sub('.ply', '.xml', os.path.basename(pth)) for pth in pc_paths] + + gen_pcs = [] + for path in pc_paths: + sample_mesh = trimesh.load(path, force='mesh') + v, f = rotate(sample_mesh.vertices,sample_mesh.faces) + mesh = trimesh.Trimesh(v, f) + sample_pts = trimesh.sample.sample_surface(mesh, 2048)[0] + gen_pcs.append(sample_pts) + xml_paths.append(re.sub('.obj', '.xml', os.path.basename(path))) + + + + gen_pcs = np.stack(gen_pcs, axis=0) + write_to_xml_batch(os.path.dirname(pc_paths[0]), gen_pcs, xml_paths, cat=cat) + + +def process(args): + shape_names = [n for n in sorted(os.listdir(args.src)) if + os.path.isdir(os.path.join(args.src, n)) and not n.startswith('x')] + + all_shape_dir = [os.path.join(args.src, name) for name in shape_names] + + Parallel(n_jobs=10, verbose=2)(delayed(process_one)(path) for path in all_shape_dir) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--src", type=str) + parser.add_argument("--cat", type=str) + args = parser.parse_args() + + process_one(args.src, args.cat) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/utils/xml_from_ply.py b/utils/xml_from_ply.py new file mode 100644 index 0000000..6a0aff1 --- /dev/null +++ b/utils/xml_from_ply.py @@ -0,0 +1,54 @@ +import sys + +sys.path.append('..') +import argparse +import os +import numpy as np +import trimesh +import glob +from joblib import Parallel, delayed +import re +from utils.mitsuba_renderer import write_to_xml_batch + + +def process_one(shape_dir): + pc_paths = glob.glob(os.path.join(shape_dir, "fake*.ply")) + pc_paths = sorted(pc_paths) + + xml_paths = [re.sub('.ply', '.xml', os.path.basename(pth)) for pth in pc_paths] + + gen_pcs = [] + for path in pc_paths: + sample_pts = trimesh.load(path) + sample_pts = np.array(sample_pts.vertices) + gen_pcs.append(sample_pts) + + raw_pc = np.array(trimesh.load(os.path.join(shape_dir, "raw.ply")).vertices) + raw_pc = np.concatenate([raw_pc, np.tile(raw_pc[0:1], (gen_pcs[0].shape[0]-raw_pc.shape[0],1))]) + + gen_pcs.append(raw_pc) + gen_pcs = np.stack(gen_pcs, axis=0) + xml_paths.append('raw.xml') + + write_to_xml_batch(os.path.dirname(pc_paths[0]), gen_pcs, xml_paths) + + +def process(args): + shape_names = [n for n in sorted(os.listdir(args.src)) if + os.path.isdir(os.path.join(args.src, n)) and not n.startswith('x')] + + all_shape_dir = [os.path.join(args.src, name) for name in shape_names] + + Parallel(n_jobs=10, verbose=2)(delayed(process_one)(path) for path in all_shape_dir) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--src", type=str) + args = parser.parse_args() + + process_one(args) + + +if __name__ == '__main__': + main() \ No newline at end of file