diff --git a/README.md b/README.md index a7bbcfe..b8bc816 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,13 @@ # Shape Generation and Completion Through Point-Voxel Diffusion +

+ +

-[Project]() | [Paper]() +[Project](https://alexzhou907.github.io/pvd) | [Paper](https://arxiv.org/abs/2104.03670) -Implementation of +Implementation of Shape Generation and Completion Through Point-Voxel Diffusion -## Pretrained Models - -Pretrained models can be accessed [here](https://www.dropbox.com/s/a3xydf594fzaokl/cifar10_pretrained.rar?dl=0). +[Linqi Zhou](https://alexzhou907.github.io), [Yilun Du](https://yilundu.github.io/), [Jiajun Wu](https://jiajunwu.com/) ## Requirements: @@ -20,24 +21,56 @@ cudatoolkit==10.1 matplotlib==2.2.5 tqdm==4.32.1 open3d==0.9.0 +trimesh=3.7.12 +scipy==1.5.1 ``` + +Install PyTorchEMD by +``` +cd metrics/PyTorchEMD +python setup.py install +cp build/**/emd_cuda.cpython-36m-x86_64-linux-gnu.so . +``` + The code was tested on Unbuntu with Titan RTX. +## Data -## Training on CIFAR-10: +For generation, we use ShapeNet point cloud, which can be downloaded [here](https://github.com/stevenygd/PointFlow). + +For completion, we use ShapeNet rendering provided by [GenRe](https://github.com/xiumingzhang/GenRe-ShapeHD). +We provide script `convert_cam_params.py` to process the provided data. + +For training the model on shape completion, we need camera parameters for each view +which are not directly available. To obtain these, simply run +```bash +$ python convert_cam_params.py --dataroot DATA_DIR --mitsuba_xml_root XML_DIR +``` +which will create `..._cam_params.npz` in each provided data folder for each view. + +## Pretrained models +Pretrained models can be downloaded [here](https://drive.google.com/drive/folders/1Q7aSaTr6lqmo8qx80nIm1j28mOHAHGiM?usp=sharing). + +## Training: ```bash -$ python train_cifar.py +$ python train_generation.py --category car|chair|airplane ``` Please refer to the python file for optimal training parameters. +## Testing: + +```bash +$ python train_generation.py --category car|chair|airplane --model MODEL_PATH +``` + ## Results Some generative results are as follows.

- - + +

diff --git a/assets/gen_comp.gif b/assets/gen_comp.gif new file mode 100644 index 0000000..a169382 Binary files /dev/null and b/assets/gen_comp.gif differ diff --git a/assets/mm_partnet.gif b/assets/mm_partnet.gif new file mode 100644 index 0000000..28defd9 Binary files /dev/null and b/assets/mm_partnet.gif differ diff --git a/assets/mm_redwood.gif b/assets/mm_redwood.gif new file mode 100644 index 0000000..cb786e0 Binary files /dev/null and b/assets/mm_redwood.gif differ diff --git a/assets/mm_shapenet.gif b/assets/mm_shapenet.gif new file mode 100644 index 0000000..a9231d2 Binary files /dev/null and b/assets/mm_shapenet.gif differ diff --git a/assets/pvd_teaser.gif b/assets/pvd_teaser.gif new file mode 100644 index 0000000..37c03ad Binary files /dev/null and b/assets/pvd_teaser.gif differ diff --git a/convert_cam_params.py b/convert_cam_params.py new file mode 100644 index 0000000..36abb2b --- /dev/null +++ b/convert_cam_params.py @@ -0,0 +1,123 @@ + +from glob import glob +import re +import argparse +import numpy as np +from pathlib import Path +import os + +def raw_camparam_from_xml(path, pose="lookAt"): + import xml.etree.ElementTree as ET + tree = ET.parse(path) + elm = tree.find("./sensor/transform/" + pose) + camparam = elm.attrib + origin = np.fromstring(camparam['origin'], dtype=np.float32, sep=',') + target = np.fromstring(camparam['target'], dtype=np.float32, sep=',') + up = np.fromstring(camparam['up'], dtype=np.float32, sep=',') + height = int( + tree.find("./sensor/film/integer[@name='height']").attrib['value']) + width = int( + tree.find("./sensor/film/integer[@name='width']").attrib['value']) + + camparam = dict() + camparam['origin'] = origin + camparam['up'] = up + camparam['target'] = target + camparam['height'] = height + camparam['width'] = width + return camparam + +def get_cam_pos(origin, target, up): + inward = origin - target + right = np.cross(up, inward) + up = np.cross(inward, right) + rx = np.cross(up, inward) + ry = np.array(up) + rz = np.array(inward) + rx /= np.linalg.norm(rx) + ry /= np.linalg.norm(ry) + rz /= np.linalg.norm(rz) + + rot = np.stack([ + rx, + ry, + -rz + ], axis=0) + + + aff = np.concatenate([ + np.eye(3), -origin[:,None] + ], axis=1) + + + ext = np.matmul(rot, aff) + + result = np.concatenate( + [ext, np.array([[0,0,0,1]])], axis=0 + ) + + + + return result + + + +def convert_cam_params_all_views(datapoint_dir, dataroot, camera_param_dir): + depths = sorted(glob(os.path.join(datapoint_dir, '*depth.png'))) + cam_ext = ['_'.join(re.sub(dataroot.strip('/'), camera_param_dir.strip('/'), f).split('_')[:-1])+'.xml' for f in depths] + + + for i, (f, pth) in enumerate(zip(cam_ext, depths)): + if not os.path.exists(f): + continue + params=raw_camparam_from_xml(f) + origin, target, up, width, height = params['origin'], params['target'], params['up'],\ + params['width'], params['height'] + + ext_matrix = get_cam_pos(origin, target, up) + + ##### + diag = (0.036 ** 2 + 0.024 ** 2) ** 0.5 + focal_length = 0.05 + res = [480, 480] + h_relative = (res[1] / res[0]) + sensor_width = np.sqrt(diag ** 2 / (1 + h_relative ** 2)) + pix_size = sensor_width / res[0] + + K = np.array([ + [focal_length / pix_size, 0, (sensor_width / pix_size - 1) / 2], + [0, -focal_length / pix_size, (sensor_width * (res[1] / res[0]) / pix_size - 1) / 2], + [0, 0, 1] + ]) + + np.savez(pth.split('depth.png')[0]+ 'cam_params.npz', extr=ext_matrix, intr=K) + + +def main(opt): + dataroot_dir = Path(opt.dataroot) + + leaf_subdirs = [] + + for dirpath, dirnames, filenames in os.walk(dataroot_dir): + if (not dirnames) and opt.mitsuba_xml_root not in dirpath: + leaf_subdirs.append(dirpath) + + + + for k, dir_ in enumerate(leaf_subdirs): + print('Processing dir {}/{}: {}'.format(k, len(leaf_subdirs), dir_)) + + convert_cam_params_all_views(dir_, opt.dataroot, opt.mitsuba_xml_root) + + + + + +if __name__ == '__main__': + args = argparse.ArgumentParser() + args.add_argument('--dataroot', type=str, default='GenReData/') + args.add_argument('--mitsuba_xml_root', type=str, default='GenReData/genre-xml_v2') + + opt = args.parse_args() + + main(opt) diff --git a/datasets/partnet.py b/datasets/partnet.py index 72417a0..f074663 100644 --- a/datasets/partnet.py +++ b/datasets/partnet.py @@ -5,9 +5,7 @@ import os import json import random import trimesh -import csv from plyfile import PlyData, PlyElement -from glob import glob def project_pc_to_image(points, resolution=64): """project point clouds into 2D image @@ -181,33 +179,3 @@ class GANdatasetPartNet(Dataset): - -if __name__ == '__main__': - data_root = '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetPointCloud' - data_raw_root = '/viscam/u/alexzhou907/01DATA/shapenet/shapenet_dim32_sdf_pc' - pc_dataroot = '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k' - - sn_root = '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2' - classes = 'car' - npoints = 2048 - # from datasets.shapenet_data_pc import ShapeNet15kPointClouds - # pc_ds = ShapeNet15kPointClouds(root_dir=pc_dataroot, - # categories=[classes], split='train', - # tr_sample_size=npoints, - # te_sample_size=npoints, - # scale=1., - # normalize_per_shape=False, - # normalize_std_per_axis=False, - # random_subsample=True) - - train_ds = GANdatasetPartNet('test', pc_dataroot, data_raw_root, classes, npoints, np.array([0,0,0]), - np.array([1, 1, 1])) - - d1 = train_ds[0] - real = d1['real'] - raw = d1['raw'] - m, s = d1['m'], d1['s'] - x = (torch.cat([raw, real], dim=-1) * s + m).transpose(0,1) - - write_ply(x.numpy(), 'x.ply') - pass diff --git a/shape_completion/12_res32_partnet_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py b/shape_completion/12_res32_partnet_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py deleted file mode 100644 index cd76baf..0000000 --- a/shape_completion/12_res32_partnet_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py +++ /dev/null @@ -1,825 +0,0 @@ -import torch.multiprocessing as mp -import torch.nn as nn -import torch.optim as optim -import torch.utils.data - -import argparse -from model.unet import get_model -from torch.distributions import Normal - -from utils.file_utils import * -from utils.visualize import * -from model.pvcnn_completion import PVCNN2Base -import torch.distributed as dist -from datasets.partnet import GANdatasetPartNet -from datasets.shapenet_data_pc import ShapeNet15kPointClouds -''' -some utils -''' -def rotation_matrix(axis, theta): - """ - Return the rotation matrix associated with counterclockwise rotation about - the given axis by theta radians. - """ - axis = np.asarray(axis) - axis = axis / np.sqrt(np.dot(axis, axis)) - a = np.cos(theta / 2.0) - b, c, d = -axis * np.sin(theta / 2.0) - aa, bb, cc, dd = a * a, b * b, c * c, d * d - bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d - return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], - [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], - [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) - -def rotate(vertices, faces): - ''' - vertices: [numpoints, 3] - ''' - M = rotation_matrix([0, 1, 0], np.pi / 2).transpose() - N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose() - K = rotation_matrix([0, 0, 1], np.pi).transpose() - - v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]] - return v, f - -def norm(v, f): - v = (v - v.min())/(v.max() - v.min()) - 0.5 - - return v, f - -def getGradNorm(net): - pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters())) - gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters())) - return pNorm, gradNorm - - -def weights_init(m): - """ - xavier initialization - """ - classname = m.__class__.__name__ - if classname.find('Conv') != -1 and m.weight is not None: - torch.nn.init.xavier_normal_(m.weight) - elif classname.find('BatchNorm') != -1: - m.weight.data.normal_() - m.bias.data.fill_(0) - -''' -models -''' -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - KL divergence between normal distributions parameterized by mean and log-variance. - """ - return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) - + (mean1 - mean2)**2 * torch.exp(-logvar2)) - -def discretized_gaussian_log_likelihood(x, *, means, log_scales): - # Assumes data is integers [0, 1] - assert x.shape == means.shape == log_scales.shape - px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) - - centered_x = x - means - inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_x + 0.5) - cdf_plus = px0.cdf(plus_in) - min_in = inv_stdv * (centered_x - .5) - cdf_min = px0.cdf(min_in) - log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) - log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) - cdf_delta = cdf_plus - cdf_min - - log_probs = torch.where( - x < 0.001, log_cdf_plus, - torch.where(x > 0.999, log_one_minus_cdf_min, - torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) - assert log_probs.shape == x.shape - return log_probs - -class GaussianDiffusion: - def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): - self.loss_type = loss_type - self.model_mean_type = model_mean_type - self.model_var_type = model_var_type - assert isinstance(betas, np.ndarray) - self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy - assert (betas > 0).all() and (betas <= 1).all() - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.sv_points = sv_points - # initialize twice the actual length so we can keep running for eval - # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) - - alphas = 1. - betas - alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() - alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() - - self.betas = torch.from_numpy(betas).float() - self.alphas_cumprod = alphas_cumprod.float() - self.alphas_cumprod_prev = alphas_cumprod_prev.float() - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() - self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() - self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() - self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() - - betas = torch.from_numpy(betas).float() - alphas = torch.from_numpy(alphas).float() - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.posterior_variance = posterior_variance - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) - self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) - self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) - - @staticmethod - def _extract(a, t, x_shape): - """ - Extract some coefficients at specified timesteps, - then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. - """ - bs, = t.shape - assert x_shape[0] == bs - out = torch.gather(a, 0, t) - assert out.shape == torch.Size([bs]) - return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) - - - - def q_mean_variance(self, x_start, t): - mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start - variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) - log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) - return mean, variance, log_variance - - def q_sample(self, x_start, t, noise=None): - """ - Diffuse the data (t == 0 means diffused for 1 step) - """ - if noise is None: - noise = torch.randn(x_start.shape, device=x_start.device) - assert noise.shape == x_start.shape - return ( - self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + - self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise - ) - - - def q_posterior_mean_variance(self, x_start, x_t, t): - """ - Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) - """ - assert x_start.shape == x_t.shape - posterior_mean = ( - self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + - self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t - ) - posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) - posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) - assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == - x_start.shape[0]) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - - def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): - - model_output = denoise_fn(data, t)[:,:,self.sv_points:] - - - if self.model_var_type in ['fixedsmall', 'fixedlarge']: - # below: only log_variance is used in the KL computations - model_variance, model_log_variance = { - # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood - 'fixedlarge': (self.betas.to(data.device), - torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), - 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), - }[self.model_var_type] - model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) - model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) - else: - raise NotImplementedError(self.model_var_type) - - if self.model_mean_type == 'eps': - x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) - - - model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) - else: - raise NotImplementedError(self.loss_type) - - - assert model_mean.shape == x_recon.shape - assert model_variance.shape == model_log_variance.shape - if return_pred_xstart: - return model_mean, model_variance, model_log_variance, x_recon - else: - return model_mean, model_variance, model_log_variance - - def _predict_xstart_from_eps(self, x_t, t, eps): - assert x_t.shape == eps.shape - return ( - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps - ) - - ''' samples ''' - - def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): - """ - Sample from the model - """ - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, - return_pred_xstart=True) - noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) - - # no noise when t == 0 - nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) - - sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise - sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) - return (sample, pred_xstart) if return_pred_xstart else sample - - - def p_sample_loop(self, partial_x, denoise_fn, shape, device, - noise_fn=torch.randn, clip_denoised=True, keep_running=False): - """ - Generate samples - keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps - - """ - - assert isinstance(shape, (tuple, list)) - img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) - for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, return_pred_xstart=False) - - assert img_t[:,:,self.sv_points:].shape == shape - return img_t - - def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, - noise_fn=torch.randn,clip_denoised=True, keep_running=False): - """ - Generate samples, returning intermediate images - Useful for visualizing how denoised images evolve over time - Args: - repeat_noise_steps (int): Number of denoising timesteps in which the same noise - is used across the batch. If >= 0, the initial noise is the same for all batch elemements. - """ - assert isinstance(shape, (tuple, list)) - - total_steps = self.num_timesteps if not keep_running else len(self.betas) - - img_t = noise_fn(size=shape, dtype=torch.float, device=device) - imgs = [img_t] - for t in reversed(range(0,total_steps)): - - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, - return_pred_xstart=False) - if t % freq == 0 or t == total_steps-1: - imgs.append(img_t) - - assert imgs[-1].shape == shape - return imgs - - '''losses''' - - def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): - true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( - denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) - - kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) - kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) - - return (kl, pred_xstart) if return_pred_xstart else kl - - def p_losses(self, denoise_fn, data_start, t, noise=None): - """ - Training loss calculation - """ - B, D, N = data_start.shape - assert t.shape == torch.Size([B]) - - if noise is None: - noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) - - data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) - - if self.loss_type == 'mse': - # predict the noise instead of x_start. seems to be weighted naturally like SNR - eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] - - losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) - elif self.loss_type == 'kl': - losses = self._vb_terms_bpd( - denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, - return_pred_xstart=False) - else: - raise NotImplementedError(self.loss_type) - - assert losses.shape == torch.Size([B]) - return losses - - '''debug''' - - def _prior_bpd(self, x_start): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) - qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, - mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) - assert kl_prior.shape == x_start.shape - return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) - - def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - - vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) - for t in reversed(range(T)): - - t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) - # Calculate VLB term at the current timestep - data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) - new_vals_b, pred_xstart = self._vb_terms_bpd( - denoise_fn, data_start=x_start, data_t=data_t, t=t_b, - clip_denoised=clip_denoised, return_pred_xstart=True) - # MSE for progressive prediction loss - assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape - new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) - assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) - # Insert the calculated term into the tensor of all terms - mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() - vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt - mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt - assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) - - prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) - total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b - assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ - total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) - return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() - - -class PVCNN2(PVCNN2Base): - sa_blocks = [ - ((32, 2, 32), (1024, 0.1, 32, (32, 64))), - ((64, 3, 16), (256, 0.2, 32, (64, 128))), - ((128, 3, 8), (64, 0.4, 32, (128, 256))), - (None, (16, 0.8, 32, (256, 256, 512))), - ] - fp_blocks = [ - ((256, 256), (256, 3, 8)), - ((256, 256), (256, 3, 8)), - ((256, 128), (128, 2, 16)), - ((128, 128, 64), (64, 2, 32)), - ] - - def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, - voxel_resolution_multiplier=1): - super().__init__( - num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, - dropout=dropout, extra_feature_channels=extra_feature_channels, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier - ) - - -class Model(nn.Module): - def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): - super(Model, self).__init__() - self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) - - self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, - dropout=args.dropout, extra_feature_channels=0) - - def prior_kl(self, x0): - return self.diffusion._prior_bpd(x0) - - def all_kl(self, x0, clip_denoised=True): - total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) - - return { - 'total_bpd_b': total_bpd_b, - 'terms_bpd': vals_bt, - 'prior_bpd_b': prior_bpd_b, - 'mse_bt':mse_bt - } - - - def _denoise(self, data, t): - B, D,N= data.shape - assert data.dtype == torch.float - assert t.shape == torch.Size([B]) and t.dtype == torch.int64 - - out = self.model(data, t) - - return out - - def get_loss_iter(self, data, noises=None): - B, D, N = data.shape - t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) - - if noises is not None: - noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) - - losses = self.diffusion.p_losses( - denoise_fn=self._denoise, data_start=data, t=t, noise=noises) - assert losses.shape == t.shape == torch.Size([B]) - return losses - - def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, - clip_denoised=True, - keep_running=False): - return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, - clip_denoised=clip_denoised, - keep_running=keep_running) - - - def train(self): - self.model.train() - - def eval(self): - self.model.eval() - - def multi_gpu_wrapper(self, f): - self.model = f(self.model) - - -def get_betas(schedule_type, b_start, b_end, time_num): - if schedule_type == 'linear': - betas = np.linspace(b_start, b_end, time_num) - elif schedule_type == 'warm0.1': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.1) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.2': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.2) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.5': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.5) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - else: - raise NotImplementedError(schedule_type) - return betas - - -def get_dataset(data_root, data_raw_root, pc_dataroot, npoints, category): - - train_ds = GANdatasetPartNet('train', data_root, category, npoints) - return train_ds - - -def get_dataloader(opt, train_dataset, test_dataset=None): - - if opt.distribution_type == 'multi': - train_sampler = torch.utils.data.distributed.DistributedSampler( - train_dataset, - num_replicas=opt.world_size, - rank=opt.rank - ) - if test_dataset is not None: - test_sampler = torch.utils.data.distributed.DistributedSampler( - test_dataset, - num_replicas=opt.world_size, - rank=opt.rank - ) - else: - test_sampler = None - else: - train_sampler = None - test_sampler = None - - train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler, - shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True) - - if test_dataset is not None: - test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - else: - test_dataloader = None - - return train_dataloader, test_dataloader, train_sampler, test_sampler - - -def train(gpu, opt, output_dir, noises_init): - - set_seed(opt) - logger = setup_logging(output_dir) - if opt.distribution_type == 'multi': - should_diag = gpu==0 - else: - should_diag = True - if should_diag: - outf_syn, = setup_output_subdirs(output_dir, 'syn') - - if opt.distribution_type == 'multi': - if opt.dist_url == "env://" and opt.rank == -1: - opt.rank = int(os.environ["RANK"]) - - base_rank = opt.rank * opt.ngpus_per_node - opt.rank = base_rank + gpu - dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url, - world_size=opt.world_size, rank=opt.rank) - - opt.bs = int(opt.bs / opt.ngpus_per_node) - opt.workers = 0 - - opt.saveIter = int(opt.saveIter / opt.ngpus_per_node) - opt.diagIter = int(opt.diagIter / opt.ngpus_per_node) - opt.vizIter = int(opt.vizIter / opt.ngpus_per_node) - - - ''' data ''' - train_dataset = get_dataset(opt.data_root, opt.data_raw_root,opt.pc_dataroot, opt.npoints,opt.classes) - dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None) - - - ''' - create networks - ''' - - betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) - netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) - - if opt.distribution_type == 'multi': # Multiple processes, single GPU per process - def _transform_(m): - return nn.parallel.DistributedDataParallel( - m, device_ids=[gpu], output_device=gpu) - - torch.cuda.set_device(gpu) - netE.cuda(gpu) - netE.multi_gpu_wrapper(_transform_) - - - elif opt.distribution_type == 'single': - def _transform_(m): - return nn.parallel.DataParallel(m) - netE = netE.cuda() - netE.multi_gpu_wrapper(_transform_) - - elif gpu is not None: - torch.cuda.set_device(gpu) - netE = netE.cuda(gpu) - else: - raise ValueError('distribution_type = multi | single | None') - - if should_diag: - logger.info(opt) - - optimizer= optim.Adam(netE.parameters(), lr=opt.lrE, weight_decay=opt.e_decay, betas=(opt.beta1, 0.999)) - - lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.e_gamma) - - if opt.netE != '': - ckpt = torch.load(opt.netE) - netE.load_state_dict(ckpt['model_state']) - optimizer.load_state_dict(ckpt['optimizer_state']) - - if opt.netE != '': - start_epoch = torch.load(opt.netE)['epoch'] + 1 - else: - start_epoch = 0 - - - for epoch in range(start_epoch, opt.niter): - - if opt.distribution_type == 'multi': - train_sampler.set_epoch(epoch) - - lr_scheduler.step(epoch) - - for i, data in enumerate(dataloader): - - x = data['real'] - sv_x = data['raw'] - - sv_x = torch.cat([sv_x, x[:, :, opt.svpoints:]], dim=-1) - noises_batch = noises_init[data['idx']] - - ''' - train diffusion - ''' - - if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None): - sv_x = sv_x.cuda(gpu) - noises_batch = noises_batch.cuda(gpu) - elif opt.distribution_type == 'single': - sv_x = sv_x.cuda() - noises_batch = noises_batch.cuda() - - loss = netE.get_loss_iter(sv_x, noises_batch).mean() - - optimizer.zero_grad() - loss.backward() - netpNorm, netgradNorm = getGradNorm(netE) - if opt.grad_clip is not None: - torch.nn.utils.clip_grad_norm_(netE.parameters(), opt.grad_clip) - - optimizer.step() - - - if i % opt.print_freq == 0 and should_diag: - - logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, ' - 'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} ' - .format( - epoch, opt.niter, i, len(dataloader),loss.item(), - netpNorm, netgradNorm, - )) - - if (epoch + 1) % opt.diagIter == 0 and should_diag: - - logger.info('Diagnosis:') - - x_range = [x.min().item(), x.max().item()] - kl_stats = netE.all_kl(sv_x) - logger.info(' [{:>3d}/{:>3d}] ' - 'x_range: [{:>10.4f}, {:>10.4f}], ' - 'total_bpd_b: {:>10.4f}, ' - 'terms_bpd: {:>10.4f}, ' - 'prior_bpd_b: {:>10.4f} ' - 'mse_bt: {:>10.4f} ' - .format( - epoch, opt.niter, - *x_range, - kl_stats['total_bpd_b'].item(), - kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item() - )) - - - - if (epoch + 1) % opt.vizIter == 0 and should_diag: - logger.info('Generation: eval') - - netE.eval() - m, s = torch.tensor([[0,0,0]]).transpose(0,1)[None], torch.tensor([[1,1,1]]).transpose(0,1)[None]#train_dataset.get_pc_stats(0) - - with torch.no_grad(): - - x_gen_eval = netE.gen_samples(sv_x[:,:,:opt.svpoints], sv_x[:,:,opt.svpoints:].shape, sv_x.device, clip_denoised=False).detach().cpu() - - - gen_stats = [x_gen_eval.mean(), x_gen_eval.std()] - gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()] - - logger.info(' [{:>3d}/{:>3d}] ' - 'eval_gen_range: [{:>10.4f}, {:>10.4f}] ' - 'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] ' - .format( - epoch, opt.niter, - *gen_eval_range, *gen_stats, - )) - - export_to_pc_batch('%s/epoch_%03d_samples_eval' % (outf_syn, epoch), - (x_gen_eval*s+m).transpose(1, 2).numpy()*3) - - export_to_pc_batch('%s/epoch_%03d_ground_truth' % (outf_syn, epoch), - (sv_x.detach().cpu()*s+m).transpose(1, 2).numpy()*3) - - export_to_pc_batch('%s/epoch_%03d_partial' % (outf_syn, epoch), - (sv_x[:,:,:opt.svpoints].detach().cpu()*s+m).transpose(1, 2).numpy()*3) - - - netE.train() - - - - - - - - if (epoch + 1) % opt.saveIter == 0: - - if should_diag: - - - save_dict = { - 'epoch': epoch, - 'model_state': netE.state_dict(), - 'optimizer_state': optimizer.state_dict() - } - - torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch)) - - - if opt.distribution_type == 'multi': - dist.barrier() - map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu} - netE.load_state_dict( - torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state']) - - dist.destroy_process_group() - -def main(): - opt = parse_args() - - exp_id = os.path.splitext(os.path.basename(__file__))[0] - dir_id = os.path.dirname(__file__) - output_dir = get_output_dir(dir_id, exp_id) - copy_source(__file__, output_dir) - - ''' workaround ''' - - train_dataset = get_dataset(opt.data_root, opt.data_raw_root, opt.pc_dataroot, opt.npoints,opt.classes) - noises_init = torch.randn(len(train_dataset), opt.nc, opt.npoints-opt.svpoints) - - if opt.dist_url == "env://" and opt.world_size == -1: - opt.world_size = int(os.environ["WORLD_SIZE"]) - - if opt.distribution_type == 'multi': - opt.ngpus_per_node = torch.cuda.device_count() - opt.world_size = opt.ngpus_per_node * opt.world_size - mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init)) - else: - train(opt.gpu, opt, output_dir, noises_init) - - - -def parse_args(): - - parser = argparse.ArgumentParser() - parser.add_argument('--data_root', default='/home/ubuntu/01DATA/partnet/data_v0', help='input batch size') - parser.add_argument('--data_raw_root', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_dim32_sdf_pc', - help='input batch size') - parser.add_argument('--pc_dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', - help='input batch size') - parser.add_argument('--classes', default='Chair') - - parser.add_argument('--bs', type=int, default=64, help='input batch size') - parser.add_argument('--workers', type=int, default=16, help='workers') - parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') - - parser.add_argument('--nc', default=3) - parser.add_argument('--npoints', default=2048) - parser.add_argument('--svpoints', default=1024) - '''model''' - parser.add_argument('--beta_start', default=0.0001) - parser.add_argument('--beta_end', default=0.02) - parser.add_argument('--schedule_type', default='linear') - parser.add_argument('--time_num', default=1000) - - #params - parser.add_argument('--attention', default=True) - parser.add_argument('--dropout', default=0.1) - parser.add_argument('--embed_dim', type=int, default=64) - parser.add_argument('--loss_type', default='mse') - parser.add_argument('--model_mean_type', default='eps') - parser.add_argument('--model_var_type', default='fixedsmall') - - parser.add_argument('--lrE', type=float, default=2e-4, help='learning rate for E, default=0.0002') - parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') - parser.add_argument('--e_decay', type=float, default=0, help='weight decay for EBM') - parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM') - parser.add_argument('--e_gamma', type=float, default=0.998, help='lr decay for EBM') - - parser.add_argument('--netE', default='', help="path to netE (to continue training)") - - - '''distributed''' - parser.add_argument('--world_size', default=1, type=int, - help='Number of distributed nodes.') - parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str, - help='url used to set up distributed training') - parser.add_argument('--dist_backend', default='nccl', type=str, - help='distributed backend') - parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None], - help='Use multi-processing distributed training to launch ' - 'N processes per node, which has N GPUs. This is the ' - 'fastest way to use PyTorch for either single node or ' - 'multi node data parallel training') - parser.add_argument('--rank', default=0, type=int, - help='node rank for distributed training') - parser.add_argument('--gpu', default=None, type=int, - help='GPU id to use. None means using all available GPUs.') - - '''eval''' - parser.add_argument('--saveIter', default=100, type=int,help='unit: epoch') - parser.add_argument('--diagIter', default=50, type=int, help='unit: epoch') - parser.add_argument('--vizIter', default=50, type=int,help='unit: epoch') - parser.add_argument('--print_freq', default=50, type=int,help='unit: iter') - - parser.add_argument('--manualSeed', default=42, type=int, help='random seed') - - - opt = parser.parse_args() - - return opt - -if __name__ == '__main__': - main() diff --git a/shape_completion/13_res32_partnet_table_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py b/shape_completion/13_res32_partnet_table_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py deleted file mode 100644 index f7a39fa..0000000 --- a/shape_completion/13_res32_partnet_table_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py +++ /dev/null @@ -1,825 +0,0 @@ -import torch.multiprocessing as mp -import torch.nn as nn -import torch.optim as optim -import torch.utils.data - -import argparse -from model.unet import get_model -from torch.distributions import Normal - -from utils.file_utils import * -from utils.visualize import * -from model.pvcnn_completion import PVCNN2Base -import torch.distributed as dist -from datasets.partnet import GANdatasetPartNet -from datasets.shapenet_data_pc import ShapeNet15kPointClouds -''' -some utils -''' -def rotation_matrix(axis, theta): - """ - Return the rotation matrix associated with counterclockwise rotation about - the given axis by theta radians. - """ - axis = np.asarray(axis) - axis = axis / np.sqrt(np.dot(axis, axis)) - a = np.cos(theta / 2.0) - b, c, d = -axis * np.sin(theta / 2.0) - aa, bb, cc, dd = a * a, b * b, c * c, d * d - bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d - return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], - [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], - [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) - -def rotate(vertices, faces): - ''' - vertices: [numpoints, 3] - ''' - M = rotation_matrix([0, 1, 0], np.pi / 2).transpose() - N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose() - K = rotation_matrix([0, 0, 1], np.pi).transpose() - - v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]] - return v, f - -def norm(v, f): - v = (v - v.min())/(v.max() - v.min()) - 0.5 - - return v, f - -def getGradNorm(net): - pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters())) - gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters())) - return pNorm, gradNorm - - -def weights_init(m): - """ - xavier initialization - """ - classname = m.__class__.__name__ - if classname.find('Conv') != -1 and m.weight is not None: - torch.nn.init.xavier_normal_(m.weight) - elif classname.find('BatchNorm') != -1: - m.weight.data.normal_() - m.bias.data.fill_(0) - -''' -models -''' -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - KL divergence between normal distributions parameterized by mean and log-variance. - """ - return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) - + (mean1 - mean2)**2 * torch.exp(-logvar2)) - -def discretized_gaussian_log_likelihood(x, *, means, log_scales): - # Assumes data is integers [0, 1] - assert x.shape == means.shape == log_scales.shape - px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) - - centered_x = x - means - inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_x + 0.5) - cdf_plus = px0.cdf(plus_in) - min_in = inv_stdv * (centered_x - .5) - cdf_min = px0.cdf(min_in) - log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) - log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) - cdf_delta = cdf_plus - cdf_min - - log_probs = torch.where( - x < 0.001, log_cdf_plus, - torch.where(x > 0.999, log_one_minus_cdf_min, - torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) - assert log_probs.shape == x.shape - return log_probs - -class GaussianDiffusion: - def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): - self.loss_type = loss_type - self.model_mean_type = model_mean_type - self.model_var_type = model_var_type - assert isinstance(betas, np.ndarray) - self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy - assert (betas > 0).all() and (betas <= 1).all() - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.sv_points = sv_points - # initialize twice the actual length so we can keep running for eval - # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) - - alphas = 1. - betas - alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() - alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() - - self.betas = torch.from_numpy(betas).float() - self.alphas_cumprod = alphas_cumprod.float() - self.alphas_cumprod_prev = alphas_cumprod_prev.float() - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() - self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() - self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() - self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() - - betas = torch.from_numpy(betas).float() - alphas = torch.from_numpy(alphas).float() - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.posterior_variance = posterior_variance - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) - self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) - self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) - - @staticmethod - def _extract(a, t, x_shape): - """ - Extract some coefficients at specified timesteps, - then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. - """ - bs, = t.shape - assert x_shape[0] == bs - out = torch.gather(a, 0, t) - assert out.shape == torch.Size([bs]) - return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) - - - - def q_mean_variance(self, x_start, t): - mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start - variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) - log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) - return mean, variance, log_variance - - def q_sample(self, x_start, t, noise=None): - """ - Diffuse the data (t == 0 means diffused for 1 step) - """ - if noise is None: - noise = torch.randn(x_start.shape, device=x_start.device) - assert noise.shape == x_start.shape - return ( - self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + - self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise - ) - - - def q_posterior_mean_variance(self, x_start, x_t, t): - """ - Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) - """ - assert x_start.shape == x_t.shape - posterior_mean = ( - self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + - self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t - ) - posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) - posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) - assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == - x_start.shape[0]) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - - def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): - - model_output = denoise_fn(data, t)[:,:,self.sv_points:] - - - if self.model_var_type in ['fixedsmall', 'fixedlarge']: - # below: only log_variance is used in the KL computations - model_variance, model_log_variance = { - # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood - 'fixedlarge': (self.betas.to(data.device), - torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), - 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), - }[self.model_var_type] - model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) - model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) - else: - raise NotImplementedError(self.model_var_type) - - if self.model_mean_type == 'eps': - x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) - - - model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) - else: - raise NotImplementedError(self.loss_type) - - - assert model_mean.shape == x_recon.shape - assert model_variance.shape == model_log_variance.shape - if return_pred_xstart: - return model_mean, model_variance, model_log_variance, x_recon - else: - return model_mean, model_variance, model_log_variance - - def _predict_xstart_from_eps(self, x_t, t, eps): - assert x_t.shape == eps.shape - return ( - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps - ) - - ''' samples ''' - - def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): - """ - Sample from the model - """ - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, - return_pred_xstart=True) - noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) - - # no noise when t == 0 - nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) - - sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise - sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) - return (sample, pred_xstart) if return_pred_xstart else sample - - - def p_sample_loop(self, partial_x, denoise_fn, shape, device, - noise_fn=torch.randn, clip_denoised=True, keep_running=False): - """ - Generate samples - keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps - - """ - - assert isinstance(shape, (tuple, list)) - img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) - for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, return_pred_xstart=False) - - assert img_t[:,:,self.sv_points:].shape == shape - return img_t - - def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, - noise_fn=torch.randn,clip_denoised=True, keep_running=False): - """ - Generate samples, returning intermediate images - Useful for visualizing how denoised images evolve over time - Args: - repeat_noise_steps (int): Number of denoising timesteps in which the same noise - is used across the batch. If >= 0, the initial noise is the same for all batch elemements. - """ - assert isinstance(shape, (tuple, list)) - - total_steps = self.num_timesteps if not keep_running else len(self.betas) - - img_t = noise_fn(size=shape, dtype=torch.float, device=device) - imgs = [img_t] - for t in reversed(range(0,total_steps)): - - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, - return_pred_xstart=False) - if t % freq == 0 or t == total_steps-1: - imgs.append(img_t) - - assert imgs[-1].shape == shape - return imgs - - '''losses''' - - def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): - true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( - denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) - - kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) - kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) - - return (kl, pred_xstart) if return_pred_xstart else kl - - def p_losses(self, denoise_fn, data_start, t, noise=None): - """ - Training loss calculation - """ - B, D, N = data_start.shape - assert t.shape == torch.Size([B]) - - if noise is None: - noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) - - data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) - - if self.loss_type == 'mse': - # predict the noise instead of x_start. seems to be weighted naturally like SNR - eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] - - losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) - elif self.loss_type == 'kl': - losses = self._vb_terms_bpd( - denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, - return_pred_xstart=False) - else: - raise NotImplementedError(self.loss_type) - - assert losses.shape == torch.Size([B]) - return losses - - '''debug''' - - def _prior_bpd(self, x_start): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) - qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, - mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) - assert kl_prior.shape == x_start.shape - return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) - - def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - - vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) - for t in reversed(range(T)): - - t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) - # Calculate VLB term at the current timestep - data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) - new_vals_b, pred_xstart = self._vb_terms_bpd( - denoise_fn, data_start=x_start, data_t=data_t, t=t_b, - clip_denoised=clip_denoised, return_pred_xstart=True) - # MSE for progressive prediction loss - assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape - new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) - assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) - # Insert the calculated term into the tensor of all terms - mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() - vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt - mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt - assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) - - prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) - total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b - assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ - total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) - return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() - - -class PVCNN2(PVCNN2Base): - sa_blocks = [ - ((32, 2, 32), (1024, 0.1, 32, (32, 64))), - ((64, 3, 16), (256, 0.2, 32, (64, 128))), - ((128, 3, 8), (64, 0.4, 32, (128, 256))), - (None, (16, 0.8, 32, (256, 256, 512))), - ] - fp_blocks = [ - ((256, 256), (256, 3, 8)), - ((256, 256), (256, 3, 8)), - ((256, 128), (128, 2, 16)), - ((128, 128, 64), (64, 2, 32)), - ] - - def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, - voxel_resolution_multiplier=1): - super().__init__( - num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, - dropout=dropout, extra_feature_channels=extra_feature_channels, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier - ) - - -class Model(nn.Module): - def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): - super(Model, self).__init__() - self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) - - self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, - dropout=args.dropout, extra_feature_channels=0) - - def prior_kl(self, x0): - return self.diffusion._prior_bpd(x0) - - def all_kl(self, x0, clip_denoised=True): - total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) - - return { - 'total_bpd_b': total_bpd_b, - 'terms_bpd': vals_bt, - 'prior_bpd_b': prior_bpd_b, - 'mse_bt':mse_bt - } - - - def _denoise(self, data, t): - B, D,N= data.shape - assert data.dtype == torch.float - assert t.shape == torch.Size([B]) and t.dtype == torch.int64 - - out = self.model(data, t) - - return out - - def get_loss_iter(self, data, noises=None): - B, D, N = data.shape - t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) - - if noises is not None: - noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) - - losses = self.diffusion.p_losses( - denoise_fn=self._denoise, data_start=data, t=t, noise=noises) - assert losses.shape == t.shape == torch.Size([B]) - return losses - - def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, - clip_denoised=True, - keep_running=False): - return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, - clip_denoised=clip_denoised, - keep_running=keep_running) - - - def train(self): - self.model.train() - - def eval(self): - self.model.eval() - - def multi_gpu_wrapper(self, f): - self.model = f(self.model) - - -def get_betas(schedule_type, b_start, b_end, time_num): - if schedule_type == 'linear': - betas = np.linspace(b_start, b_end, time_num) - elif schedule_type == 'warm0.1': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.1) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.2': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.2) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.5': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.5) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - else: - raise NotImplementedError(schedule_type) - return betas - - -def get_dataset(data_root, data_raw_root, pc_dataroot, npoints, category): - - train_ds = GANdatasetPartNet('train', data_root, category, npoints) - return train_ds - - -def get_dataloader(opt, train_dataset, test_dataset=None): - - if opt.distribution_type == 'multi': - train_sampler = torch.utils.data.distributed.DistributedSampler( - train_dataset, - num_replicas=opt.world_size, - rank=opt.rank - ) - if test_dataset is not None: - test_sampler = torch.utils.data.distributed.DistributedSampler( - test_dataset, - num_replicas=opt.world_size, - rank=opt.rank - ) - else: - test_sampler = None - else: - train_sampler = None - test_sampler = None - - train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler, - shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True) - - if test_dataset is not None: - test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - else: - test_dataloader = None - - return train_dataloader, test_dataloader, train_sampler, test_sampler - - -def train(gpu, opt, output_dir, noises_init): - - set_seed(opt) - logger = setup_logging(output_dir) - if opt.distribution_type == 'multi': - should_diag = gpu==0 - else: - should_diag = True - if should_diag: - outf_syn, = setup_output_subdirs(output_dir, 'syn') - - if opt.distribution_type == 'multi': - if opt.dist_url == "env://" and opt.rank == -1: - opt.rank = int(os.environ["RANK"]) - - base_rank = opt.rank * opt.ngpus_per_node - opt.rank = base_rank + gpu - dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url, - world_size=opt.world_size, rank=opt.rank) - - opt.bs = int(opt.bs / opt.ngpus_per_node) - opt.workers = 0 - - opt.saveIter = int(opt.saveIter / opt.ngpus_per_node) - opt.diagIter = int(opt.diagIter / opt.ngpus_per_node) - opt.vizIter = int(opt.vizIter / opt.ngpus_per_node) - - - ''' data ''' - train_dataset = get_dataset(opt.data_root, opt.data_raw_root,opt.pc_dataroot, opt.npoints,opt.classes) - dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None) - - - ''' - create networks - ''' - - betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) - netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) - - if opt.distribution_type == 'multi': # Multiple processes, single GPU per process - def _transform_(m): - return nn.parallel.DistributedDataParallel( - m, device_ids=[gpu], output_device=gpu) - - torch.cuda.set_device(gpu) - netE.cuda(gpu) - netE.multi_gpu_wrapper(_transform_) - - - elif opt.distribution_type == 'single': - def _transform_(m): - return nn.parallel.DataParallel(m) - netE = netE.cuda() - netE.multi_gpu_wrapper(_transform_) - - elif gpu is not None: - torch.cuda.set_device(gpu) - netE = netE.cuda(gpu) - else: - raise ValueError('distribution_type = multi | single | None') - - if should_diag: - logger.info(opt) - - optimizer= optim.Adam(netE.parameters(), lr=opt.lrE, weight_decay=opt.e_decay, betas=(opt.beta1, 0.999)) - - lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.e_gamma) - - if opt.netE != '': - ckpt = torch.load(opt.netE) - netE.load_state_dict(ckpt['model_state']) - optimizer.load_state_dict(ckpt['optimizer_state']) - - if opt.netE != '': - start_epoch = torch.load(opt.netE)['epoch'] + 1 - else: - start_epoch = 0 - - - for epoch in range(start_epoch, opt.niter): - - if opt.distribution_type == 'multi': - train_sampler.set_epoch(epoch) - - lr_scheduler.step(epoch) - - for i, data in enumerate(dataloader): - - x = data['real'] - sv_x = data['raw'] - - sv_x = torch.cat([sv_x, x[:, :, opt.svpoints:]], dim=-1) - noises_batch = noises_init[data['idx']] - - ''' - train diffusion - ''' - - if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None): - sv_x = sv_x.cuda(gpu) - noises_batch = noises_batch.cuda(gpu) - elif opt.distribution_type == 'single': - sv_x = sv_x.cuda() - noises_batch = noises_batch.cuda() - - loss = netE.get_loss_iter(sv_x, noises_batch).mean() - - optimizer.zero_grad() - loss.backward() - netpNorm, netgradNorm = getGradNorm(netE) - if opt.grad_clip is not None: - torch.nn.utils.clip_grad_norm_(netE.parameters(), opt.grad_clip) - - optimizer.step() - - - if i % opt.print_freq == 0 and should_diag: - - logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, ' - 'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} ' - .format( - epoch, opt.niter, i, len(dataloader),loss.item(), - netpNorm, netgradNorm, - )) - - if (epoch + 1) % opt.diagIter == 0 and should_diag: - - logger.info('Diagnosis:') - - x_range = [x.min().item(), x.max().item()] - kl_stats = netE.all_kl(sv_x) - logger.info(' [{:>3d}/{:>3d}] ' - 'x_range: [{:>10.4f}, {:>10.4f}], ' - 'total_bpd_b: {:>10.4f}, ' - 'terms_bpd: {:>10.4f}, ' - 'prior_bpd_b: {:>10.4f} ' - 'mse_bt: {:>10.4f} ' - .format( - epoch, opt.niter, - *x_range, - kl_stats['total_bpd_b'].item(), - kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item() - )) - - - - if (epoch + 1) % opt.vizIter == 0 and should_diag: - logger.info('Generation: eval') - - netE.eval() - m, s = torch.tensor([[0,0,0]]).transpose(0,1)[None], torch.tensor([[1,1,1]]).transpose(0,1)[None]#train_dataset.get_pc_stats(0) - - with torch.no_grad(): - - x_gen_eval = netE.gen_samples(sv_x[:,:,:opt.svpoints], sv_x[:,:,opt.svpoints:].shape, sv_x.device, clip_denoised=False).detach().cpu() - - - gen_stats = [x_gen_eval.mean(), x_gen_eval.std()] - gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()] - - logger.info(' [{:>3d}/{:>3d}] ' - 'eval_gen_range: [{:>10.4f}, {:>10.4f}] ' - 'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] ' - .format( - epoch, opt.niter, - *gen_eval_range, *gen_stats, - )) - - export_to_pc_batch('%s/epoch_%03d_samples_eval' % (outf_syn, epoch), - (x_gen_eval*s+m).transpose(1, 2).numpy()*3) - - export_to_pc_batch('%s/epoch_%03d_ground_truth' % (outf_syn, epoch), - (sv_x.detach().cpu()*s+m).transpose(1, 2).numpy()*3) - - export_to_pc_batch('%s/epoch_%03d_partial' % (outf_syn, epoch), - (sv_x[:,:,:opt.svpoints].detach().cpu()*s+m).transpose(1, 2).numpy()*3) - - - netE.train() - - - - - - - - if (epoch + 1) % opt.saveIter == 0: - - if should_diag: - - - save_dict = { - 'epoch': epoch, - 'model_state': netE.state_dict(), - 'optimizer_state': optimizer.state_dict() - } - - torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch)) - - - if opt.distribution_type == 'multi': - dist.barrier() - map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu} - netE.load_state_dict( - torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state']) - - dist.destroy_process_group() - -def main(): - opt = parse_args() - - exp_id = os.path.splitext(os.path.basename(__file__))[0] - dir_id = os.path.dirname(__file__) - output_dir = get_output_dir(dir_id, exp_id) - copy_source(__file__, output_dir) - - ''' workaround ''' - - train_dataset = get_dataset(opt.data_root, opt.data_raw_root, opt.pc_dataroot, opt.npoints,opt.classes) - noises_init = torch.randn(len(train_dataset), opt.nc, opt.npoints-opt.svpoints) - - if opt.dist_url == "env://" and opt.world_size == -1: - opt.world_size = int(os.environ["WORLD_SIZE"]) - - if opt.distribution_type == 'multi': - opt.ngpus_per_node = torch.cuda.device_count() - opt.world_size = opt.ngpus_per_node * opt.world_size - mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init)) - else: - train(opt.gpu, opt, output_dir, noises_init) - - - -def parse_args(): - - parser = argparse.ArgumentParser() - parser.add_argument('--data_root', default='/home/ubuntu/01DATA/partnet/', help='input batch size') - parser.add_argument('--data_raw_root', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_dim32_sdf_pc', - help='input batch size') - parser.add_argument('--pc_dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', - help='input batch size') - parser.add_argument('--classes', default='Table') - - parser.add_argument('--bs', type=int, default=64, help='input batch size') - parser.add_argument('--workers', type=int, default=16, help='workers') - parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') - - parser.add_argument('--nc', default=3) - parser.add_argument('--npoints', default=2048) - parser.add_argument('--svpoints', default=1024) - '''model''' - parser.add_argument('--beta_start', default=0.0001) - parser.add_argument('--beta_end', default=0.02) - parser.add_argument('--schedule_type', default='linear') - parser.add_argument('--time_num', default=1000) - - #params - parser.add_argument('--attention', default=True) - parser.add_argument('--dropout', default=0.1) - parser.add_argument('--embed_dim', type=int, default=64) - parser.add_argument('--loss_type', default='mse') - parser.add_argument('--model_mean_type', default='eps') - parser.add_argument('--model_var_type', default='fixedsmall') - - parser.add_argument('--lrE', type=float, default=2e-4, help='learning rate for E, default=0.0002') - parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') - parser.add_argument('--e_decay', type=float, default=0, help='weight decay for EBM') - parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM') - parser.add_argument('--e_gamma', type=float, default=0.998, help='lr decay for EBM') - - parser.add_argument('--netE', default='', help="path to netE (to continue training)") - - - '''distributed''' - parser.add_argument('--world_size', default=1, type=int, - help='Number of distributed nodes.') - parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str, - help='url used to set up distributed training') - parser.add_argument('--dist_backend', default='nccl', type=str, - help='distributed backend') - parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None], - help='Use multi-processing distributed training to launch ' - 'N processes per node, which has N GPUs. This is the ' - 'fastest way to use PyTorch for either single node or ' - 'multi node data parallel training') - parser.add_argument('--rank', default=0, type=int, - help='node rank for distributed training') - parser.add_argument('--gpu', default=None, type=int, - help='GPU id to use. None means using all available GPUs.') - - '''eval''' - parser.add_argument('--saveIter', default=100, type=int,help='unit: epoch') - parser.add_argument('--diagIter', default=50, type=int, help='unit: epoch') - parser.add_argument('--vizIter', default=50, type=int,help='unit: epoch') - parser.add_argument('--print_freq', default=50, type=int,help='unit: iter') - - parser.add_argument('--manualSeed', default=42, type=int, help='random seed') - - - opt = parser.parse_args() - - return opt - -if __name__ == '__main__': - main() diff --git a/shape_completion/14_res32_partnet_lamp_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py b/shape_completion/14_res32_partnet_lamp_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py deleted file mode 100644 index b8a23f5..0000000 --- a/shape_completion/14_res32_partnet_lamp_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py +++ /dev/null @@ -1,822 +0,0 @@ -import torch.multiprocessing as mp -import torch.nn as nn -import torch.optim as optim -import torch.utils.data - -import argparse -from model.unet import get_model -from torch.distributions import Normal - -from utils.file_utils import * -from utils.visualize import * -from model.pvcnn_completion import PVCNN2Base -import torch.distributed as dist -from datasets.partnet import GANdatasetPartNet -from datasets.shapenet_data_pc import ShapeNet15kPointClouds -''' -some utils -''' -def rotation_matrix(axis, theta): - """ - Return the rotation matrix associated with counterclockwise rotation about - the given axis by theta radians. - """ - axis = np.asarray(axis) - axis = axis / np.sqrt(np.dot(axis, axis)) - a = np.cos(theta / 2.0) - b, c, d = -axis * np.sin(theta / 2.0) - aa, bb, cc, dd = a * a, b * b, c * c, d * d - bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d - return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], - [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], - [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) - -def rotate(vertices, faces): - ''' - vertices: [numpoints, 3] - ''' - M = rotation_matrix([0, 1, 0], np.pi / 2).transpose() - N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose() - K = rotation_matrix([0, 0, 1], np.pi).transpose() - - v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]] - return v, f - -def norm(v, f): - v = (v - v.min())/(v.max() - v.min()) - 0.5 - - return v, f - -def getGradNorm(net): - pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters())) - gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters())) - return pNorm, gradNorm - - -def weights_init(m): - """ - xavier initialization - """ - classname = m.__class__.__name__ - if classname.find('Conv') != -1 and m.weight is not None: - torch.nn.init.xavier_normal_(m.weight) - elif classname.find('BatchNorm') != -1: - m.weight.data.normal_() - m.bias.data.fill_(0) - -''' -models -''' -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - KL divergence between normal distributions parameterized by mean and log-variance. - """ - return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) - + (mean1 - mean2)**2 * torch.exp(-logvar2)) - -def discretized_gaussian_log_likelihood(x, *, means, log_scales): - # Assumes data is integers [0, 1] - assert x.shape == means.shape == log_scales.shape - px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) - - centered_x = x - means - inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_x + 0.5) - cdf_plus = px0.cdf(plus_in) - min_in = inv_stdv * (centered_x - .5) - cdf_min = px0.cdf(min_in) - log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) - log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) - cdf_delta = cdf_plus - cdf_min - - log_probs = torch.where( - x < 0.001, log_cdf_plus, - torch.where(x > 0.999, log_one_minus_cdf_min, - torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) - assert log_probs.shape == x.shape - return log_probs - -class GaussianDiffusion: - def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): - self.loss_type = loss_type - self.model_mean_type = model_mean_type - self.model_var_type = model_var_type - assert isinstance(betas, np.ndarray) - self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy - assert (betas > 0).all() and (betas <= 1).all() - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.sv_points = sv_points - # initialize twice the actual length so we can keep running for eval - # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) - - alphas = 1. - betas - alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() - alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() - - self.betas = torch.from_numpy(betas).float() - self.alphas_cumprod = alphas_cumprod.float() - self.alphas_cumprod_prev = alphas_cumprod_prev.float() - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() - self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() - self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() - self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() - - betas = torch.from_numpy(betas).float() - alphas = torch.from_numpy(alphas).float() - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.posterior_variance = posterior_variance - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) - self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) - self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) - - @staticmethod - def _extract(a, t, x_shape): - """ - Extract some coefficients at specified timesteps, - then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. - """ - bs, = t.shape - assert x_shape[0] == bs - out = torch.gather(a, 0, t) - assert out.shape == torch.Size([bs]) - return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) - - - - def q_mean_variance(self, x_start, t): - mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start - variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) - log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) - return mean, variance, log_variance - - def q_sample(self, x_start, t, noise=None): - """ - Diffuse the data (t == 0 means diffused for 1 step) - """ - if noise is None: - noise = torch.randn(x_start.shape, device=x_start.device) - assert noise.shape == x_start.shape - return ( - self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + - self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise - ) - - - def q_posterior_mean_variance(self, x_start, x_t, t): - """ - Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) - """ - assert x_start.shape == x_t.shape - posterior_mean = ( - self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + - self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t - ) - posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) - posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) - assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == - x_start.shape[0]) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - - def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): - - model_output = denoise_fn(data, t)[:,:,self.sv_points:] - - - if self.model_var_type in ['fixedsmall', 'fixedlarge']: - # below: only log_variance is used in the KL computations - model_variance, model_log_variance = { - # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood - 'fixedlarge': (self.betas.to(data.device), - torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), - 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), - }[self.model_var_type] - model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) - model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) - else: - raise NotImplementedError(self.model_var_type) - - if self.model_mean_type == 'eps': - x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) - - - model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) - else: - raise NotImplementedError(self.loss_type) - - - assert model_mean.shape == x_recon.shape - assert model_variance.shape == model_log_variance.shape - if return_pred_xstart: - return model_mean, model_variance, model_log_variance, x_recon - else: - return model_mean, model_variance, model_log_variance - - def _predict_xstart_from_eps(self, x_t, t, eps): - assert x_t.shape == eps.shape - return ( - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps - ) - - ''' samples ''' - - def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): - """ - Sample from the model - """ - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, - return_pred_xstart=True) - noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) - - # no noise when t == 0 - nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) - - sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise - sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) - return (sample, pred_xstart) if return_pred_xstart else sample - - - def p_sample_loop(self, partial_x, denoise_fn, shape, device, - noise_fn=torch.randn, clip_denoised=True, keep_running=False): - """ - Generate samples - keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps - - """ - - assert isinstance(shape, (tuple, list)) - img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) - for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, return_pred_xstart=False) - - assert img_t[:,:,self.sv_points:].shape == shape - return img_t - - def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, - noise_fn=torch.randn,clip_denoised=True, keep_running=False): - """ - Generate samples, returning intermediate images - Useful for visualizing how denoised images evolve over time - Args: - repeat_noise_steps (int): Number of denoising timesteps in which the same noise - is used across the batch. If >= 0, the initial noise is the same for all batch elemements. - """ - assert isinstance(shape, (tuple, list)) - - total_steps = self.num_timesteps if not keep_running else len(self.betas) - - img_t = noise_fn(size=shape, dtype=torch.float, device=device) - imgs = [img_t] - for t in reversed(range(0,total_steps)): - - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, - return_pred_xstart=False) - if t % freq == 0 or t == total_steps-1: - imgs.append(img_t) - - assert imgs[-1].shape == shape - return imgs - - '''losses''' - - def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): - true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( - denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) - - kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) - kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) - - return (kl, pred_xstart) if return_pred_xstart else kl - - def p_losses(self, denoise_fn, data_start, t, noise=None): - """ - Training loss calculation - """ - B, D, N = data_start.shape - assert t.shape == torch.Size([B]) - - if noise is None: - noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) - - data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) - - if self.loss_type == 'mse': - # predict the noise instead of x_start. seems to be weighted naturally like SNR - eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] - - losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) - elif self.loss_type == 'kl': - losses = self._vb_terms_bpd( - denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, - return_pred_xstart=False) - else: - raise NotImplementedError(self.loss_type) - - assert losses.shape == torch.Size([B]) - return losses - - '''debug''' - - def _prior_bpd(self, x_start): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) - qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, - mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) - assert kl_prior.shape == x_start.shape - return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) - - def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - - vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) - for t in reversed(range(T)): - - t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) - # Calculate VLB term at the current timestep - data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) - new_vals_b, pred_xstart = self._vb_terms_bpd( - denoise_fn, data_start=x_start, data_t=data_t, t=t_b, - clip_denoised=clip_denoised, return_pred_xstart=True) - # MSE for progressive prediction loss - assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape - new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) - assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) - # Insert the calculated term into the tensor of all terms - mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() - vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt - mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt - assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) - - prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) - total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b - assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ - total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) - return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() - - -class PVCNN2(PVCNN2Base): - sa_blocks = [ - ((32, 2, 32), (1024, 0.1, 32, (32, 64))), - ((64, 3, 16), (256, 0.2, 32, (64, 128))), - ((128, 3, 8), (64, 0.4, 32, (128, 256))), - (None, (16, 0.8, 32, (256, 256, 512))), - ] - fp_blocks = [ - ((256, 256), (256, 3, 8)), - ((256, 256), (256, 3, 8)), - ((256, 128), (128, 2, 16)), - ((128, 128, 64), (64, 2, 32)), - ] - - def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, - voxel_resolution_multiplier=1): - super().__init__( - num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, - dropout=dropout, extra_feature_channels=extra_feature_channels, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier - ) - - -class Model(nn.Module): - def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): - super(Model, self).__init__() - self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) - - self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, - dropout=args.dropout, extra_feature_channels=0) - - def prior_kl(self, x0): - return self.diffusion._prior_bpd(x0) - - def all_kl(self, x0, clip_denoised=True): - total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) - - return { - 'total_bpd_b': total_bpd_b, - 'terms_bpd': vals_bt, - 'prior_bpd_b': prior_bpd_b, - 'mse_bt':mse_bt - } - - - def _denoise(self, data, t): - B, D,N= data.shape - assert data.dtype == torch.float - assert t.shape == torch.Size([B]) and t.dtype == torch.int64 - - out = self.model(data, t) - - return out - - def get_loss_iter(self, data, noises=None): - B, D, N = data.shape - t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) - - if noises is not None: - noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) - - losses = self.diffusion.p_losses( - denoise_fn=self._denoise, data_start=data, t=t, noise=noises) - assert losses.shape == t.shape == torch.Size([B]) - return losses - - def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, - clip_denoised=True, - keep_running=False): - return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, - clip_denoised=clip_denoised, - keep_running=keep_running) - - - def train(self): - self.model.train() - - def eval(self): - self.model.eval() - - def multi_gpu_wrapper(self, f): - self.model = f(self.model) - - -def get_betas(schedule_type, b_start, b_end, time_num): - if schedule_type == 'linear': - betas = np.linspace(b_start, b_end, time_num) - elif schedule_type == 'warm0.1': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.1) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.2': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.2) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.5': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.5) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - else: - raise NotImplementedError(schedule_type) - return betas - - -def get_dataset(data_root, npoints, category): - - train_ds = GANdatasetPartNet('train', data_root, category, npoints) - return train_ds - - -def get_dataloader(opt, train_dataset, test_dataset=None): - - if opt.distribution_type == 'multi': - train_sampler = torch.utils.data.distributed.DistributedSampler( - train_dataset, - num_replicas=opt.world_size, - rank=opt.rank - ) - if test_dataset is not None: - test_sampler = torch.utils.data.distributed.DistributedSampler( - test_dataset, - num_replicas=opt.world_size, - rank=opt.rank - ) - else: - test_sampler = None - else: - train_sampler = None - test_sampler = None - - train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler, - shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True) - - if test_dataset is not None: - test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - else: - test_dataloader = None - - return train_dataloader, test_dataloader, train_sampler, test_sampler - - -def train(gpu, opt, output_dir, noises_init): - - set_seed(opt) - logger = setup_logging(output_dir) - if opt.distribution_type == 'multi': - should_diag = gpu==0 - else: - should_diag = True - if should_diag: - outf_syn, = setup_output_subdirs(output_dir, 'syn') - - if opt.distribution_type == 'multi': - if opt.dist_url == "env://" and opt.rank == -1: - opt.rank = int(os.environ["RANK"]) - - base_rank = opt.rank * opt.ngpus_per_node - opt.rank = base_rank + gpu - dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url, - world_size=opt.world_size, rank=opt.rank) - - opt.bs = int(opt.bs / opt.ngpus_per_node) - opt.workers = 0 - - opt.saveIter = int(opt.saveIter / opt.ngpus_per_node) - opt.diagIter = int(opt.diagIter / opt.ngpus_per_node) - opt.vizIter = int(opt.vizIter / opt.ngpus_per_node) - - - ''' data ''' - train_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes) - dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None) - - - ''' - create networks - ''' - - betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) - netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) - - if opt.distribution_type == 'multi': # Multiple processes, single GPU per process - def _transform_(m): - return nn.parallel.DistributedDataParallel( - m, device_ids=[gpu], output_device=gpu) - - torch.cuda.set_device(gpu) - netE.cuda(gpu) - netE.multi_gpu_wrapper(_transform_) - - - elif opt.distribution_type == 'single': - def _transform_(m): - return nn.parallel.DataParallel(m) - netE = netE.cuda() - netE.multi_gpu_wrapper(_transform_) - - elif gpu is not None: - torch.cuda.set_device(gpu) - netE = netE.cuda(gpu) - else: - raise ValueError('distribution_type = multi | single | None') - - if should_diag: - logger.info(opt) - - optimizer= optim.Adam(netE.parameters(), lr=opt.lrE, weight_decay=opt.e_decay, betas=(opt.beta1, 0.999)) - - lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.e_gamma) - - if opt.netE != '': - ckpt = torch.load(opt.netE) - netE.load_state_dict(ckpt['model_state']) - optimizer.load_state_dict(ckpt['optimizer_state']) - - if opt.netE != '': - start_epoch = torch.load(opt.netE)['epoch'] + 1 - else: - start_epoch = 0 - - - for epoch in range(start_epoch, opt.niter): - - if opt.distribution_type == 'multi': - train_sampler.set_epoch(epoch) - - lr_scheduler.step(epoch) - - for i, data in enumerate(dataloader): - - x = data['real'] - sv_x = data['raw'] - - sv_x = torch.cat([sv_x, x[:, :, opt.svpoints:]], dim=-1) - noises_batch = noises_init[data['idx']] - - ''' - train diffusion - ''' - - if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None): - sv_x = sv_x.cuda(gpu) - noises_batch = noises_batch.cuda(gpu) - elif opt.distribution_type == 'single': - sv_x = sv_x.cuda() - noises_batch = noises_batch.cuda() - - loss = netE.get_loss_iter(sv_x, noises_batch).mean() - - optimizer.zero_grad() - loss.backward() - netpNorm, netgradNorm = getGradNorm(netE) - if opt.grad_clip is not None: - torch.nn.utils.clip_grad_norm_(netE.parameters(), opt.grad_clip) - - optimizer.step() - - - if i % opt.print_freq == 0 and should_diag: - - logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, ' - 'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} ' - .format( - epoch, opt.niter, i, len(dataloader),loss.item(), - netpNorm, netgradNorm, - )) - - if (epoch + 1) % opt.diagIter == 0 and should_diag: - - logger.info('Diagnosis:') - - x_range = [x.min().item(), x.max().item()] - kl_stats = netE.all_kl(sv_x) - logger.info(' [{:>3d}/{:>3d}] ' - 'x_range: [{:>10.4f}, {:>10.4f}], ' - 'total_bpd_b: {:>10.4f}, ' - 'terms_bpd: {:>10.4f}, ' - 'prior_bpd_b: {:>10.4f} ' - 'mse_bt: {:>10.4f} ' - .format( - epoch, opt.niter, - *x_range, - kl_stats['total_bpd_b'].item(), - kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item() - )) - - - - if (epoch + 1) % opt.vizIter == 0 and should_diag: - logger.info('Generation: eval') - - netE.eval() - m, s = torch.tensor([[0,0,0]]).transpose(0,1)[None], torch.tensor([[1,1,1]]).transpose(0,1)[None]#train_dataset.get_pc_stats(0) - - with torch.no_grad(): - - x_gen_eval = netE.gen_samples(sv_x[:,:,:opt.svpoints], sv_x[:,:,opt.svpoints:].shape, sv_x.device, clip_denoised=False).detach().cpu() - - - gen_stats = [x_gen_eval.mean(), x_gen_eval.std()] - gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()] - - logger.info(' [{:>3d}/{:>3d}] ' - 'eval_gen_range: [{:>10.4f}, {:>10.4f}] ' - 'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] ' - .format( - epoch, opt.niter, - *gen_eval_range, *gen_stats, - )) - - export_to_pc_batch('%s/epoch_%03d_samples_eval' % (outf_syn, epoch), - (x_gen_eval*s+m).transpose(1, 2).numpy()*3) - - export_to_pc_batch('%s/epoch_%03d_ground_truth' % (outf_syn, epoch), - (sv_x.detach().cpu()*s+m).transpose(1, 2).numpy()*3) - - export_to_pc_batch('%s/epoch_%03d_partial' % (outf_syn, epoch), - (sv_x[:,:,:opt.svpoints].detach().cpu()*s+m).transpose(1, 2).numpy()*3) - - - netE.train() - - - - - - - - if (epoch + 1) % opt.saveIter == 0: - - if should_diag: - - - save_dict = { - 'epoch': epoch, - 'model_state': netE.state_dict(), - 'optimizer_state': optimizer.state_dict() - } - - torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch)) - - - if opt.distribution_type == 'multi': - dist.barrier() - map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu} - netE.load_state_dict( - torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state']) - - dist.destroy_process_group() - -def main(): - opt = parse_args() - - exp_id = os.path.splitext(os.path.basename(__file__))[0] - dir_id = os.path.dirname(__file__) - output_dir = get_output_dir(dir_id, exp_id) - copy_source(__file__, output_dir) - - ''' workaround ''' - - train_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes) - noises_init = torch.randn(len(train_dataset), opt.nc, opt.npoints-opt.svpoints) - - if opt.dist_url == "env://" and opt.world_size == -1: - opt.world_size = int(os.environ["WORLD_SIZE"]) - - if opt.distribution_type == 'multi': - opt.ngpus_per_node = torch.cuda.device_count() - opt.world_size = opt.ngpus_per_node * opt.world_size - mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init)) - else: - train(opt.gpu, opt, output_dir, noises_init) - - - -def parse_args(): - - parser = argparse.ArgumentParser() - parser.add_argument('--data_root', default='/viscam/u/alexzhou907/01DATA/partnet/', help='input batch size') - - parser.add_argument('--classes', default='Table') - - parser.add_argument('--bs', type=int, default=64, help='input batch size') - parser.add_argument('--workers', type=int, default=16, help='workers') - parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') - - parser.add_argument('--nc', default=3) - parser.add_argument('--npoints', default=2048) - parser.add_argument('--svpoints', default=1024) - '''model''' - parser.add_argument('--beta_start', default=0.0001) - parser.add_argument('--beta_end', default=0.02) - parser.add_argument('--schedule_type', default='linear') - parser.add_argument('--time_num', default=1000) - - #params - parser.add_argument('--attention', default=True) - parser.add_argument('--dropout', default=0.1) - parser.add_argument('--embed_dim', type=int, default=64) - parser.add_argument('--loss_type', default='mse') - parser.add_argument('--model_mean_type', default='eps') - parser.add_argument('--model_var_type', default='fixedsmall') - - parser.add_argument('--lrE', type=float, default=2e-4, help='learning rate for E, default=0.0002') - parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') - parser.add_argument('--e_decay', type=float, default=0, help='weight decay for EBM') - parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM') - parser.add_argument('--e_gamma', type=float, default=0.998, help='lr decay for EBM') - - parser.add_argument('--netE', default='', help="path to netE (to continue training)") - - - '''distributed''' - parser.add_argument('--world_size', default=1, type=int, - help='Number of distributed nodes.') - parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str, - help='url used to set up distributed training') - parser.add_argument('--dist_backend', default='nccl', type=str, - help='distributed backend') - parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None], - help='Use multi-processing distributed training to launch ' - 'N processes per node, which has N GPUs. This is the ' - 'fastest way to use PyTorch for either single node or ' - 'multi node data parallel training') - parser.add_argument('--rank', default=0, type=int, - help='node rank for distributed training') - parser.add_argument('--gpu', default=None, type=int, - help='GPU id to use. None means using all available GPUs.') - - '''eval''' - parser.add_argument('--saveIter', default=100, type=int,help='unit: epoch') - parser.add_argument('--diagIter', default=50, type=int, help='unit: epoch') - parser.add_argument('--vizIter', default=50, type=int,help='unit: epoch') - parser.add_argument('--print_freq', default=50, type=int,help='unit: iter') - - parser.add_argument('--manualSeed', default=42, type=int, help='random seed') - - - opt = parser.parse_args() - - return opt - -if __name__ == '__main__': - main() diff --git a/shape_completion/__init__.py b/shape_completion/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/shape_completion/control_gen_chair.py b/shape_completion/control_gen_chair.py deleted file mode 100644 index 96a1387..0000000 --- a/shape_completion/control_gen_chair.py +++ /dev/null @@ -1,660 +0,0 @@ - -from pprint import pprint -from metrics.evaluation_metrics import compute_all_metrics, EMD_CD - -import torch.nn as nn -import torch.utils.data - -import argparse -from torch.distributions import Normal -from utils.visualize import * -from utils.file_utils import * -from utils.mitsuba_renderer import write_to_xml_batch -from model.pvcnn_completion import PVCNN2Base - -from datasets.shapenet_data_pc import ShapeNet15kPointClouds -from datasets.shapenet_data_sv import * -''' -models -''' -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - KL divergence between normal distributions parameterized by mean and log-variance. - """ - return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) - + (mean1 - mean2)**2 * torch.exp(-logvar2)) - -def discretized_gaussian_log_likelihood(x, *, means, log_scales): - # Assumes data is integers [0, 1] - assert x.shape == means.shape == log_scales.shape - px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) - - centered_x = x - means - inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_x + 0.5) - cdf_plus = px0.cdf(plus_in) - min_in = inv_stdv * (centered_x - .5) - cdf_min = px0.cdf(min_in) - log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) - log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) - cdf_delta = cdf_plus - cdf_min - - log_probs = torch.where( - x < 0.001, log_cdf_plus, - torch.where(x > 0.999, log_one_minus_cdf_min, - torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) - assert log_probs.shape == x.shape - return log_probs - - - -class GaussianDiffusion: - def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): - self.loss_type = loss_type - self.model_mean_type = model_mean_type - self.model_var_type = model_var_type - assert isinstance(betas, np.ndarray) - self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy - assert (betas > 0).all() and (betas <= 1).all() - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.sv_points = sv_points - # initialize twice the actual length so we can keep running for eval - # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) - - alphas = 1. - betas - alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() - alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() - - self.betas = torch.from_numpy(betas).float() - self.alphas_cumprod = alphas_cumprod.float() - self.alphas_cumprod_prev = alphas_cumprod_prev.float() - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() - self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() - self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() - self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() - - betas = torch.from_numpy(betas).float() - alphas = torch.from_numpy(alphas).float() - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.posterior_variance = posterior_variance - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) - self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) - self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) - - @staticmethod - def _extract(a, t, x_shape): - """ - Extract some coefficients at specified timesteps, - then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. - """ - bs, = t.shape - assert x_shape[0] == bs - out = torch.gather(a, 0, t) - assert out.shape == torch.Size([bs]) - return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) - - - - def q_mean_variance(self, x_start, t): - mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start - variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) - log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) - return mean, variance, log_variance - - def q_sample(self, x_start, t, noise=None): - """ - Diffuse the data (t == 0 means diffused for 1 step) - """ - if noise is None: - noise = torch.randn(x_start.shape, device=x_start.device) - assert noise.shape == x_start.shape - return ( - self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + - self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise - ) - - - def q_posterior_mean_variance(self, x_start, x_t, t): - """ - Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) - """ - assert x_start.shape == x_t.shape - posterior_mean = ( - self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + - self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t - ) - posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) - posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) - assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == - x_start.shape[0]) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - - def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): - - model_output = denoise_fn(data, t)[:,:,self.sv_points:] - - - if self.model_var_type in ['fixedsmall', 'fixedlarge']: - # below: only log_variance is used in the KL computations - model_variance, model_log_variance = { - # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood - 'fixedlarge': (self.betas.to(data.device), - torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), - 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), - }[self.model_var_type] - model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) - model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) - else: - raise NotImplementedError(self.model_var_type) - - if self.model_mean_type == 'eps': - x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) - - - model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) - else: - raise NotImplementedError(self.loss_type) - - - assert model_mean.shape == x_recon.shape - assert model_variance.shape == model_log_variance.shape - if return_pred_xstart: - return model_mean, model_variance, model_log_variance, x_recon - else: - return model_mean, model_variance, model_log_variance - - def _predict_xstart_from_eps(self, x_t, t, eps): - assert x_t.shape == eps.shape - return ( - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps - ) - - ''' samples ''' - - def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): - """ - Sample from the model - """ - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, - return_pred_xstart=True) - noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) - - # no noise when t == 0 - nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) - - sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise - sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) - return (sample, pred_xstart) if return_pred_xstart else sample - - - def p_sample_loop(self, partial_x, denoise_fn, shape, device, - noise_fn=torch.randn, clip_denoised=True, keep_running=False): - """ - Generate samples - keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps - - """ - - assert isinstance(shape, (tuple, list)) - img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) - for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, return_pred_xstart=False) - - assert img_t[:,:,self.sv_points:].shape == shape - return img_t - - def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, - noise_fn=torch.randn,clip_denoised=True, keep_running=False): - """ - Generate samples, returning intermediate images - Useful for visualizing how denoised images evolve over time - Args: - repeat_noise_steps (int): Number of denoising timesteps in which the same noise - is used across the batch. If >= 0, the initial noise is the same for all batch elemements. - """ - assert isinstance(shape, (tuple, list)) - - total_steps = self.num_timesteps if not keep_running else len(self.betas) - - img_t = noise_fn(size=shape, dtype=torch.float, device=device) - imgs = [img_t] - for t in reversed(range(0,total_steps)): - - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, - return_pred_xstart=False) - if t % freq == 0 or t == total_steps-1: - imgs.append(img_t) - - assert imgs[-1].shape == shape - return imgs - - - - - '''losses''' - - def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): - true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( - denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) - - kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) - kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) - - return (kl, pred_xstart) if return_pred_xstart else kl - - def p_losses(self, denoise_fn, data_start, t, noise=None): - """ - Training loss calculation - """ - B, D, N = data_start.shape - assert t.shape == torch.Size([B]) - - if noise is None: - noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) - - data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) - - if self.loss_type == 'mse': - # predict the noise instead of x_start. seems to be weighted naturally like SNR - eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] - - losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) - elif self.loss_type == 'kl': - losses = self._vb_terms_bpd( - denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, - return_pred_xstart=False) - else: - raise NotImplementedError(self.loss_type) - - assert losses.shape == torch.Size([B]) - return losses - - '''debug''' - - def _prior_bpd(self, x_start): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) - qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, - mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) - assert kl_prior.shape == x_start.shape - return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) - - def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - - vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) - for t in reversed(range(T)): - - t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) - # Calculate VLB term at the current timestep - data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) - new_vals_b, pred_xstart = self._vb_terms_bpd( - denoise_fn, data_start=x_start, data_t=data_t, t=t_b, - clip_denoised=clip_denoised, return_pred_xstart=True) - # MSE for progressive prediction loss - assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape - new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) - assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) - # Insert the calculated term into the tensor of all terms - mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() - vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt - mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt - assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) - - prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) - total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b - assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ - total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) - return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() - - def interpolate(self, x0_part, x1_part, x0_sv, x1_sv, t, lamb, denoise_fn, noise_fn=torch.randn): - - assert t >= 1 - - t_vec = torch.empty(x0_part.shape[0], dtype=torch.int64, device=x0_part.device).fill_(t-1) - encoding0 = self.q_sample(x0_part, t_vec) - encoding1 = self.q_sample(x1_part, t_vec) - - enc = encoding0 * (1-lamb) + (lamb) * encoding1 - - img_t = torch.cat([torch.cat([x0_sv[:,:,:int(self.sv_points*(1-lamb))], x1_sv[:,:,:(self.sv_points - int(self.sv_points*(1-lamb)))]], dim=-1), enc], dim=-1) - - for k in reversed(range(0,t)): - t_ = torch.empty(img_t.shape[0], dtype=torch.int64, device=img_t.device).fill_(k) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=False, return_pred_xstart=False).detach() - - - return img_t - - -class PVCNN2(PVCNN2Base): - sa_blocks = [ - ((32, 2, 32), (1024, 0.1, 32, (32, 64))), - ((64, 3, 16), (256, 0.2, 32, (64, 128))), - ((128, 3, 8), (64, 0.4, 32, (128, 256))), - (None, (16, 0.8, 32, (256, 256, 512))), - ] - fp_blocks = [ - ((256, 256), (256, 3, 8)), - ((256, 256), (256, 3, 8)), - ((256, 128), (128, 2, 16)), - ((128, 128, 64), (64, 2, 32)), - ] - - def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, - voxel_resolution_multiplier=1): - super().__init__( - num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, - dropout=dropout, extra_feature_channels=extra_feature_channels, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier - ) - - -class Model(nn.Module): - def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): - super(Model, self).__init__() - self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) - - self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, - dropout=args.dropout, extra_feature_channels=0) - - def prior_kl(self, x0): - return self.diffusion._prior_bpd(x0) - - def all_kl(self, x0, clip_denoised=True): - total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) - - return { - 'total_bpd_b': total_bpd_b, - 'terms_bpd': vals_bt, - 'prior_bpd_b': prior_bpd_b, - 'mse_bt':mse_bt - } - - - def _denoise(self, data, t): - B, D,N= data.shape - assert data.dtype == torch.float - assert t.shape == torch.Size([B]) and t.dtype == torch.int64 - - out = self.model(data, t) - - return out - - def get_loss_iter(self, data, noises=None): - B, D, N = data.shape - t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) - - if noises is not None: - noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) - - losses = self.diffusion.p_losses( - denoise_fn=self._denoise, data_start=data, t=t, noise=noises) - assert losses.shape == t.shape == torch.Size([B]) - return losses - - def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, - clip_denoised=True, - keep_running=False): - return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, - clip_denoised=clip_denoised, - keep_running=keep_running) - - def interpolate(self, x0_part, x1_part, x0_sv, x1_sv, t, lamb): - - return self.diffusion.interpolate(x0_part, x1_part, x0_sv, x1_sv, t, lamb, self._denoise) - - - def train(self): - self.model.train() - - def eval(self): - self.model.eval() - - def multi_gpu_wrapper(self, f): - self.model = f(self.model) - -def get_betas(schedule_type, b_start, b_end, time_num): - if schedule_type == 'linear': - betas = np.linspace(b_start, b_end, time_num) - elif schedule_type == 'warm0.1': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.1) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.2': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.2) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.5': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.5) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - else: - raise NotImplementedError(schedule_type) - return betas - - - -############################################################################# -def get_svr_dataset(pc_dataroot, mesh_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, - cache=os.path.join(mesh_root, '../cache'), split='val', - categories=category, - radius=3, elev=-89, azim=180, img_size=512, focal_length=1000, - npoints=npoints, sv_samples=200, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - ) - return te_dataset - -def get_mvr_dataset(pc_dataroot, views_root, npoints,category, get_image=True): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, - cache=os.path.join(pc_dataroot, '../cache'), split='val', - categories=category, - npoints=npoints, sv_samples=200, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, get_image=get_image, - ) - return te_dataset - - -def generate_multimodal(opt, netE, save_dir, logger): - test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, - opt.npoints, opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - if i!=3: - continue - gt_all = data['test_points'] - x_all = data['sv_points'] - img_all = data['image'] - - for v in range(20): - - recons = [] - svs = [] - for p in [0,1]: - x = x_all[:,p].transpose(1, 2).contiguous() - img = img_all[:,p] - - m, s = data['mean'].float(), data['std'].float() - - recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', - clip_denoised=False).detach().cpu() - - recon = recon.transpose(1, 2).contiguous() - x = x.transpose(1, 2).contiguous() - - recons.append(recon) - svs.append(x[:, :opt.svpoints,:]) - - for l, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))): - - # im = np.fliplr(np.flipud(d[-1])) - write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,l), 'mode_%03d'%v, 'depth_%d'%p), - (torch.stack(d[:-1], dim=0)* s[0] + m[0]).numpy(), cat='chair') - - export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,l), 'mode_%03d'%v, 'depth_%d'%p), - (torch.stack(d[:-1], dim=0)* s[0] + m[0]).numpy()) - plt.imsave(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, l), 'mode_%03d' % v, 'depth_%d.png' % p), - d[-1].permute(1, 2, 0), cmap='gray') - - x0_part = recons[0].transpose(1, 2).contiguous()[:,:,opt.svpoints:].cuda() - x1_part = recons[1].transpose(1, 2).contiguous()[:,:,opt.svpoints:].cuda() - x0_sv = svs[0].transpose(1,2).cuda() - x1_sv = svs[1].transpose(1,2).cuda() - - interres = [] - for lamb in np.linspace(0.1, 0.9, 5): - res = netE.interpolate(x0_part, x1_part, x0_sv, x1_sv, 1000, lamb) - - res = torch.cat([x0_sv, x1_sv, res[:,:,opt.svpoints:]], dim=-1).detach().cpu().transpose(1,2).contiguous() - interres.append(res) - for l, d in enumerate(torch.stack(interres, dim=1)): - write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d' % (i, l), 'mode_%03d' % v), - (d* s[0] + m[0]).numpy(), cat='chair') - export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, l), 'mode_%03d' % v), - (d * s[0] + m[0]).numpy()) - -def main(opt): - exp_id = os.path.splitext(os.path.basename(__file__))[0] - dir_id = os.path.dirname(__file__) - output_dir = get_output_dir(dir_id, exp_id) - copy_source(__file__, output_dir) - logger = setup_logging(output_dir) - - outf_syn, = setup_output_subdirs(output_dir, 'syn') - - betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) - netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) - - if opt.cuda: - netE.cuda() - - def _transform_(m): - return nn.parallel.DataParallel(m) - - netE = netE.cuda() - netE.multi_gpu_wrapper(_transform_) - - netE.eval() - - ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] - - with torch.no_grad(): - for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): - - opt.netE = ckpt - logger.info("Resume Path:%s" % opt.netE) - - resumed_param = torch.load(opt.netE) - netE.load_state_dict(resumed_param['model_state']) - - - generate_multimodal(opt, netE, outf_syn, logger) - - -def parse_args(): - - parser = argparse.ArgumentParser() - parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') - parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', - help='input batch size') - parser.add_argument('--classes', default=['chair']) - - parser.add_argument('--batch_size', type=int, default=1, help='input batch size') - parser.add_argument('--workers', type=int, default=16, help='workers') - parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') - - parser.add_argument('--eval_recon_mvr', default=False) - parser.add_argument('--generate_multimodal', default=False) - parser.add_argument('--eval_saved', default=False) - parser.add_argument('--eval_redwood', default=True) - - parser.add_argument('--nc', default=3) - parser.add_argument('--npoints', default=2048) - parser.add_argument('--svpoints', default=200) - '''model''' - parser.add_argument('--beta_start', default=0.0001) - parser.add_argument('--beta_end', default=0.02) - parser.add_argument('--schedule_type', default='linear') - parser.add_argument('--time_num', default=1000) - - #params - parser.add_argument('--attention', default=True) - parser.add_argument('--dropout', default=0.1) - parser.add_argument('--embed_dim', type=int, default=64) - parser.add_argument('--loss_type', default='mse') - parser.add_argument('--model_mean_type', default='eps') - parser.add_argument('--model_var_type', default='fixedsmall') - - - # constrain function - parser.add_argument('--constrain_eps', default=0.2) - parser.add_argument('--constrain_steps', type=int, default=1) - - parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/6_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-11-01-19-21-18', help="path to netE (to continue training)") - - '''eval''' - - parser.add_argument('--eval_path', - default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth') - - parser.add_argument('--manualSeed', default=None, type=int, help='random seed') - - parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') - - opt = parser.parse_args() - - if torch.cuda.is_available(): - opt.cuda = True - else: - opt.cuda = False - - return opt -if __name__ == '__main__': - opt = parse_args() - set_seed(opt) - main(opt) diff --git a/shape_completion/teaser_chair.py b/shape_completion/teaser_chair.py deleted file mode 100644 index 5d022e5..0000000 --- a/shape_completion/teaser_chair.py +++ /dev/null @@ -1,706 +0,0 @@ - - -import torch.nn as nn -import torch.utils.data - -import argparse -from torch.distributions import Normal -from utils.visualize import * -from utils.file_utils import * -from utils.mitsuba_renderer2 import write_to_xml_batch, write_to_xml -from model.pvcnn_completion import PVCNN2Base - -from datasets.shapenet_data_pc import ShapeNet15kPointClouds -from datasets.shapenet_data_sv import * -''' -models -''' -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - KL divergence between normal distributions parameterized by mean and log-variance. - """ - return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) - + (mean1 - mean2)**2 * torch.exp(-logvar2)) - -def discretized_gaussian_log_likelihood(x, *, means, log_scales): - # Assumes data is integers [0, 1] - assert x.shape == means.shape == log_scales.shape - px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) - - centered_x = x - means - inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_x + 0.5) - cdf_plus = px0.cdf(plus_in) - min_in = inv_stdv * (centered_x - .5) - cdf_min = px0.cdf(min_in) - log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) - log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) - cdf_delta = cdf_plus - cdf_min - - log_probs = torch.where( - x < 0.001, log_cdf_plus, - torch.where(x > 0.999, log_one_minus_cdf_min, - torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) - assert log_probs.shape == x.shape - return log_probs - - - -class GaussianDiffusion: - def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): - self.loss_type = loss_type - self.model_mean_type = model_mean_type - self.model_var_type = model_var_type - assert isinstance(betas, np.ndarray) - self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy - assert (betas > 0).all() and (betas <= 1).all() - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.sv_points = sv_points - # initialize twice the actual length so we can keep running for eval - # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) - - alphas = 1. - betas - alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() - alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() - - self.betas = torch.from_numpy(betas).float() - self.alphas_cumprod = alphas_cumprod.float() - self.alphas_cumprod_prev = alphas_cumprod_prev.float() - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() - self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() - self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() - self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() - - betas = torch.from_numpy(betas).float() - alphas = torch.from_numpy(alphas).float() - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.posterior_variance = posterior_variance - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) - self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) - self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) - - @staticmethod - def _extract(a, t, x_shape): - """ - Extract some coefficients at specified timesteps, - then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. - """ - bs, = t.shape - assert x_shape[0] == bs - out = torch.gather(a, 0, t) - assert out.shape == torch.Size([bs]) - return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) - - - - def q_mean_variance(self, x_start, t): - mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start - variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) - log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) - return mean, variance, log_variance - - def q_sample(self, x_start, t, noise=None): - """ - Diffuse the data (t == 0 means diffused for 1 step) - """ - if noise is None: - noise = torch.randn(x_start.shape, device=x_start.device) - assert noise.shape == x_start.shape - return ( - self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + - self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise - ) - - - def q_posterior_mean_variance(self, x_start, x_t, t): - """ - Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) - """ - assert x_start.shape == x_t.shape - posterior_mean = ( - self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + - self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t - ) - posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) - posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) - assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == - x_start.shape[0]) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - - def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): - - model_output = denoise_fn(data, t)[:,:,self.sv_points:] - - - if self.model_var_type in ['fixedsmall', 'fixedlarge']: - # below: only log_variance is used in the KL computations - model_variance, model_log_variance = { - # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood - 'fixedlarge': (self.betas.to(data.device), - torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), - 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), - }[self.model_var_type] - model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) - model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) - else: - raise NotImplementedError(self.model_var_type) - - if self.model_mean_type == 'eps': - x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) - - - model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) - else: - raise NotImplementedError(self.loss_type) - - - assert model_mean.shape == x_recon.shape - assert model_variance.shape == model_log_variance.shape - if return_pred_xstart: - return model_mean, model_variance, model_log_variance, x_recon - else: - return model_mean, model_variance, model_log_variance - - def _predict_xstart_from_eps(self, x_t, t, eps): - assert x_t.shape == eps.shape - return ( - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps - ) - - ''' samples ''' - - def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): - """ - Sample from the model - """ - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, - return_pred_xstart=True) - noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) - - # no noise when t == 0 - nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) - - sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise - sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) - return (sample, pred_xstart) if return_pred_xstart else sample - - - def p_sample_loop(self, partial_x, denoise_fn, shape, device, - noise_fn=torch.randn, clip_denoised=True, keep_running=False): - """ - Generate samples - keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps - - """ - - assert isinstance(shape, (tuple, list)) - img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) - for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, return_pred_xstart=False) - - assert img_t[:,:,self.sv_points:].shape == shape - return img_t - - - def p_sample_loop_trajectory2(self, partial_x, denoise_fn, shape, device, num_save, - noise_fn=torch.randn,clip_denoised=True, keep_running=False): - """ - Generate samples, returning intermediate images - Useful for visualizing how denoised images evolve over time - Args: - repeat_noise_steps (int): Number of denoising timesteps in which the same noise - is used across the batch. If >= 0, the initial noise is the same for all batch elemements. - """ - assert isinstance(shape, (tuple, list)) - - total_steps = self.num_timesteps if not keep_running else len(self.betas) - - scale = np.exp(np.log(1/total_steps)/num_save) - save_step = total_steps - - img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) - imgs = [img_t.detach().cpu()] - for t in reversed(range(0,total_steps)): - if (t+1) == save_step and t > 0 and len(imgs) 0.999, log_one_minus_cdf_min, - torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) - assert log_probs.shape == x.shape - return log_probs - - - -class GaussianDiffusion: - def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): - self.loss_type = loss_type - self.model_mean_type = model_mean_type - self.model_var_type = model_var_type - assert isinstance(betas, np.ndarray) - self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy - assert (betas > 0).all() and (betas <= 1).all() - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.sv_points = sv_points - # initialize twice the actual length so we can keep running for eval - # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) - - alphas = 1. - betas - alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() - alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() - - self.betas = torch.from_numpy(betas).float() - self.alphas_cumprod = alphas_cumprod.float() - self.alphas_cumprod_prev = alphas_cumprod_prev.float() - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() - self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() - self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() - self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() - - betas = torch.from_numpy(betas).float() - alphas = torch.from_numpy(alphas).float() - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.posterior_variance = posterior_variance - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) - self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) - self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) - - @staticmethod - def _extract(a, t, x_shape): - """ - Extract some coefficients at specified timesteps, - then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. - """ - bs, = t.shape - assert x_shape[0] == bs - out = torch.gather(a, 0, t) - assert out.shape == torch.Size([bs]) - return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) - - - - def q_mean_variance(self, x_start, t): - mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start - variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) - log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) - return mean, variance, log_variance - - def q_sample(self, x_start, t, noise=None): - """ - Diffuse the data (t == 0 means diffused for 1 step) - """ - if noise is None: - noise = torch.randn(x_start.shape, device=x_start.device) - assert noise.shape == x_start.shape - return ( - self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + - self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise - ) - - - def q_posterior_mean_variance(self, x_start, x_t, t): - """ - Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) - """ - assert x_start.shape == x_t.shape - posterior_mean = ( - self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + - self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t - ) - posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) - posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) - assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == - x_start.shape[0]) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - - def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): - - model_output = denoise_fn(data, t)[:,:,self.sv_points:] - - - if self.model_var_type in ['fixedsmall', 'fixedlarge']: - # below: only log_variance is used in the KL computations - model_variance, model_log_variance = { - # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood - 'fixedlarge': (self.betas.to(data.device), - torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), - 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), - }[self.model_var_type] - model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) - model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) - else: - raise NotImplementedError(self.model_var_type) - - if self.model_mean_type == 'eps': - x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) - - - model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) - else: - raise NotImplementedError(self.loss_type) - - - assert model_mean.shape == x_recon.shape - assert model_variance.shape == model_log_variance.shape - if return_pred_xstart: - return model_mean, model_variance, model_log_variance, x_recon - else: - return model_mean, model_variance, model_log_variance - - def _predict_xstart_from_eps(self, x_t, t, eps): - assert x_t.shape == eps.shape - return ( - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps - ) - - ''' samples ''' - - def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): - """ - Sample from the model - """ - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, - return_pred_xstart=True) - noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) - - # no noise when t == 0 - nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) - - sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise - sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) - return (sample, pred_xstart) if return_pred_xstart else sample - - - def p_sample_loop(self, partial_x, denoise_fn, shape, device, - noise_fn=torch.randn, clip_denoised=True, keep_running=False): - """ - Generate samples - keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps - - """ - - assert isinstance(shape, (tuple, list)) - img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) - for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, return_pred_xstart=False) - - assert img_t[:,:,self.sv_points:].shape == shape - return img_t - - def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, - noise_fn=torch.randn,clip_denoised=True, keep_running=False): - """ - Generate samples, returning intermediate images - Useful for visualizing how denoised images evolve over time - Args: - repeat_noise_steps (int): Number of denoising timesteps in which the same noise - is used across the batch. If >= 0, the initial noise is the same for all batch elemements. - """ - assert isinstance(shape, (tuple, list)) - - total_steps = self.num_timesteps if not keep_running else len(self.betas) - - img_t = noise_fn(size=shape, dtype=torch.float, device=device) - imgs = [img_t] - for t in reversed(range(0,total_steps)): - - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, - return_pred_xstart=False) - if t % freq == 0 or t == total_steps-1: - imgs.append(img_t) - - assert imgs[-1].shape == shape - return imgs - - '''losses''' - - def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): - true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( - denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) - - kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) - kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) - - return (kl, pred_xstart) if return_pred_xstart else kl - - def p_losses(self, denoise_fn, data_start, t, noise=None): - """ - Training loss calculation - """ - B, D, N = data_start.shape - assert t.shape == torch.Size([B]) - - if noise is None: - noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) - - data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) - - if self.loss_type == 'mse': - # predict the noise instead of x_start. seems to be weighted naturally like SNR - eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] - - losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) - elif self.loss_type == 'kl': - losses = self._vb_terms_bpd( - denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, - return_pred_xstart=False) - else: - raise NotImplementedError(self.loss_type) - - assert losses.shape == torch.Size([B]) - return losses - - '''debug''' - - def _prior_bpd(self, x_start): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) - qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, - mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) - assert kl_prior.shape == x_start.shape - return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) - - def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - - vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) - for t in reversed(range(T)): - - t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) - # Calculate VLB term at the current timestep - data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) - new_vals_b, pred_xstart = self._vb_terms_bpd( - denoise_fn, data_start=x_start, data_t=data_t, t=t_b, - clip_denoised=clip_denoised, return_pred_xstart=True) - # MSE for progressive prediction loss - assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape - new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) - assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) - # Insert the calculated term into the tensor of all terms - mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() - vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt - mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt - assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) - - prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) - total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b - assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ - total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) - return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() - - -class PVCNN2(PVCNN2Base): - sa_blocks = [ - ((32, 2, 32), (1024, 0.1, 32, (32, 64))), - ((64, 3, 16), (256, 0.2, 32, (64, 128))), - ((128, 3, 8), (64, 0.4, 32, (128, 256))), - (None, (16, 0.8, 32, (256, 256, 512))), - ] - fp_blocks = [ - ((256, 256), (256, 3, 8)), - ((256, 256), (256, 3, 8)), - ((256, 128), (128, 2, 16)), - ((128, 128, 64), (64, 2, 32)), - ] - - def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, - voxel_resolution_multiplier=1): - super().__init__( - num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, - dropout=dropout, extra_feature_channels=extra_feature_channels, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier - ) - - -class Model(nn.Module): - def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): - super(Model, self).__init__() - self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) - - self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, - dropout=args.dropout, extra_feature_channels=0) - - def prior_kl(self, x0): - return self.diffusion._prior_bpd(x0) - - def all_kl(self, x0, clip_denoised=True): - total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) - - return { - 'total_bpd_b': total_bpd_b, - 'terms_bpd': vals_bt, - 'prior_bpd_b': prior_bpd_b, - 'mse_bt':mse_bt - } - - - def _denoise(self, data, t): - B, D,N= data.shape - assert data.dtype == torch.float - assert t.shape == torch.Size([B]) and t.dtype == torch.int64 - - out = self.model(data, t) - - return out - - def get_loss_iter(self, data, noises=None): - B, D, N = data.shape - t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) - - if noises is not None: - noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) - - losses = self.diffusion.p_losses( - denoise_fn=self._denoise, data_start=data, t=t, noise=noises) - assert losses.shape == t.shape == torch.Size([B]) - return losses - - def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, - clip_denoised=True, - keep_running=False): - return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, - clip_denoised=clip_denoised, - keep_running=keep_running) - - - def train(self): - self.model.train() - - def eval(self): - self.model.eval() - - def multi_gpu_wrapper(self, f): - self.model = f(self.model) - -def get_betas(schedule_type, b_start, b_end, time_num): - if schedule_type == 'linear': - betas = np.linspace(b_start, b_end, time_num) - elif schedule_type == 'warm0.1': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.1) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.2': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.2) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.5': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.5) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - else: - raise NotImplementedError(schedule_type) - return betas - - - -############################################################################# -def get_svr_dataset(pc_dataroot, mesh_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, - cache=os.path.join(mesh_root, '../cache'), split='val', - categories=category, - radius=3, elev=0, azim=0, img_size=512, focal_length=1000, - npoints=npoints, sv_samples=200, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - ) - return te_dataset - -def get_mvr_dataset(pc_dataroot, views_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, - cache=os.path.join(pc_dataroot, '../cache'), split='val', - categories=category, - npoints=npoints, sv_samples=200, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - ) - return te_dataset - - -def evaluate_recon_mvr(opt, netE, save_dir, logger): - test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, - opt.npoints, opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - ref = [] - samples = [] - images = [] - masked = [] - k = 0 - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - - gt_all = data['test_points'] - x_all = data['sv_points'] - img_all = data['image'] - - - B,V,N,C = x_all.shape - gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1) - - x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous() - img = img_all.reshape(B * V, *img_all.shape[2:]) - - m, s = data['mean'].float(), data['std'].float() - - recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', - clip_denoised=False).detach().cpu() - - recon = recon.transpose(1, 2).contiguous() - x = x.transpose(1, 2).contiguous() - - - x_adj = x.reshape(B,V,N,C)* s + m - recon_adj = recon.reshape(B,V,N,C)* s + m - img = img.reshape(B,V,*img.shape[1:]) - - ref.append( gt_all * s + m) - masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) - samples.append(recon_adj) - images.append(img) - - ref_pcs = torch.cat(ref, dim=0) - sample_pcs = torch.cat(samples, dim=0) - images = torch.cat(images, dim=0) - masked = torch.cat(masked, dim=0) - - B, V, N, C = ref_pcs.shape - - - torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth')) - # torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) - torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) - # Compute metrics - results = EMD_CD(sample_pcs.reshape(B*V, N, C), - ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) - - results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} - - pprint({key: val.mean().item() for key, val in results.items()}) - logger.info({key: val.mean().item() for key, val in results.items()}) - - results['pc'] = sample_pcs - torch.save(results, os.path.join(save_dir, 'ours_results.pth')) - - - -def generate_multimodal(opt, netE, save_dir, logger): - test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv, - opt.npoints, opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - - gt_all = data['test_points'] - x_all = data['sv_points'] - img_all = data['image'] - - Path(os.path.join(save_dir, 'x_%03d'%k)).mkdir(exist_ok=True, parents=True) - Path(os.path.join(save_dir, 'x_ply_%03d' % k)).mkdir(exist_ok=True, parents=True) - - for v in range(5): - x = x_all.transpose(1, 2).contiguous() - img = img_all - - m, s = data['mean'].float(), data['std'].float() - - recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', - clip_denoised=False).detach().cpu() - - recon = recon.transpose(1, 2).contiguous() - x = x.transpose(1, 2).contiguous() - - for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))): - - im = np.fliplr(np.flipud(d[-1])) - plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray') - write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - - export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - - -def main(opt): - exp_id = os.path.splitext(os.path.basename(__file__))[0] - dir_id = os.path.dirname(__file__) - output_dir = get_output_dir(dir_id, exp_id) - copy_source(__file__, output_dir) - logger = setup_logging(output_dir) - - outf_syn, = setup_output_subdirs(output_dir, 'syn') - - betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) - netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) - - if opt.cuda: - netE.cuda() - - def _transform_(m): - return nn.parallel.DataParallel(m) - - netE = netE.cuda() - netE.multi_gpu_wrapper(_transform_) - - netE.eval() - - ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] - - with torch.no_grad(): - for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): - - opt.netE = ckpt - logger.info("Resume Path:%s" % opt.netE) - - resumed_param = torch.load(opt.netE) - netE.load_state_dict(resumed_param['model_state']) - - - if opt.eval_recon_mvr: - # Evaluate generation - evaluate_recon_mvr(opt, netE, outf_syn, logger) - - if opt.generate_multimodal: - - generate_multimodal(opt, netE, outf_syn, logger) - - -def parse_args(): - - parser = argparse.ArgumentParser() - parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') - parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', - help='input batch size') - parser.add_argument('--classes', default=['car']) - - parser.add_argument('--batch_size', type=int, default=8, help='input batch size') - parser.add_argument('--workers', type=int, default=16, help='workers') - parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') - - parser.add_argument('--eval_recon_mvr', default=True) - parser.add_argument('--generate_multimodal', default=False) - - parser.add_argument('--nc', default=3) - parser.add_argument('--npoints', default=2048) - parser.add_argument('--svpoints', default=200) - '''model''' - parser.add_argument('--beta_start', default=0.0001) - parser.add_argument('--beta_end', default=0.02) - parser.add_argument('--schedule_type', default='linear') - parser.add_argument('--time_num', default=1000) - - #params - parser.add_argument('--attention', default=True) - parser.add_argument('--dropout', default=0.1) - parser.add_argument('--embed_dim', type=int, default=64) - parser.add_argument('--loss_type', default='mse') - parser.add_argument('--model_mean_type', default='eps') - parser.add_argument('--model_var_type', default='fixedsmall') - - - # constrain function - parser.add_argument('--constrain_eps', default=0.2) - parser.add_argument('--constrain_steps', type=int, default=1) - - parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/3_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-11-01-03-08-40', help="path to netE (to continue training)") - - '''eval''' - - parser.add_argument('--eval_path', - default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth') - - parser.add_argument('--manualSeed', default=42, type=int, help='random seed') - - parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') - - opt = parser.parse_args() - - if torch.cuda.is_available(): - opt.cuda = True - else: - opt.cuda = False - - return opt -if __name__ == '__main__': - opt = parse_args() - set_seed(opt) - - main(opt) diff --git a/shape_completion/test_chair.py b/shape_completion/test_chair.py deleted file mode 100644 index dda8e3c..0000000 --- a/shape_completion/test_chair.py +++ /dev/null @@ -1,753 +0,0 @@ - -from pprint import pprint -from metrics.evaluation_metrics import compute_all_metrics, EMD_CD - -import torch.nn as nn -import torch.utils.data - -import argparse -from torch.distributions import Normal -from utils.visualize import * -from utils.file_utils import * -from utils.mitsuba_renderer import write_to_xml_batch -from model.pvcnn_completion import PVCNN2Base - -from datasets.shapenet_data_pc import ShapeNet15kPointClouds -from datasets.shapenet_data_sv import * -''' -models -''' -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - KL divergence between normal distributions parameterized by mean and log-variance. - """ - return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) - + (mean1 - mean2)**2 * torch.exp(-logvar2)) - -def discretized_gaussian_log_likelihood(x, *, means, log_scales): - # Assumes data is integers [0, 1] - assert x.shape == means.shape == log_scales.shape - px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) - - centered_x = x - means - inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_x + 0.5) - cdf_plus = px0.cdf(plus_in) - min_in = inv_stdv * (centered_x - .5) - cdf_min = px0.cdf(min_in) - log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) - log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) - cdf_delta = cdf_plus - cdf_min - - log_probs = torch.where( - x < 0.001, log_cdf_plus, - torch.where(x > 0.999, log_one_minus_cdf_min, - torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) - assert log_probs.shape == x.shape - return log_probs - - - -class GaussianDiffusion: - def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): - self.loss_type = loss_type - self.model_mean_type = model_mean_type - self.model_var_type = model_var_type - assert isinstance(betas, np.ndarray) - self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy - assert (betas > 0).all() and (betas <= 1).all() - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.sv_points = sv_points - # initialize twice the actual length so we can keep running for eval - # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) - - alphas = 1. - betas - alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() - alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() - - self.betas = torch.from_numpy(betas).float() - self.alphas_cumprod = alphas_cumprod.float() - self.alphas_cumprod_prev = alphas_cumprod_prev.float() - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() - self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() - self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() - self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() - - betas = torch.from_numpy(betas).float() - alphas = torch.from_numpy(alphas).float() - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.posterior_variance = posterior_variance - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) - self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) - self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) - - @staticmethod - def _extract(a, t, x_shape): - """ - Extract some coefficients at specified timesteps, - then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. - """ - bs, = t.shape - assert x_shape[0] == bs - out = torch.gather(a, 0, t) - assert out.shape == torch.Size([bs]) - return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) - - - - def q_mean_variance(self, x_start, t): - mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start - variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) - log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) - return mean, variance, log_variance - - def q_sample(self, x_start, t, noise=None): - """ - Diffuse the data (t == 0 means diffused for 1 step) - """ - if noise is None: - noise = torch.randn(x_start.shape, device=x_start.device) - assert noise.shape == x_start.shape - return ( - self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + - self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise - ) - - - def q_posterior_mean_variance(self, x_start, x_t, t): - """ - Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) - """ - assert x_start.shape == x_t.shape - posterior_mean = ( - self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + - self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t - ) - posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) - posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) - assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == - x_start.shape[0]) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - - def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): - - model_output = denoise_fn(data, t)[:,:,self.sv_points:] - - - if self.model_var_type in ['fixedsmall', 'fixedlarge']: - # below: only log_variance is used in the KL computations - model_variance, model_log_variance = { - # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood - 'fixedlarge': (self.betas.to(data.device), - torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), - 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), - }[self.model_var_type] - model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) - model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) - else: - raise NotImplementedError(self.model_var_type) - - if self.model_mean_type == 'eps': - x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) - - - model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) - else: - raise NotImplementedError(self.loss_type) - - - assert model_mean.shape == x_recon.shape - assert model_variance.shape == model_log_variance.shape - if return_pred_xstart: - return model_mean, model_variance, model_log_variance, x_recon - else: - return model_mean, model_variance, model_log_variance - - def _predict_xstart_from_eps(self, x_t, t, eps): - assert x_t.shape == eps.shape - return ( - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps - ) - - ''' samples ''' - - def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): - """ - Sample from the model - """ - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, - return_pred_xstart=True) - noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) - - # no noise when t == 0 - nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) - - sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise - sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) - return (sample, pred_xstart) if return_pred_xstart else sample - - - def p_sample_loop(self, partial_x, denoise_fn, shape, device, - noise_fn=torch.randn, clip_denoised=True, keep_running=False): - """ - Generate samples - keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps - - """ - - assert isinstance(shape, (tuple, list)) - img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) - for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, return_pred_xstart=False) - - assert img_t[:,:,self.sv_points:].shape == shape - return img_t - - def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, - noise_fn=torch.randn,clip_denoised=True, keep_running=False): - """ - Generate samples, returning intermediate images - Useful for visualizing how denoised images evolve over time - Args: - repeat_noise_steps (int): Number of denoising timesteps in which the same noise - is used across the batch. If >= 0, the initial noise is the same for all batch elemements. - """ - assert isinstance(shape, (tuple, list)) - - total_steps = self.num_timesteps if not keep_running else len(self.betas) - - img_t = noise_fn(size=shape, dtype=torch.float, device=device) - imgs = [img_t] - for t in reversed(range(0,total_steps)): - - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, - return_pred_xstart=False) - if t % freq == 0 or t == total_steps-1: - imgs.append(img_t) - - assert imgs[-1].shape == shape - return imgs - - '''losses''' - - def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): - true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( - denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) - - kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) - kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) - - return (kl, pred_xstart) if return_pred_xstart else kl - - def p_losses(self, denoise_fn, data_start, t, noise=None): - """ - Training loss calculation - """ - B, D, N = data_start.shape - assert t.shape == torch.Size([B]) - - if noise is None: - noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) - - data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) - - if self.loss_type == 'mse': - # predict the noise instead of x_start. seems to be weighted naturally like SNR - eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] - - losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) - elif self.loss_type == 'kl': - losses = self._vb_terms_bpd( - denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, - return_pred_xstart=False) - else: - raise NotImplementedError(self.loss_type) - - assert losses.shape == torch.Size([B]) - return losses - - '''debug''' - - def _prior_bpd(self, x_start): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) - qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, - mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) - assert kl_prior.shape == x_start.shape - return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) - - def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - - vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) - for t in reversed(range(T)): - - t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) - # Calculate VLB term at the current timestep - data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) - new_vals_b, pred_xstart = self._vb_terms_bpd( - denoise_fn, data_start=x_start, data_t=data_t, t=t_b, - clip_denoised=clip_denoised, return_pred_xstart=True) - # MSE for progressive prediction loss - assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape - new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) - assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) - # Insert the calculated term into the tensor of all terms - mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() - vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt - mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt - assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) - - prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) - total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b - assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ - total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) - return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() - - -class PVCNN2(PVCNN2Base): - sa_blocks = [ - ((32, 2, 32), (1024, 0.1, 32, (32, 64))), - ((64, 3, 16), (256, 0.2, 32, (64, 128))), - ((128, 3, 8), (64, 0.4, 32, (128, 256))), - (None, (16, 0.8, 32, (256, 256, 512))), - ] - fp_blocks = [ - ((256, 256), (256, 3, 8)), - ((256, 256), (256, 3, 8)), - ((256, 128), (128, 2, 16)), - ((128, 128, 64), (64, 2, 32)), - ] - - def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, - voxel_resolution_multiplier=1): - super().__init__( - num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, - dropout=dropout, extra_feature_channels=extra_feature_channels, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier - ) - - -class Model(nn.Module): - def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): - super(Model, self).__init__() - self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) - - self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, - dropout=args.dropout, extra_feature_channels=0) - - def prior_kl(self, x0): - return self.diffusion._prior_bpd(x0) - - def all_kl(self, x0, clip_denoised=True): - total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) - - return { - 'total_bpd_b': total_bpd_b, - 'terms_bpd': vals_bt, - 'prior_bpd_b': prior_bpd_b, - 'mse_bt':mse_bt - } - - - def _denoise(self, data, t): - B, D,N= data.shape - assert data.dtype == torch.float - assert t.shape == torch.Size([B]) and t.dtype == torch.int64 - - out = self.model(data, t) - - return out - - def get_loss_iter(self, data, noises=None): - B, D, N = data.shape - t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) - - if noises is not None: - noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) - - losses = self.diffusion.p_losses( - denoise_fn=self._denoise, data_start=data, t=t, noise=noises) - assert losses.shape == t.shape == torch.Size([B]) - return losses - - def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, - clip_denoised=True, - keep_running=False): - return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, - clip_denoised=clip_denoised, - keep_running=keep_running) - - - def train(self): - self.model.train() - - def eval(self): - self.model.eval() - - def multi_gpu_wrapper(self, f): - self.model = f(self.model) - -def get_betas(schedule_type, b_start, b_end, time_num): - if schedule_type == 'linear': - betas = np.linspace(b_start, b_end, time_num) - elif schedule_type == 'warm0.1': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.1) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.2': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.2) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.5': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.5) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - else: - raise NotImplementedError(schedule_type) - return betas - - - -############################################################################# -def get_svr_dataset(pc_dataroot, mesh_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, - cache=os.path.join(mesh_root, '../cache'), split='val', - categories=category, - radius=3, elev=-89, azim=180, img_size=512, focal_length=1000, - npoints=npoints, sv_samples=200, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - ) - return te_dataset - -def get_mvr_dataset(pc_dataroot, views_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, - cache=os.path.join(pc_dataroot, '../cache'), split='val', - categories=category, - npoints=npoints, sv_samples=200, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - ) - return te_dataset - - -def evaluate_recon_mvr(opt, netE, save_dir, logger): - test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, - opt.npoints, opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - ref = [] - samples = [] - images = [] - masked = [] - k = 0 - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - - gt_all = data['test_points'] - x_all = data['sv_points'] - # img_all = data['image'] - - - B,V,N,C = x_all.shape - gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1) - - x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous() - # img = img_all.reshape(B * V, *img_all.shape[2:]) - - m, s = data['mean'].float(), data['std'].float() - - recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', - clip_denoised=False).detach().cpu() - - recon = recon.transpose(1, 2).contiguous() - x = x.transpose(1, 2).contiguous() - - - x_adj = x.reshape(B,V,N,C)* s + m - recon_adj = recon.reshape(B,V,N,C)* s + m - # img = img.reshape(B,V,*img.shape[1:]) - - ref.append( gt_all * s + m) - masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) - samples.append(recon_adj) - # images.append(img) - - ref_pcs = torch.cat(ref, dim=0) - sample_pcs = torch.cat(samples, dim=0) - # images = torch.cat(images, dim=0) - masked = torch.cat(masked, dim=0) - - B, V, N, C = ref_pcs.shape - - - torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth')) - # torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) - torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) - # Compute metrics - results = EMD_CD(sample_pcs.reshape(B*V, N, C), - ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) - - results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} - - pprint({key: val.mean().item() for key, val in results.items()}) - logger.info({key: val.mean().item() for key, val in results.items()}) - - results['pc'] = sample_pcs - torch.save(results, os.path.join(save_dir, 'ours_results.pth')) - - del ref_pcs, masked, results - -def evaluate_saved(opt, netE, save_dir, logger): - ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn' - - gt_pth = ours_base + '/recon_gt.pth' - ours_pth = ours_base + '/ours_results.pth' - gt = torch.load(gt_pth).permute(1,0,2,3) - ours = torch.load(ours_pth)['pc'].permute(1,0,2,3) - - all_res = {} - for i, (gt_, ours_) in enumerate(zip(gt, ours)): - results = compute_all_metrics(gt_, ours_, opt.batch_size) - - for key, val in results.items(): - if i == 0: - all_res[key] = val - else: - all_res[key] += val - pprint(results) - for key, val in all_res.items(): - all_res[key] = val / gt.shape[0] - - pprint({key: val.mean().item() for key, val in all_res.items()}) - - -def generate_multimodal(opt, netE, save_dir, logger): - test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv, - opt.npoints, opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - - gt_all = data['test_points'] - x_all = data['sv_points'] - img_all = data['image'] - - for v in range(6): - x = x_all.transpose(1, 2).contiguous() - img = img_all - - m, s = data['mean'].float(), data['std'].float() - - recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', - clip_denoised=False).detach().cpu() - - recon = recon.transpose(1, 2).contiguous() - x = x.transpose(1, 2).contiguous() - - for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))): - - im = np.fliplr(np.flipud(d[-1])) - plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray') - write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy(), cat='chair') - - export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - - -def redwood_demo(opt, netE, save_dir, logger): - import open3d as o3d - pth = "/viscam/u/alexzhou907/01DATA/redwood/09515_pc_partial.ply" - pth_gt = "/viscam/u/alexzhou907/01DATA/redwood/09515_pc.ply" - - points = np.asarray(o3d.io.read_point_cloud(pth).points) - - gt_points = np.asarray(o3d.io.read_point_cloud(pth_gt).points) - - np.save('gt.npy', gt_points) - - test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, - opt.npoints, opt.classes) - - m, s = torch.from_numpy(test_dataset[0]['mean']).float(), torch.from_numpy(test_dataset[0]['std']).float() - - x = torch.from_numpy(points[np.random.choice(points.shape[0], size=opt.npoints, replace=False)]).float() - x = (x-m)/s - - x = x.transpose(1,2).cuda() - - res = [] - for k in range(20): - recon = netE.gen_samples(x[:,:,:opt.svpoints], x[:,:,opt.svpoints:].shape, 'cuda', - clip_denoised=False).detach().cpu() - recon = recon.transpose(1, 2).contiguous() - recon = recon * s+ m - res.append(recon) - res = torch.cat(res, dim=0) - - write_to_xml_batch(os.path.join(save_dir, 'xml'), - (res).numpy(), cat='chair') - - export_to_pc_batch(os.path.join(save_dir, 'ply'), - (res).numpy()) - - torch.save(res, os.path.join(save_dir, 'redwood_demo.pth')) - - pcwrite(os.path.join(save_dir, 'ply', 'gt.ply'), - gt_points[np.random.choice(gt_points.shape[0], size=opt.npoints, replace=False)]) - write_to_xml_batch(os.path.join(save_dir, 'xml_gt'), - gt_points[None], cat='chair') - - exit() - -def main(opt): - exp_id = os.path.splitext(os.path.basename(__file__))[0] - dir_id = os.path.dirname(__file__) - output_dir = get_output_dir(dir_id, exp_id) - copy_source(__file__, output_dir) - logger = setup_logging(output_dir) - - outf_syn, = setup_output_subdirs(output_dir, 'syn') - - betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) - netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) - - if opt.cuda: - netE.cuda() - - def _transform_(m): - return nn.parallel.DataParallel(m) - - netE = netE.cuda() - netE.multi_gpu_wrapper(_transform_) - - netE.eval() - - ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] - - with torch.no_grad(): - for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): - - opt.netE = ckpt - logger.info("Resume Path:%s" % opt.netE) - - resumed_param = torch.load(opt.netE) - netE.load_state_dict(resumed_param['model_state']) - - - if opt.eval_recon_mvr: - # Evaluate generation - evaluate_recon_mvr(opt, netE, outf_syn, logger) - - if opt.generate_multimodal: - - generate_multimodal(opt, netE, outf_syn, logger) - - if opt.eval_saved: - evaluate_saved(opt, netE, outf_syn, logger) - - if opt.eval_redwood: - redwood_demo(opt, netE, outf_syn, logger) - -def parse_args(): - - parser = argparse.ArgumentParser() - parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') - parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', - help='input batch size') - parser.add_argument('--classes', default=['chair']) - - parser.add_argument('--batch_size', type=int, default=1, help='input batch size') - parser.add_argument('--workers', type=int, default=16, help='workers') - parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') - - parser.add_argument('--eval_recon_mvr', default=False) - parser.add_argument('--generate_multimodal', default=False) - parser.add_argument('--eval_saved', default=False) - parser.add_argument('--eval_redwood', default=True) - - parser.add_argument('--nc', default=3) - parser.add_argument('--npoints', default=2048) - parser.add_argument('--svpoints', default=200) - '''model''' - parser.add_argument('--beta_start', default=0.0001) - parser.add_argument('--beta_end', default=0.02) - parser.add_argument('--schedule_type', default='linear') - parser.add_argument('--time_num', default=1000) - - #params - parser.add_argument('--attention', default=True) - parser.add_argument('--dropout', default=0.1) - parser.add_argument('--embed_dim', type=int, default=64) - parser.add_argument('--loss_type', default='mse') - parser.add_argument('--model_mean_type', default='eps') - parser.add_argument('--model_var_type', default='fixedsmall') - - - # constrain function - parser.add_argument('--constrain_eps', default=0.2) - parser.add_argument('--constrain_steps', type=int, default=1) - - parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/6_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-11-01-19-21-18', help="path to netE (to continue training)") - - '''eval''' - - parser.add_argument('--eval_path', - default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth') - - parser.add_argument('--manualSeed', default=42, type=int, help='random seed') - - parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') - - opt = parser.parse_args() - - if torch.cuda.is_available(): - opt.cuda = True - else: - opt.cuda = False - - return opt -if __name__ == '__main__': - opt = parse_args() - - main(opt) diff --git a/shape_completion/test_partnet_chair.py b/shape_completion/test_partnet_chair.py deleted file mode 100644 index babb8a0..0000000 --- a/shape_completion/test_partnet_chair.py +++ /dev/null @@ -1,599 +0,0 @@ - -from pprint import pprint -from tqdm import tqdm -import torch.nn as nn -import torch.utils.data - -import argparse -from torch.distributions import Normal -from utils.visualize import * -from utils.file_utils import * -from utils.mitsuba_renderer import write_to_xml_batch -from model.pvcnn_completion import PVCNN2Base - -from datasets.shapenet_data_pc import ShapeNet15kPointClouds -from datasets.partnet import GANdatasetPartNet -import trimesh -import csv -import numpy as np -import random -from plyfile import PlyData, PlyElement - - -def write_ply(points, filename, text=False): - """ input: Nx3, write points to filename as PLY format. """ - points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] - vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) - el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) - with open(filename, mode='wb') as f: - PlyData([el], text=text).write(f) - -''' -models -''' -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - KL divergence between normal distributions parameterized by mean and log-variance. - """ - return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) - + (mean1 - mean2)**2 * torch.exp(-logvar2)) - -def discretized_gaussian_log_likelihood(x, *, means, log_scales): - # Assumes data is integers [0, 1] - assert x.shape == means.shape == log_scales.shape - px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) - - centered_x = x - means - inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_x + 0.5) - cdf_plus = px0.cdf(plus_in) - min_in = inv_stdv * (centered_x - .5) - cdf_min = px0.cdf(min_in) - log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) - log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) - cdf_delta = cdf_plus - cdf_min - - log_probs = torch.where( - x < 0.001, log_cdf_plus, - torch.where(x > 0.999, log_one_minus_cdf_min, - torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) - assert log_probs.shape == x.shape - return log_probs - - - -class GaussianDiffusion: - def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): - self.loss_type = loss_type - self.model_mean_type = model_mean_type - self.model_var_type = model_var_type - assert isinstance(betas, np.ndarray) - self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy - assert (betas > 0).all() and (betas <= 1).all() - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.sv_points = sv_points - # initialize twice the actual length so we can keep running for eval - # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) - - alphas = 1. - betas - alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() - alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() - - self.betas = torch.from_numpy(betas).float() - self.alphas_cumprod = alphas_cumprod.float() - self.alphas_cumprod_prev = alphas_cumprod_prev.float() - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() - self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() - self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() - self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() - - betas = torch.from_numpy(betas).float() - alphas = torch.from_numpy(alphas).float() - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.posterior_variance = posterior_variance - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) - self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) - self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) - - @staticmethod - def _extract(a, t, x_shape): - """ - Extract some coefficients at specified timesteps, - then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. - """ - bs, = t.shape - assert x_shape[0] == bs - out = torch.gather(a, 0, t) - assert out.shape == torch.Size([bs]) - return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) - - - - def q_mean_variance(self, x_start, t): - mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start - variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) - log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) - return mean, variance, log_variance - - def q_sample(self, x_start, t, noise=None): - """ - Diffuse the data (t == 0 means diffused for 1 step) - """ - if noise is None: - noise = torch.randn(x_start.shape, device=x_start.device) - assert noise.shape == x_start.shape - return ( - self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + - self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise - ) - - - def q_posterior_mean_variance(self, x_start, x_t, t): - """ - Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) - """ - assert x_start.shape == x_t.shape - posterior_mean = ( - self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + - self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t - ) - posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) - posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) - assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == - x_start.shape[0]) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - - def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): - - model_output = denoise_fn(data, t)[:,:,self.sv_points:] - - - if self.model_var_type in ['fixedsmall', 'fixedlarge']: - # below: only log_variance is used in the KL computations - model_variance, model_log_variance = { - # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood - 'fixedlarge': (self.betas.to(data.device), - torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), - 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), - }[self.model_var_type] - model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) - model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) - else: - raise NotImplementedError(self.model_var_type) - - if self.model_mean_type == 'eps': - x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) - - - model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) - else: - raise NotImplementedError(self.loss_type) - - - assert model_mean.shape == x_recon.shape - assert model_variance.shape == model_log_variance.shape - if return_pred_xstart: - return model_mean, model_variance, model_log_variance, x_recon - else: - return model_mean, model_variance, model_log_variance - - def _predict_xstart_from_eps(self, x_t, t, eps): - assert x_t.shape == eps.shape - return ( - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps - ) - - ''' samples ''' - - def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): - """ - Sample from the model - """ - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, - return_pred_xstart=True) - noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) - - # no noise when t == 0 - nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) - - sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise - sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) - return (sample, pred_xstart) if return_pred_xstart else sample - - - def p_sample_loop(self, partial_x, denoise_fn, shape, device, - noise_fn=torch.randn, clip_denoised=True, keep_running=False): - """ - Generate samples - keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps - - """ - - assert isinstance(shape, (tuple, list)) - img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) - for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, return_pred_xstart=False) - - assert img_t[:,:,self.sv_points:].shape == shape - return img_t - - def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, - noise_fn=torch.randn,clip_denoised=True, keep_running=False): - """ - Generate samples, returning intermediate images - Useful for visualizing how denoised images evolve over time - Args: - repeat_noise_steps (int): Number of denoising timesteps in which the same noise - is used across the batch. If >= 0, the initial noise is the same for all batch elemements. - """ - assert isinstance(shape, (tuple, list)) - - total_steps = self.num_timesteps if not keep_running else len(self.betas) - - img_t = noise_fn(size=shape, dtype=torch.float, device=device) - imgs = [img_t] - for t in reversed(range(0,total_steps)): - - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, - return_pred_xstart=False) - if t % freq == 0 or t == total_steps-1: - imgs.append(img_t) - - assert imgs[-1].shape == shape - return imgs - - '''losses''' - - def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): - true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( - denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) - - kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) - kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) - - return (kl, pred_xstart) if return_pred_xstart else kl - - def p_losses(self, denoise_fn, data_start, t, noise=None): - """ - Training loss calculation - """ - B, D, N = data_start.shape - assert t.shape == torch.Size([B]) - - if noise is None: - noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) - - data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) - - if self.loss_type == 'mse': - # predict the noise instead of x_start. seems to be weighted naturally like SNR - eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] - - losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) - elif self.loss_type == 'kl': - losses = self._vb_terms_bpd( - denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, - return_pred_xstart=False) - else: - raise NotImplementedError(self.loss_type) - - assert losses.shape == torch.Size([B]) - return losses - - '''debug''' - - def _prior_bpd(self, x_start): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) - qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, - mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) - assert kl_prior.shape == x_start.shape - return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) - - def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - - vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) - for t in reversed(range(T)): - - t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) - # Calculate VLB term at the current timestep - data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) - new_vals_b, pred_xstart = self._vb_terms_bpd( - denoise_fn, data_start=x_start, data_t=data_t, t=t_b, - clip_denoised=clip_denoised, return_pred_xstart=True) - # MSE for progressive prediction loss - assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape - new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) - assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) - # Insert the calculated term into the tensor of all terms - mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() - vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt - mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt - assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) - - prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) - total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b - assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ - total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) - return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() - - -class PVCNN2(PVCNN2Base): - sa_blocks = [ - ((32, 2, 32), (1024, 0.1, 32, (32, 64))), - ((64, 3, 16), (256, 0.2, 32, (64, 128))), - ((128, 3, 8), (64, 0.4, 32, (128, 256))), - (None, (16, 0.8, 32, (256, 256, 512))), - ] - fp_blocks = [ - ((256, 256), (256, 3, 8)), - ((256, 256), (256, 3, 8)), - ((256, 128), (128, 2, 16)), - ((128, 128, 64), (64, 2, 32)), - ] - - def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, - voxel_resolution_multiplier=1): - super().__init__( - num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, - dropout=dropout, extra_feature_channels=extra_feature_channels, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier - ) - - -class Model(nn.Module): - def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): - super(Model, self).__init__() - self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) - - self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, - dropout=args.dropout, extra_feature_channels=0) - - def prior_kl(self, x0): - return self.diffusion._prior_bpd(x0) - - def all_kl(self, x0, clip_denoised=True): - total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) - - return { - 'total_bpd_b': total_bpd_b, - 'terms_bpd': vals_bt, - 'prior_bpd_b': prior_bpd_b, - 'mse_bt':mse_bt - } - - - def _denoise(self, data, t): - B, D,N= data.shape - assert data.dtype == torch.float - assert t.shape == torch.Size([B]) and t.dtype == torch.int64 - - out = self.model(data, t) - - return out - - def get_loss_iter(self, data, noises=None): - B, D, N = data.shape - t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) - - if noises is not None: - noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) - - losses = self.diffusion.p_losses( - denoise_fn=self._denoise, data_start=data, t=t, noise=noises) - assert losses.shape == t.shape == torch.Size([B]) - return losses - - def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, - clip_denoised=True, - keep_running=False): - return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, - clip_denoised=clip_denoised, - keep_running=keep_running) - - - def train(self): - self.model.train() - - def eval(self): - self.model.eval() - - def multi_gpu_wrapper(self, f): - self.model = f(self.model) - -def get_betas(schedule_type, b_start, b_end, time_num): - if schedule_type == 'linear': - betas = np.linspace(b_start, b_end, time_num) - elif schedule_type == 'warm0.1': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.1) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.2': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.2) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.5': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.5) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - else: - raise NotImplementedError(schedule_type) - return betas - - - -############################################################################# -def get_dataset(data_root, npoints, category): - - train_ds = GANdatasetPartNet('test', data_root, category, npoints) - return train_ds - - -def generate_multimodal(opt, netE, save_dir, logger): - test_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - - gt_all = data['real'] - x_all = data['raw'] - - for j in range(5): - x = torch.cat([x_all, gt_all[:, :, opt.svpoints:]], dim=-1) - - recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', - clip_denoised=False).detach().cpu() - - - - - for p, d in enumerate(zip(list(x_all), list(recon), list(data['raw_id']))): - - partial = torch.cat([d[0], d[0][:, 0:1].expand(-1, opt.svpoints)], dim=-1) - rec = d[1] - rid = d[2] - write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d' % (i, p), 'mode_%03d' % j), - (torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy(), cat='chair') - - export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, p), 'mode_%03d' % j), - (torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy()) - - raw_id = rid.split('.')[0] - save_sample_dir = os.path.join(save_dir, "{}".format(raw_id)) - Path(save_sample_dir).mkdir(parents=True, exist_ok=True) - # save input partial shape - if j == 0: - save_path = os.path.join(save_sample_dir, "raw.ply") - write_ply((partial.detach().cpu()).transpose(0, 1).numpy(), save_path) - # save completed shape - save_path = os.path.join(save_sample_dir, "fake-z{}.ply".format(j)) - write_ply((rec.detach().cpu()).transpose(0, 1).numpy(), save_path) - - -def main(opt): - exp_id = os.path.splitext(os.path.basename(__file__))[0] - dir_id = os.path.dirname(__file__) - output_dir = get_output_dir(dir_id, exp_id) - copy_source(__file__, output_dir) - logger = setup_logging(output_dir) - - outf_syn, = setup_output_subdirs(output_dir, 'syn') - - betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) - netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) - - if opt.cuda: - netE.cuda() - - def _transform_(m): - return nn.parallel.DataParallel(m) - - netE = netE.cuda() - netE.multi_gpu_wrapper(_transform_) - - netE.eval() - - ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] - - with torch.no_grad(): - for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): - - opt.netE = ckpt - logger.info("Resume Path:%s" % opt.netE) - - resumed_param = torch.load(opt.netE) - netE.load_state_dict(resumed_param['model_state']) - - - - if opt.generate_multimodal: - - generate_multimodal(opt, netE, outf_syn, logger) - - -def parse_args(): - - parser = argparse.ArgumentParser() - parser.add_argument('--data_root', default='/viscam/u/alexzhou907/01DATA/partnet', - help='input batch size') - - parser.add_argument('--classes', default='Chair') - - parser.add_argument('--batch_size', type=int, default=64, help='input batch size') - parser.add_argument('--workers', type=int, default=16, help='workers') - parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') - - parser.add_argument('--eval_recon_mvr', default=False) - parser.add_argument('--generate_multimodal', default=True) - parser.add_argument('--eval_saved', default=False) - - parser.add_argument('--nc', default=3) - parser.add_argument('--npoints', default=2048) - parser.add_argument('--svpoints', default=1024) - '''model''' - parser.add_argument('--beta_start', default=0.0001) - parser.add_argument('--beta_end', default=0.02) - parser.add_argument('--schedule_type', default='linear') - parser.add_argument('--time_num', default=1000) - - #params - parser.add_argument('--attention', default=True) - parser.add_argument('--dropout', default=0.1) - parser.add_argument('--embed_dim', type=int, default=64) - parser.add_argument('--loss_type', default='mse') - parser.add_argument('--model_mean_type', default='eps') - parser.add_argument('--model_var_type', default='fixedsmall') - - - # constrain function - parser.add_argument('--constrain_eps', default=0.2) - parser.add_argument('--constrain_steps', type=int, default=1) - - parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/epn3d_chair', help="path to netE (to continue training)") - - '''eval''' - - parser.add_argument('--eval_path', - default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth') - - parser.add_argument('--manualSeed', default=42, type=int, help='random seed') - - parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') - - opt = parser.parse_args() - - if torch.cuda.is_available(): - opt.cuda = True - else: - opt.cuda = False - - return opt -if __name__ == '__main__': - opt = parse_args() - - main(opt) diff --git a/shape_completion/test_partnet_table.py b/shape_completion/test_partnet_table.py deleted file mode 100644 index 8efbde5..0000000 --- a/shape_completion/test_partnet_table.py +++ /dev/null @@ -1,599 +0,0 @@ - -from pprint import pprint -from tqdm import tqdm -import torch.nn as nn -import torch.utils.data - -import argparse -from torch.distributions import Normal -from utils.visualize import * -from utils.file_utils import * -from utils.mitsuba_renderer import write_to_xml_batch -from model.pvcnn_completion import PVCNN2Base - -from datasets.shapenet_data_pc import ShapeNet15kPointClouds -from datasets.partnet import GANdatasetPartNet -import trimesh -import csv -import numpy as np -import random -from plyfile import PlyData, PlyElement - - -def write_ply(points, filename, text=False): - """ input: Nx3, write points to filename as PLY format. """ - points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] - vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) - el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) - with open(filename, mode='wb') as f: - PlyData([el], text=text).write(f) - -''' -models -''' -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - KL divergence between normal distributions parameterized by mean and log-variance. - """ - return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) - + (mean1 - mean2)**2 * torch.exp(-logvar2)) - -def discretized_gaussian_log_likelihood(x, *, means, log_scales): - # Assumes data is integers [0, 1] - assert x.shape == means.shape == log_scales.shape - px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) - - centered_x = x - means - inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_x + 0.5) - cdf_plus = px0.cdf(plus_in) - min_in = inv_stdv * (centered_x - .5) - cdf_min = px0.cdf(min_in) - log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) - log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) - cdf_delta = cdf_plus - cdf_min - - log_probs = torch.where( - x < 0.001, log_cdf_plus, - torch.where(x > 0.999, log_one_minus_cdf_min, - torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) - assert log_probs.shape == x.shape - return log_probs - - - -class GaussianDiffusion: - def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): - self.loss_type = loss_type - self.model_mean_type = model_mean_type - self.model_var_type = model_var_type - assert isinstance(betas, np.ndarray) - self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy - assert (betas > 0).all() and (betas <= 1).all() - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.sv_points = sv_points - # initialize twice the actual length so we can keep running for eval - # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) - - alphas = 1. - betas - alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() - alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() - - self.betas = torch.from_numpy(betas).float() - self.alphas_cumprod = alphas_cumprod.float() - self.alphas_cumprod_prev = alphas_cumprod_prev.float() - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() - self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() - self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() - self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() - - betas = torch.from_numpy(betas).float() - alphas = torch.from_numpy(alphas).float() - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.posterior_variance = posterior_variance - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) - self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) - self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) - - @staticmethod - def _extract(a, t, x_shape): - """ - Extract some coefficients at specified timesteps, - then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. - """ - bs, = t.shape - assert x_shape[0] == bs - out = torch.gather(a, 0, t) - assert out.shape == torch.Size([bs]) - return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) - - - - def q_mean_variance(self, x_start, t): - mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start - variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) - log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) - return mean, variance, log_variance - - def q_sample(self, x_start, t, noise=None): - """ - Diffuse the data (t == 0 means diffused for 1 step) - """ - if noise is None: - noise = torch.randn(x_start.shape, device=x_start.device) - assert noise.shape == x_start.shape - return ( - self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + - self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise - ) - - - def q_posterior_mean_variance(self, x_start, x_t, t): - """ - Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) - """ - assert x_start.shape == x_t.shape - posterior_mean = ( - self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + - self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t - ) - posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) - posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) - assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == - x_start.shape[0]) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - - def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): - - model_output = denoise_fn(data, t)[:,:,self.sv_points:] - - - if self.model_var_type in ['fixedsmall', 'fixedlarge']: - # below: only log_variance is used in the KL computations - model_variance, model_log_variance = { - # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood - 'fixedlarge': (self.betas.to(data.device), - torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), - 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), - }[self.model_var_type] - model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) - model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) - else: - raise NotImplementedError(self.model_var_type) - - if self.model_mean_type == 'eps': - x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) - - - model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) - else: - raise NotImplementedError(self.loss_type) - - - assert model_mean.shape == x_recon.shape - assert model_variance.shape == model_log_variance.shape - if return_pred_xstart: - return model_mean, model_variance, model_log_variance, x_recon - else: - return model_mean, model_variance, model_log_variance - - def _predict_xstart_from_eps(self, x_t, t, eps): - assert x_t.shape == eps.shape - return ( - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps - ) - - ''' samples ''' - - def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): - """ - Sample from the model - """ - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, - return_pred_xstart=True) - noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) - - # no noise when t == 0 - nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) - - sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise - sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) - return (sample, pred_xstart) if return_pred_xstart else sample - - - def p_sample_loop(self, partial_x, denoise_fn, shape, device, - noise_fn=torch.randn, clip_denoised=True, keep_running=False): - """ - Generate samples - keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps - - """ - - assert isinstance(shape, (tuple, list)) - img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) - for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, return_pred_xstart=False) - - assert img_t[:,:,self.sv_points:].shape == shape - return img_t - - def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, - noise_fn=torch.randn,clip_denoised=True, keep_running=False): - """ - Generate samples, returning intermediate images - Useful for visualizing how denoised images evolve over time - Args: - repeat_noise_steps (int): Number of denoising timesteps in which the same noise - is used across the batch. If >= 0, the initial noise is the same for all batch elemements. - """ - assert isinstance(shape, (tuple, list)) - - total_steps = self.num_timesteps if not keep_running else len(self.betas) - - img_t = noise_fn(size=shape, dtype=torch.float, device=device) - imgs = [img_t] - for t in reversed(range(0,total_steps)): - - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, - return_pred_xstart=False) - if t % freq == 0 or t == total_steps-1: - imgs.append(img_t) - - assert imgs[-1].shape == shape - return imgs - - '''losses''' - - def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): - true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( - denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) - - kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) - kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) - - return (kl, pred_xstart) if return_pred_xstart else kl - - def p_losses(self, denoise_fn, data_start, t, noise=None): - """ - Training loss calculation - """ - B, D, N = data_start.shape - assert t.shape == torch.Size([B]) - - if noise is None: - noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) - - data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) - - if self.loss_type == 'mse': - # predict the noise instead of x_start. seems to be weighted naturally like SNR - eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] - - losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) - elif self.loss_type == 'kl': - losses = self._vb_terms_bpd( - denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, - return_pred_xstart=False) - else: - raise NotImplementedError(self.loss_type) - - assert losses.shape == torch.Size([B]) - return losses - - '''debug''' - - def _prior_bpd(self, x_start): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) - qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, - mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) - assert kl_prior.shape == x_start.shape - return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) - - def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - - vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) - for t in reversed(range(T)): - - t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) - # Calculate VLB term at the current timestep - data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) - new_vals_b, pred_xstart = self._vb_terms_bpd( - denoise_fn, data_start=x_start, data_t=data_t, t=t_b, - clip_denoised=clip_denoised, return_pred_xstart=True) - # MSE for progressive prediction loss - assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape - new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) - assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) - # Insert the calculated term into the tensor of all terms - mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() - vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt - mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt - assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) - - prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) - total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b - assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ - total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) - return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() - - -class PVCNN2(PVCNN2Base): - sa_blocks = [ - ((32, 2, 32), (1024, 0.1, 32, (32, 64))), - ((64, 3, 16), (256, 0.2, 32, (64, 128))), - ((128, 3, 8), (64, 0.4, 32, (128, 256))), - (None, (16, 0.8, 32, (256, 256, 512))), - ] - fp_blocks = [ - ((256, 256), (256, 3, 8)), - ((256, 256), (256, 3, 8)), - ((256, 128), (128, 2, 16)), - ((128, 128, 64), (64, 2, 32)), - ] - - def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, - voxel_resolution_multiplier=1): - super().__init__( - num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, - dropout=dropout, extra_feature_channels=extra_feature_channels, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier - ) - - -class Model(nn.Module): - def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): - super(Model, self).__init__() - self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) - - self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, - dropout=args.dropout, extra_feature_channels=0) - - def prior_kl(self, x0): - return self.diffusion._prior_bpd(x0) - - def all_kl(self, x0, clip_denoised=True): - total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) - - return { - 'total_bpd_b': total_bpd_b, - 'terms_bpd': vals_bt, - 'prior_bpd_b': prior_bpd_b, - 'mse_bt':mse_bt - } - - - def _denoise(self, data, t): - B, D,N= data.shape - assert data.dtype == torch.float - assert t.shape == torch.Size([B]) and t.dtype == torch.int64 - - out = self.model(data, t) - - return out - - def get_loss_iter(self, data, noises=None): - B, D, N = data.shape - t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) - - if noises is not None: - noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) - - losses = self.diffusion.p_losses( - denoise_fn=self._denoise, data_start=data, t=t, noise=noises) - assert losses.shape == t.shape == torch.Size([B]) - return losses - - def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, - clip_denoised=True, - keep_running=False): - return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, - clip_denoised=clip_denoised, - keep_running=keep_running) - - - def train(self): - self.model.train() - - def eval(self): - self.model.eval() - - def multi_gpu_wrapper(self, f): - self.model = f(self.model) - -def get_betas(schedule_type, b_start, b_end, time_num): - if schedule_type == 'linear': - betas = np.linspace(b_start, b_end, time_num) - elif schedule_type == 'warm0.1': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.1) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.2': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.2) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.5': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.5) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - else: - raise NotImplementedError(schedule_type) - return betas - - - -############################################################################# -def get_dataset(data_root, npoints, category): - - train_ds = GANdatasetPartNet('test', data_root, category, npoints) - return train_ds - - -def generate_multimodal(opt, netE, save_dir, logger): - test_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - - gt_all = data['real'] - x_all = data['raw'] - - for j in range(5): - x = torch.cat([x_all, gt_all[:, :, opt.svpoints:]], dim=-1) - - recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', - clip_denoised=False).detach().cpu() - - - - - for p, d in enumerate(zip(list(x_all), list(recon), list(data['raw_id']))): - - partial = torch.cat([d[0], d[0][:, 0:1].expand(-1, opt.svpoints)], dim=-1) - rec = d[1] - rid = d[2] - write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d' % (i, p), 'mode_%03d' % j), - (torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy(), cat='chair') - - export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, p), 'mode_%03d' % j), - (torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy()) - - raw_id = rid.split('.')[0] - save_sample_dir = os.path.join(save_dir, "{}".format(raw_id)) - Path(save_sample_dir).mkdir(parents=True, exist_ok=True) - # save input partial shape - if j == 0: - save_path = os.path.join(save_sample_dir, "raw.ply") - write_ply((partial.detach().cpu()).transpose(0, 1).numpy(), save_path) - # save completed shape - save_path = os.path.join(save_sample_dir, "fake-z{}.ply".format(j)) - write_ply((rec.detach().cpu()).transpose(0, 1).numpy(), save_path) - - -def main(opt): - exp_id = os.path.splitext(os.path.basename(__file__))[0] - dir_id = os.path.dirname(__file__) - output_dir = get_output_dir(dir_id, exp_id) - copy_source(__file__, output_dir) - logger = setup_logging(output_dir) - - outf_syn, = setup_output_subdirs(output_dir, 'syn') - - betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) - netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) - - if opt.cuda: - netE.cuda() - - def _transform_(m): - return nn.parallel.DataParallel(m) - - netE = netE.cuda() - netE.multi_gpu_wrapper(_transform_) - - netE.eval() - - ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] - - with torch.no_grad(): - for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): - - opt.netE = ckpt - logger.info("Resume Path:%s" % opt.netE) - - resumed_param = torch.load(opt.netE) - netE.load_state_dict(resumed_param['model_state']) - - - - if opt.generate_multimodal: - - generate_multimodal(opt, netE, outf_syn, logger) - - -def parse_args(): - - parser = argparse.ArgumentParser() - parser.add_argument('--data_root', default='/viscam/u/alexzhou907/01DATA/partnet', - help='input batch size') - - parser.add_argument('--classes', default='Table') - - parser.add_argument('--batch_size', type=int, default=64, help='input batch size') - parser.add_argument('--workers', type=int, default=16, help='workers') - parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') - - parser.add_argument('--eval_recon_mvr', default=False) - parser.add_argument('--generate_multimodal', default=True) - parser.add_argument('--eval_saved', default=False) - - parser.add_argument('--nc', default=3) - parser.add_argument('--npoints', default=2048) - parser.add_argument('--svpoints', default=1024) - '''model''' - parser.add_argument('--beta_start', default=0.0001) - parser.add_argument('--beta_end', default=0.02) - parser.add_argument('--schedule_type', default='linear') - parser.add_argument('--time_num', default=1000) - - #params - parser.add_argument('--attention', default=True) - parser.add_argument('--dropout', default=0.1) - parser.add_argument('--embed_dim', type=int, default=64) - parser.add_argument('--loss_type', default='mse') - parser.add_argument('--model_mean_type', default='eps') - parser.add_argument('--model_var_type', default='fixedsmall') - - - # constrain function - parser.add_argument('--constrain_eps', default=0.2) - parser.add_argument('--constrain_steps', type=int, default=1) - - parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/epn3d_chair', help="path to netE (to continue training)") - - '''eval''' - - parser.add_argument('--eval_path', - default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth') - - parser.add_argument('--manualSeed', default=42, type=int, help='random seed') - - parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') - - opt = parser.parse_args() - - if torch.cuda.is_available(): - opt.cuda = True - else: - opt.cuda = False - - return opt -if __name__ == '__main__': - opt = parse_args() - - main(opt) diff --git a/shape_completion/test_plane.py b/shape_completion/test_plane.py deleted file mode 100644 index 5981c3e..0000000 --- a/shape_completion/test_plane.py +++ /dev/null @@ -1,681 +0,0 @@ - -from pprint import pprint -from metrics.evaluation_metrics import compute_all_metrics, EMD_CD - -import torch.nn as nn -import torch.utils.data - -import argparse -from torch.distributions import Normal - -from utils.file_utils import * -from utils.mitsuba_renderer import write_to_xml_batch -from model.pvcnn_completion import PVCNN2Base - -from datasets.shapenet_data_pc import ShapeNet15kPointClouds -from datasets.shapenet_data_sv import * -''' -models -''' -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - KL divergence between normal distributions parameterized by mean and log-variance. - """ - return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) - + (mean1 - mean2)**2 * torch.exp(-logvar2)) - -def discretized_gaussian_log_likelihood(x, *, means, log_scales): - # Assumes data is integers [0, 1] - assert x.shape == means.shape == log_scales.shape - px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) - - centered_x = x - means - inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_x + 0.5) - cdf_plus = px0.cdf(plus_in) - min_in = inv_stdv * (centered_x - .5) - cdf_min = px0.cdf(min_in) - log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) - log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) - cdf_delta = cdf_plus - cdf_min - - log_probs = torch.where( - x < 0.001, log_cdf_plus, - torch.where(x > 0.999, log_one_minus_cdf_min, - torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) - assert log_probs.shape == x.shape - return log_probs - - - -class GaussianDiffusion: - def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): - self.loss_type = loss_type - self.model_mean_type = model_mean_type - self.model_var_type = model_var_type - assert isinstance(betas, np.ndarray) - self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy - assert (betas > 0).all() and (betas <= 1).all() - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.sv_points = sv_points - # initialize twice the actual length so we can keep running for eval - # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) - - alphas = 1. - betas - alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() - alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() - - self.betas = torch.from_numpy(betas).float() - self.alphas_cumprod = alphas_cumprod.float() - self.alphas_cumprod_prev = alphas_cumprod_prev.float() - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() - self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() - self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() - self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() - - betas = torch.from_numpy(betas).float() - alphas = torch.from_numpy(alphas).float() - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.posterior_variance = posterior_variance - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) - self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) - self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) - - @staticmethod - def _extract(a, t, x_shape): - """ - Extract some coefficients at specified timesteps, - then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. - """ - bs, = t.shape - assert x_shape[0] == bs - out = torch.gather(a, 0, t) - assert out.shape == torch.Size([bs]) - return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) - - - - def q_mean_variance(self, x_start, t): - mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start - variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) - log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) - return mean, variance, log_variance - - def q_sample(self, x_start, t, noise=None): - """ - Diffuse the data (t == 0 means diffused for 1 step) - """ - if noise is None: - noise = torch.randn(x_start.shape, device=x_start.device) - assert noise.shape == x_start.shape - return ( - self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + - self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise - ) - - - def q_posterior_mean_variance(self, x_start, x_t, t): - """ - Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) - """ - assert x_start.shape == x_t.shape - posterior_mean = ( - self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + - self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t - ) - posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) - posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) - assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == - x_start.shape[0]) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - - def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): - - model_output = denoise_fn(data, t)[:,:,self.sv_points:] - - - if self.model_var_type in ['fixedsmall', 'fixedlarge']: - # below: only log_variance is used in the KL computations - model_variance, model_log_variance = { - # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood - 'fixedlarge': (self.betas.to(data.device), - torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), - 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), - }[self.model_var_type] - model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) - model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) - else: - raise NotImplementedError(self.model_var_type) - - if self.model_mean_type == 'eps': - x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) - - - model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) - else: - raise NotImplementedError(self.loss_type) - - - assert model_mean.shape == x_recon.shape - assert model_variance.shape == model_log_variance.shape - if return_pred_xstart: - return model_mean, model_variance, model_log_variance, x_recon - else: - return model_mean, model_variance, model_log_variance - - def _predict_xstart_from_eps(self, x_t, t, eps): - assert x_t.shape == eps.shape - return ( - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps - ) - - ''' samples ''' - - def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): - """ - Sample from the model - """ - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, - return_pred_xstart=True) - noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) - - # no noise when t == 0 - nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) - - sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise - sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) - return (sample, pred_xstart) if return_pred_xstart else sample - - - def p_sample_loop(self, partial_x, denoise_fn, shape, device, - noise_fn=torch.randn, clip_denoised=True, keep_running=False): - """ - Generate samples - keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps - - """ - - assert isinstance(shape, (tuple, list)) - img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) - for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, return_pred_xstart=False) - - assert img_t[:,:,self.sv_points:].shape == shape - return img_t - - def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, - noise_fn=torch.randn,clip_denoised=True, keep_running=False): - """ - Generate samples, returning intermediate images - Useful for visualizing how denoised images evolve over time - Args: - repeat_noise_steps (int): Number of denoising timesteps in which the same noise - is used across the batch. If >= 0, the initial noise is the same for all batch elemements. - """ - assert isinstance(shape, (tuple, list)) - - total_steps = self.num_timesteps if not keep_running else len(self.betas) - - img_t = noise_fn(size=shape, dtype=torch.float, device=device) - imgs = [img_t] - for t in reversed(range(0,total_steps)): - - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, - return_pred_xstart=False) - if t % freq == 0 or t == total_steps-1: - imgs.append(img_t) - - assert imgs[-1].shape == shape - return imgs - - '''losses''' - - def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): - true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( - denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) - - kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) - kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) - - return (kl, pred_xstart) if return_pred_xstart else kl - - def p_losses(self, denoise_fn, data_start, t, noise=None): - """ - Training loss calculation - """ - B, D, N = data_start.shape - assert t.shape == torch.Size([B]) - - if noise is None: - noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) - - data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) - - if self.loss_type == 'mse': - # predict the noise instead of x_start. seems to be weighted naturally like SNR - eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] - - losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) - elif self.loss_type == 'kl': - losses = self._vb_terms_bpd( - denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, - return_pred_xstart=False) - else: - raise NotImplementedError(self.loss_type) - - assert losses.shape == torch.Size([B]) - return losses - - '''debug''' - - def _prior_bpd(self, x_start): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) - qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, - mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) - assert kl_prior.shape == x_start.shape - return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) - - def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - - vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) - for t in reversed(range(T)): - - t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) - # Calculate VLB term at the current timestep - data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) - new_vals_b, pred_xstart = self._vb_terms_bpd( - denoise_fn, data_start=x_start, data_t=data_t, t=t_b, - clip_denoised=clip_denoised, return_pred_xstart=True) - # MSE for progressive prediction loss - assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape - new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) - assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) - # Insert the calculated term into the tensor of all terms - mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() - vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt - mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt - assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) - - prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) - total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b - assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ - total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) - return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() - - -class PVCNN2(PVCNN2Base): - sa_blocks = [ - ((32, 2, 32), (1024, 0.1, 32, (32, 64))), - ((64, 3, 16), (256, 0.2, 32, (64, 128))), - ((128, 3, 8), (64, 0.4, 32, (128, 256))), - (None, (16, 0.8, 32, (256, 256, 512))), - ] - fp_blocks = [ - ((256, 256), (256, 3, 8)), - ((256, 256), (256, 3, 8)), - ((256, 128), (128, 2, 16)), - ((128, 128, 64), (64, 2, 32)), - ] - - def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, - voxel_resolution_multiplier=1): - super().__init__( - num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, - dropout=dropout, extra_feature_channels=extra_feature_channels, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier - ) - - -class Model(nn.Module): - def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): - super(Model, self).__init__() - self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) - - self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, - dropout=args.dropout, extra_feature_channels=0) - - def prior_kl(self, x0): - return self.diffusion._prior_bpd(x0) - - def all_kl(self, x0, clip_denoised=True): - total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) - - return { - 'total_bpd_b': total_bpd_b, - 'terms_bpd': vals_bt, - 'prior_bpd_b': prior_bpd_b, - 'mse_bt':mse_bt - } - - - def _denoise(self, data, t): - B, D,N= data.shape - assert data.dtype == torch.float - assert t.shape == torch.Size([B]) and t.dtype == torch.int64 - - out = self.model(data, t) - - return out - - def get_loss_iter(self, data, noises=None): - B, D, N = data.shape - t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) - - if noises is not None: - noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) - - losses = self.diffusion.p_losses( - denoise_fn=self._denoise, data_start=data, t=t, noise=noises) - assert losses.shape == t.shape == torch.Size([B]) - return losses - - def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, - clip_denoised=True, - keep_running=False): - return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, - clip_denoised=clip_denoised, - keep_running=keep_running) - - - def train(self): - self.model.train() - - def eval(self): - self.model.eval() - - def multi_gpu_wrapper(self, f): - self.model = f(self.model) - -def get_betas(schedule_type, b_start, b_end, time_num): - if schedule_type == 'linear': - betas = np.linspace(b_start, b_end, time_num) - elif schedule_type == 'warm0.1': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.1) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.2': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.2) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.5': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.5) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - else: - raise NotImplementedError(schedule_type) - return betas - - - -############################################################################# -def get_svr_dataset(pc_dataroot, mesh_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, - cache=os.path.join(mesh_root, '../cache'), split='val', - categories=category, - radius=3, elev=0, azim=0, img_size=512, focal_length=1000, - npoints=npoints, sv_samples=200, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - ) - return te_dataset - -def get_mvr_dataset(pc_dataroot, views_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, - cache=os.path.join(pc_dataroot, '../cache'), split='val', - categories=category, - npoints=npoints, sv_samples=200, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - ) - return te_dataset - - -def evaluate_recon_mvr(opt, netE, save_dir, logger): - test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, - opt.npoints, opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - ref = [] - samples = [] - images = [] - masked = [] - k = 0 - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - - gt_all = data['test_points'] - x_all = data['sv_points'] - img_all = data['image'] - - - B,V,N,C = x_all.shape - gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1) - - x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous() - img = img_all.reshape(B * V, *img_all.shape[2:]) - - m, s = data['mean'].float(), data['std'].float() - - recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', - clip_denoised=False).detach().cpu() - - recon = recon.transpose(1, 2).contiguous() - x = x.transpose(1, 2).contiguous() - - - x_adj = x.reshape(B,V,N,C)* s + m - recon_adj = recon.reshape(B,V,N,C)* s + m - img = img.reshape(B,V,*img.shape[1:]) - - ref.append( gt_all * s + m) - masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) - samples.append(recon_adj) - images.append(img) - - ref_pcs = torch.cat(ref, dim=0) - sample_pcs = torch.cat(samples, dim=0) - images = torch.cat(images, dim=0) - masked = torch.cat(masked, dim=0) - - B, V, N, C = ref_pcs.shape - - - torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth')) - # torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) - torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) - # Compute metrics - results = EMD_CD(sample_pcs.reshape(B*V, N, C), - ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) - - results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} - - pprint({key: val.mean().item() for key, val in results.items()}) - logger.info({key: val.mean().item() for key, val in results.items()}) - - results['pc'] = sample_pcs - torch.save(results, os.path.join(save_dir, 'ours_results.pth')) - - - -def generate_multimodal(opt, netE, save_dir, logger): - test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv, - opt.npoints, opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - - gt_all = data['test_points'] - x_all = data['sv_points'] - img_all = data['image'] - - Path(os.path.join(save_dir, 'x_%03d'%k)).mkdir(exist_ok=True, parents=True) - Path(os.path.join(save_dir, 'x_ply_%03d' % k)).mkdir(exist_ok=True, parents=True) - - for v in range(5): - x = x_all.transpose(1, 2).contiguous() - img = img_all - - m, s = data['mean'].float(), data['std'].float() - - recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', - clip_denoised=False).detach().cpu() - - recon = recon.transpose(1, 2).contiguous() - x = x.transpose(1, 2).contiguous() - - for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))): - - im = np.fliplr(np.flipud(d[-1])) - plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray') - write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - - export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - - -def main(opt): - exp_id = os.path.splitext(os.path.basename(__file__))[0] - dir_id = os.path.dirname(__file__) - output_dir = get_output_dir(dir_id, exp_id) - copy_source(__file__, output_dir) - logger = setup_logging(output_dir) - - outf_syn, = setup_output_subdirs(output_dir, 'syn') - - betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) - netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) - - if opt.cuda: - netE.cuda() - - def _transform_(m): - return nn.parallel.DataParallel(m) - - netE = netE.cuda() - netE.multi_gpu_wrapper(_transform_) - - netE.eval() - - ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] - - with torch.no_grad(): - for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): - - opt.netE = ckpt - logger.info("Resume Path:%s" % opt.netE) - - resumed_param = torch.load(opt.netE) - netE.load_state_dict(resumed_param['model_state']) - - - if opt.eval_recon_mvr: - # Evaluate generation - evaluate_recon_mvr(opt, netE, outf_syn, logger) - - if opt.generate_multimodal: - - generate_multimodal(opt, netE, outf_syn, logger) - - -def parse_args(): - - parser = argparse.ArgumentParser() - parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') - parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', - help='input batch size') - parser.add_argument('--classes', default=['airplane']) - - parser.add_argument('--batch_size', type=int, default=8, help='input batch size') - parser.add_argument('--workers', type=int, default=16, help='workers') - parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') - - parser.add_argument('--eval_recon_mvr', default=True) - parser.add_argument('--generate_multimodal', default=False) - - parser.add_argument('--nc', default=3) - parser.add_argument('--npoints', default=2048) - parser.add_argument('--svpoints', default=200) - '''model''' - parser.add_argument('--beta_start', default=0.0001) - parser.add_argument('--beta_end', default=0.02) - parser.add_argument('--schedule_type', default='linear') - parser.add_argument('--time_num', default=1000) - - #params - parser.add_argument('--attention', default=True) - parser.add_argument('--dropout', default=0.1) - parser.add_argument('--embed_dim', type=int, default=64) - parser.add_argument('--loss_type', default='mse') - parser.add_argument('--model_mean_type', default='eps') - parser.add_argument('--model_var_type', default='fixedsmall') - - - # constrain function - parser.add_argument('--constrain_eps', default=0.2) - parser.add_argument('--constrain_steps', type=int, default=1) - - parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/airplane_ckpt/', help="path to netE (to continue training)") - - '''eval''' - - parser.add_argument('--eval_path', - default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth') - - parser.add_argument('--manualSeed', default=42, type=int, help='random seed') - - parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') - - opt = parser.parse_args() - - if torch.cuda.is_available(): - opt.cuda = True - else: - opt.cuda = False - - return opt -if __name__ == '__main__': - opt = parse_args() - set_seed(opt) - - main(opt) diff --git a/shape_completion/test_table.py b/shape_completion/test_table.py deleted file mode 100644 index 8fb92be..0000000 --- a/shape_completion/test_table.py +++ /dev/null @@ -1,764 +0,0 @@ - -from pprint import pprint -from metrics.evaluation_metrics import compute_all_metrics, EMD_CD - -import torch.nn as nn -import torch.utils.data - -import argparse -from torch.distributions import Normal -from utils.visualize import * -from utils.file_utils import * -from utils.mitsuba_renderer import write_to_xml_batch -from model.pvcnn_completion import PVCNN2Base - -from datasets.shapenet_data_pc import ShapeNet15kPointClouds -from datasets.shapenet_data_sv import * -''' -models -''' -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - KL divergence between normal distributions parameterized by mean and log-variance. - """ - return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) - + (mean1 - mean2)**2 * torch.exp(-logvar2)) - -def discretized_gaussian_log_likelihood(x, *, means, log_scales): - # Assumes data is integers [0, 1] - assert x.shape == means.shape == log_scales.shape - px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) - - centered_x = x - means - inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_x + 0.5) - cdf_plus = px0.cdf(plus_in) - min_in = inv_stdv * (centered_x - .5) - cdf_min = px0.cdf(min_in) - log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) - log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) - cdf_delta = cdf_plus - cdf_min - - log_probs = torch.where( - x < 0.001, log_cdf_plus, - torch.where(x > 0.999, log_one_minus_cdf_min, - torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) - assert log_probs.shape == x.shape - return log_probs - - - -class GaussianDiffusion: - def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): - self.loss_type = loss_type - self.model_mean_type = model_mean_type - self.model_var_type = model_var_type - assert isinstance(betas, np.ndarray) - self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy - assert (betas > 0).all() and (betas <= 1).all() - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.sv_points = sv_points - # initialize twice the actual length so we can keep running for eval - # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) - - alphas = 1. - betas - alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() - alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() - - self.betas = torch.from_numpy(betas).float() - self.alphas_cumprod = alphas_cumprod.float() - self.alphas_cumprod_prev = alphas_cumprod_prev.float() - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() - self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() - self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() - self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() - - betas = torch.from_numpy(betas).float() - alphas = torch.from_numpy(alphas).float() - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.posterior_variance = posterior_variance - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) - self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) - self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) - - @staticmethod - def _extract(a, t, x_shape): - """ - Extract some coefficients at specified timesteps, - then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. - """ - bs, = t.shape - assert x_shape[0] == bs - out = torch.gather(a, 0, t) - assert out.shape == torch.Size([bs]) - return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) - - - - def q_mean_variance(self, x_start, t): - mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start - variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) - log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) - return mean, variance, log_variance - - def q_sample(self, x_start, t, noise=None): - """ - Diffuse the data (t == 0 means diffused for 1 step) - """ - if noise is None: - noise = torch.randn(x_start.shape, device=x_start.device) - assert noise.shape == x_start.shape - return ( - self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + - self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise - ) - - - def q_posterior_mean_variance(self, x_start, x_t, t): - """ - Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) - """ - assert x_start.shape == x_t.shape - posterior_mean = ( - self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + - self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t - ) - posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) - posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) - assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == - x_start.shape[0]) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - - def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): - - model_output = denoise_fn(data, t)[:,:,self.sv_points:] - - - if self.model_var_type in ['fixedsmall', 'fixedlarge']: - # below: only log_variance is used in the KL computations - model_variance, model_log_variance = { - # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood - 'fixedlarge': (self.betas.to(data.device), - torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), - 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), - }[self.model_var_type] - model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) - model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) - else: - raise NotImplementedError(self.model_var_type) - - if self.model_mean_type == 'eps': - x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) - - - model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) - else: - raise NotImplementedError(self.loss_type) - - - assert model_mean.shape == x_recon.shape - assert model_variance.shape == model_log_variance.shape - if return_pred_xstart: - return model_mean, model_variance, model_log_variance, x_recon - else: - return model_mean, model_variance, model_log_variance - - def _predict_xstart_from_eps(self, x_t, t, eps): - assert x_t.shape == eps.shape - return ( - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps - ) - - ''' samples ''' - - def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): - """ - Sample from the model - """ - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, - return_pred_xstart=True) - noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) - - # no noise when t == 0 - nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) - - sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise - sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) - return (sample, pred_xstart) if return_pred_xstart else sample - - - def p_sample_loop(self, partial_x, denoise_fn, shape, device, - noise_fn=torch.randn, clip_denoised=True, keep_running=False): - """ - Generate samples - keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps - - """ - - assert isinstance(shape, (tuple, list)) - img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) - for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, return_pred_xstart=False) - - assert img_t[:,:,self.sv_points:].shape == shape - return img_t - - def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, - noise_fn=torch.randn,clip_denoised=True, keep_running=False): - """ - Generate samples, returning intermediate images - Useful for visualizing how denoised images evolve over time - Args: - repeat_noise_steps (int): Number of denoising timesteps in which the same noise - is used across the batch. If >= 0, the initial noise is the same for all batch elemements. - """ - assert isinstance(shape, (tuple, list)) - - total_steps = self.num_timesteps if not keep_running else len(self.betas) - - img_t = noise_fn(size=shape, dtype=torch.float, device=device) - imgs = [img_t] - for t in reversed(range(0,total_steps)): - - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, - return_pred_xstart=False) - if t % freq == 0 or t == total_steps-1: - imgs.append(img_t) - - assert imgs[-1].shape == shape - return imgs - - '''losses''' - - def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): - true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( - denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) - - kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) - kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) - - return (kl, pred_xstart) if return_pred_xstart else kl - - def p_losses(self, denoise_fn, data_start, t, noise=None): - """ - Training loss calculation - """ - B, D, N = data_start.shape - assert t.shape == torch.Size([B]) - - if noise is None: - noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) - - data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) - - if self.loss_type == 'mse': - # predict the noise instead of x_start. seems to be weighted naturally like SNR - eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] - - losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) - elif self.loss_type == 'kl': - losses = self._vb_terms_bpd( - denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, - return_pred_xstart=False) - else: - raise NotImplementedError(self.loss_type) - - assert losses.shape == torch.Size([B]) - return losses - - '''debug''' - - def _prior_bpd(self, x_start): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) - qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, - mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) - assert kl_prior.shape == x_start.shape - return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) - - def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): - - with torch.no_grad(): - B, T = x_start.shape[0], self.num_timesteps - - vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) - for t in reversed(range(T)): - - t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) - # Calculate VLB term at the current timestep - data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) - new_vals_b, pred_xstart = self._vb_terms_bpd( - denoise_fn, data_start=x_start, data_t=data_t, t=t_b, - clip_denoised=clip_denoised, return_pred_xstart=True) - # MSE for progressive prediction loss - assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape - new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) - assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) - # Insert the calculated term into the tensor of all terms - mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() - vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt - mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt - assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) - - prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) - total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b - assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ - total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) - return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() - - -class PVCNN2(PVCNN2Base): - sa_blocks = [ - ((32, 2, 32), (1024, 0.1, 32, (32, 64))), - ((64, 3, 16), (256, 0.2, 32, (64, 128))), - ((128, 3, 8), (64, 0.4, 32, (128, 256))), - (None, (16, 0.8, 32, (256, 256, 512))), - ] - fp_blocks = [ - ((256, 256), (256, 3, 8)), - ((256, 256), (256, 3, 8)), - ((256, 128), (128, 2, 16)), - ((128, 128, 64), (64, 2, 32)), - ] - - def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, - voxel_resolution_multiplier=1): - super().__init__( - num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, - dropout=dropout, extra_feature_channels=extra_feature_channels, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier - ) - - -class Model(nn.Module): - def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): - super(Model, self).__init__() - self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) - - self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, - dropout=args.dropout, extra_feature_channels=0) - - def prior_kl(self, x0): - return self.diffusion._prior_bpd(x0) - - def all_kl(self, x0, clip_denoised=True): - total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) - - return { - 'total_bpd_b': total_bpd_b, - 'terms_bpd': vals_bt, - 'prior_bpd_b': prior_bpd_b, - 'mse_bt':mse_bt - } - - - def _denoise(self, data, t): - B, D,N= data.shape - assert data.dtype == torch.float - assert t.shape == torch.Size([B]) and t.dtype == torch.int64 - - out = self.model(data, t) - - return out - - def get_loss_iter(self, data, noises=None): - B, D, N = data.shape - t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) - - if noises is not None: - noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) - - losses = self.diffusion.p_losses( - denoise_fn=self._denoise, data_start=data, t=t, noise=noises) - assert losses.shape == t.shape == torch.Size([B]) - return losses - - def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, - clip_denoised=True, - keep_running=False): - return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, - clip_denoised=clip_denoised, - keep_running=keep_running) - - - def train(self): - self.model.train() - - def eval(self): - self.model.eval() - - def multi_gpu_wrapper(self, f): - self.model = f(self.model) - -def get_betas(schedule_type, b_start, b_end, time_num): - if schedule_type == 'linear': - betas = np.linspace(b_start, b_end, time_num) - elif schedule_type == 'warm0.1': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.1) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.2': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.2) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.5': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.5) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - else: - raise NotImplementedError(schedule_type) - return betas - - - -############################################################################# -def get_pc_dataset(pc_dataroot, mesh_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - return tr_dataset - -def get_svr_dataset(pc_dataroot, mesh_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, - cache=os.path.join(mesh_root, '../cache'), split='val', - categories=category, - radius=3, elev=-89, azim=180, img_size=512, focal_length=1000, - npoints=npoints, sv_samples=200, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - ) - return te_dataset - -def get_mvr_dataset(pc_dataroot, views_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, - cache=os.path.join(pc_dataroot, '../cache'), split='val', - categories=category, - npoints=npoints, sv_samples=200, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - ) - return te_dataset - - -def evaluate_recon_mvr(opt, netE, save_dir, logger): - test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, - opt.npoints, opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - ref = [] - samples = [] - images = [] - masked = [] - k = 0 - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - - gt_all = data['test_points'] - x_all = data['sv_points'] - # img_all = data['image'] - - - B,V,N,C = x_all.shape - gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1) - - x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous() - # img = img_all.reshape(B * V, *img_all.shape[2:]) - - m, s = data['mean'].float(), data['std'].float() - - recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', - clip_denoised=False).detach().cpu() - - recon = recon.transpose(1, 2).contiguous() - x = x.transpose(1, 2).contiguous() - - - x_adj = x.reshape(B,V,N,C)* s + m - recon_adj = recon.reshape(B,V,N,C)* s + m - # img = img.reshape(B,V,*img.shape[1:]) - - ref.append( gt_all * s + m) - masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) - samples.append(recon_adj) - # images.append(img) - - ref_pcs = torch.cat(ref, dim=0) - sample_pcs = torch.cat(samples, dim=0) - # images = torch.cat(images, dim=0) - masked = torch.cat(masked, dim=0) - - B, V, N, C = ref_pcs.shape - - - torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth')) - # torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) - torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) - # Compute metrics - results = EMD_CD(sample_pcs.reshape(B*V, N, C), - ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) - - results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} - - pprint({key: val.mean().item() for key, val in results.items()}) - logger.info({key: val.mean().item() for key, val in results.items()}) - - results['pc'] = sample_pcs - torch.save(results, os.path.join(save_dir, 'ours_results.pth')) - - del ref_pcs, masked, results - -def evaluate_saved(opt, netE, save_dir, logger): - ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn' - - gt_pth = ours_base + '/recon_gt.pth' - ours_pth = ours_base + '/ours_results.pth' - gt = torch.load(gt_pth).permute(1,0,2,3) - ours = torch.load(ours_pth)['pc'].permute(1,0,2,3) - - all_res = {} - for i, (gt_, ours_) in enumerate(zip(gt, ours)): - results = compute_all_metrics(gt_, ours_, opt.batch_size) - - for key, val in results.items(): - if i == 0: - all_res[key] = val - else: - all_res[key] += val - pprint(results) - for key, val in all_res.items(): - all_res[key] = val / gt.shape[0] - - pprint({key: val.mean().item() for key, val in all_res.items()}) - - -def generate_multimodal(opt, netE, save_dir, logger): - test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv, - opt.npoints, opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - - gt_all = data['test_points'] - x_all = data['sv_points'] - img_all = data['image'] - - for v in range(6): - x = x_all.transpose(1, 2).contiguous() - img = img_all - - m, s = data['mean'].float(), data['std'].float() - - recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', - clip_denoised=False).detach().cpu() - - recon = recon.transpose(1, 2).contiguous() - x = x.transpose(1, 2).contiguous() - - for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))): - - im = np.fliplr(np.flipud(d[-1])) - plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray') - write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy(), cat='chair') - - export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - - -def redwood_demo(opt, netE, save_dir, logger): - import open3d as o3d - pth = "/viscam/u/alexzhou907/01DATA/redwood/01605_sample_1.ply" - pth_gt = "/viscam/u/alexzhou907/01DATA/redwood/01605_pc_gt.ply" - - points = np.asarray(o3d.io.read_point_cloud(pth).points) - - gt_points = np.asarray(o3d.io.read_point_cloud(pth_gt).points) - - np.save('gt.npy', gt_points[np.random.choice(gt_points.shape[0], size=opt.npoints, replace=False)]) - - write_to_xml_batch(os.path.join(save_dir, 'xml_gt'), - gt_points[np.random.choice(gt_points.shape[0], size=opt.npoints, replace=False)][None], cat='table') - - test_dataset = get_pc_dataset(opt.dataroot_pc, opt.dataroot_sv, - opt.npoints, opt.classes) - - m, s = torch.from_numpy(test_dataset[0]['mean']).float(), torch.from_numpy(test_dataset[0]['std']).float() - - x = torch.from_numpy(points[np.random.choice(points.shape[0], size=opt.npoints, replace=False)]).float() - - x = (x-m)/s - - - x = x[None].transpose(1,2).cuda() - - res = [] - for k in range(20): - recon = netE.gen_samples(x[:,:,:opt.svpoints], x[:,:,opt.svpoints:].shape, 'cuda', - clip_denoised=False).detach().cpu() - recon = recon.transpose(1, 2).contiguous() - recon = recon * s+ m - res.append(recon) - res = torch.cat(res, dim=0) - - write_to_xml_batch(os.path.join(save_dir, 'xml'), - (res).numpy(), cat='table') - - export_to_pc_batch(os.path.join(save_dir, 'ply'), - (res).numpy()) - - torch.save(res, os.path.join(save_dir, 'redwood_demo.pth')) - - exit() - -def main(opt): - exp_id = os.path.splitext(os.path.basename(__file__))[0] - dir_id = os.path.dirname(__file__) - output_dir = get_output_dir(dir_id, exp_id) - copy_source(__file__, output_dir) - logger = setup_logging(output_dir) - - outf_syn, = setup_output_subdirs(output_dir, 'syn') - - betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) - netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) - - if opt.cuda: - netE.cuda() - - def _transform_(m): - return nn.parallel.DataParallel(m) - - netE = netE.cuda() - netE.multi_gpu_wrapper(_transform_) - - netE.eval() - - ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] - - with torch.no_grad(): - for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): - - opt.netE = ckpt - logger.info("Resume Path:%s" % opt.netE) - - resumed_param = torch.load(opt.netE) - netE.load_state_dict(resumed_param['model_state']) - - - if opt.eval_recon_mvr: - # Evaluate generation - evaluate_recon_mvr(opt, netE, outf_syn, logger) - - if opt.generate_multimodal: - - generate_multimodal(opt, netE, outf_syn, logger) - - if opt.eval_saved: - evaluate_saved(opt, netE, outf_syn, logger) - - if opt.eval_redwood: - redwood_demo(opt, netE, outf_syn, logger) - -def parse_args(): - - parser = argparse.ArgumentParser() - parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') - parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', - help='input batch size') - parser.add_argument('--classes', default=['table']) - - parser.add_argument('--batch_size', type=int, default=1, help='input batch size') - parser.add_argument('--workers', type=int, default=16, help='workers') - parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') - - parser.add_argument('--eval_recon_mvr', default=False) - parser.add_argument('--generate_multimodal', default=False) - parser.add_argument('--eval_saved', default=False) - parser.add_argument('--eval_redwood', default=True) - - parser.add_argument('--nc', default=3) - parser.add_argument('--npoints', default=2048) - parser.add_argument('--svpoints', default=200) - '''model''' - parser.add_argument('--beta_start', default=0.0001) - parser.add_argument('--beta_end', default=0.02) - parser.add_argument('--schedule_type', default='linear') - parser.add_argument('--time_num', default=1000) - - #params - parser.add_argument('--attention', default=True) - parser.add_argument('--dropout', default=0.1) - parser.add_argument('--embed_dim', type=int, default=64) - parser.add_argument('--loss_type', default='mse') - parser.add_argument('--model_mean_type', default='eps') - parser.add_argument('--model_var_type', default='fixedsmall') - - - # constrain function - parser.add_argument('--constrain_eps', default=0.2) - parser.add_argument('--constrain_steps', type=int, default=1) - - parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/9_res32_pc_table_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-12-16-14-09-50', help="path to netE (to continue training)") - - '''eval''' - - parser.add_argument('--eval_path', - default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth') - - parser.add_argument('--manualSeed', default=42, type=int, help='random seed') - - parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') - - opt = parser.parse_args() - - if torch.cuda.is_available(): - opt.cuda = True - else: - opt.cuda = False - - return opt -if __name__ == '__main__': - opt = parse_args() - - main(opt) diff --git a/shapenet/__init__.py b/shapenet/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/shapenet/test_car.py b/shapenet/test_car.py deleted file mode 100644 index 7997128..0000000 --- a/shapenet/test_car.py +++ /dev/null @@ -1,905 +0,0 @@ -import torch -from pprint import pprint -from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD -from metrics.evaluation_metrics import compute_all_metrics, EMD_CD - -import torch.nn as nn -import torch.optim as optim -import torch.utils.data - -import argparse -from model.unet import get_model -from torch.distributions import Normal - -from utils.file_utils import * -from utils.visualize import * -from utils.mitsuba_renderer import write_to_xml_batch -from model.pvcnn_generation import PVCNN2Base - -from tqdm import tqdm - -from datasets.shapenet_data_pc import ShapeNet15kPointClouds -from datasets.shapenet_data_sv import * -''' -models -''' -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - KL divergence between normal distributions parameterized by mean and log-variance. - """ - return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) - + (mean1 - mean2)**2 * torch.exp(-logvar2)) - -def discretized_gaussian_log_likelihood(x, *, means, log_scales): - # Assumes data is integers [0, 1] - assert x.shape == means.shape == log_scales.shape - px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) - - centered_x = x - means - inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_x + 0.5) - cdf_plus = px0.cdf(plus_in) - min_in = inv_stdv * (centered_x - .5) - cdf_min = px0.cdf(min_in) - log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) - log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) - cdf_delta = cdf_plus - cdf_min - - log_probs = torch.where( - x < 0.001, log_cdf_plus, - torch.where(x > 0.999, log_one_minus_cdf_min, - torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) - assert log_probs.shape == x.shape - return log_probs - - -class GaussianDiffusion: - def __init__(self,betas, loss_type, model_mean_type, model_var_type): - self.loss_type = loss_type - self.model_mean_type = model_mean_type - self.model_var_type = model_var_type - assert isinstance(betas, np.ndarray) - self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy - assert (betas > 0).all() and (betas <= 1).all() - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - - # initialize twice the actual length so we can keep running for eval - # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) - - alphas = 1. - betas - alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() - alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() - - self.betas = torch.from_numpy(betas).float() - self.alphas_cumprod = alphas_cumprod.float() - self.alphas_cumprod_prev = alphas_cumprod_prev.float() - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() - self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() - self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() - self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() - - betas = torch.from_numpy(betas).float() - alphas = torch.from_numpy(alphas).float() - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.posterior_variance = posterior_variance - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) - self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) - self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) - - @staticmethod - def _extract(a, t, x_shape): - """ - Extract some coefficients at specified timesteps, - then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. - """ - bs, = t.shape - assert x_shape[0] == bs - out = torch.gather(a, 0, t) - assert out.shape == torch.Size([bs]) - return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) - - - - def q_mean_variance(self, x_start, t): - mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start - variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) - log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) - return mean, variance, log_variance - - def q_sample(self, x_start, t, noise=None): - """ - Diffuse the data (t == 0 means diffused for 1 step) - """ - if noise is None: - noise = torch.randn(x_start.shape, device=x_start.device) - assert noise.shape == x_start.shape - return ( - self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + - self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise - ) - - - def q_posterior_mean_variance(self, x_start, x_t, t): - """ - Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) - """ - assert x_start.shape == x_t.shape - posterior_mean = ( - self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + - self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t - ) - posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) - posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) - assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == - x_start.shape[0]) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - - def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): - - model_output = denoise_fn(data, t) - - - if self.model_var_type in ['fixedsmall', 'fixedlarge']: - # below: only log_variance is used in the KL computations - model_variance, model_log_variance = { - # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood - 'fixedlarge': (self.betas.to(data.device), - torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), - 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), - }[self.model_var_type] - model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data) - model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data) - else: - raise NotImplementedError(self.model_var_type) - - if self.model_mean_type == 'eps': - x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output) - - if clip_denoised: - x_recon = torch.clamp(x_recon, -.5, .5) - - model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t) - else: - raise NotImplementedError(self.loss_type) - - - assert model_mean.shape == x_recon.shape == data.shape - assert model_variance.shape == model_log_variance.shape == data.shape - if return_pred_xstart: - return model_mean, model_variance, model_log_variance, x_recon - else: - return model_mean, model_variance, model_log_variance - - def _predict_xstart_from_eps(self, x_t, t, eps): - assert x_t.shape == eps.shape - return ( - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps - ) - - ''' samples ''' - - def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True): - """ - Sample from the model - """ - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, - return_pred_xstart=True) - noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device) - assert noise.shape == data.shape - # no noise when t == 0 - nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(data.shape) - 1)) - - sample = model_mean - if use_var: - sample = sample + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise - assert sample.shape == pred_xstart.shape - return (sample, pred_xstart) if return_pred_xstart else sample - - - def p_sample_loop(self, denoise_fn, shape, device, - noise_fn=torch.randn, constrain_fn=lambda x, t:x, - clip_denoised=True, max_timestep=None, keep_running=False): - """ - Generate samples - keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps - - """ - if max_timestep is None: - final_time = self.num_timesteps - else: - final_time = max_timestep - - assert isinstance(shape, (tuple, list)) - img_t = noise_fn(size=shape, dtype=torch.float, device=device) - for t in reversed(range(0, final_time if not keep_running else len(self.betas))): - img_t = constrain_fn(img_t, t) - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, return_pred_xstart=False).detach() - - - assert img_t.shape == shape - return img_t - - def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x): - - assert t >= 1 - - t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1) - encoding = self.q_sample(x0, t_vec) - - img_t = encoding - - for k in reversed(range(0,t)): - img_t = constrain_fn(img_t, k) - t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=False, return_pred_xstart=False, use_var=True).detach() - - - return img_t - - -class PVCNN2(PVCNN2Base): - sa_blocks = [ - ((32, 2, 32), (1024, 0.1, 32, (32, 64))), - ((64, 3, 16), (256, 0.2, 32, (64, 128))), - ((128, 3, 8), (64, 0.4, 32, (128, 256))), - (None, (16, 0.8, 32, (256, 256, 512))), - ] - fp_blocks = [ - ((256, 256), (256, 3, 8)), - ((256, 256), (256, 3, 8)), - ((256, 128), (128, 2, 16)), - ((128, 128, 64), (64, 2, 32)), - ] - - def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, - voxel_resolution_multiplier=1): - super().__init__( - num_classes=num_classes, embed_dim=embed_dim, use_att=use_att, - dropout=dropout, extra_feature_channels=extra_feature_channels, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier - ) - -class Model(nn.Module): - def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): - super(Model, self).__init__() - self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type) - - self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention, - dropout=args.dropout, extra_feature_channels=0) - - def prior_kl(self, x0): - return self.diffusion._prior_bpd(x0) - - def all_kl(self, x0, clip_denoised=True): - total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) - - return { - 'total_bpd_b': total_bpd_b, - 'terms_bpd': vals_bt, - 'prior_bpd_b': prior_bpd_b, - 'mse_bt':mse_bt - } - - - def _denoise(self, data, t): - B, D,N= data.shape - assert data.dtype == torch.float - assert t.shape == torch.Size([B]) and t.dtype == torch.int64 - - out = self.model(data, t) - - assert out.shape == torch.Size([B, D, N]) - return out - - def get_loss_iter(self, data, noises=None): - B, D, N = data.shape - t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) - - if noises is not None: - noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) - - losses = self.diffusion.p_losses( - denoise_fn=self._denoise, data_start=data, t=t, noise=noises) - assert losses.shape == t.shape == torch.Size([B]) - return losses - - def gen_samples(self, shape, device, noise_fn=torch.randn, constrain_fn=lambda x, t:x, - clip_denoised=False, max_timestep=None, - keep_running=False): - return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn, - constrain_fn=constrain_fn, - clip_denoised=clip_denoised, max_timestep=max_timestep, - keep_running=keep_running) - - def reconstruct(self, x0, t, constrain_fn=lambda x, t:x): - - return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn) - - def train(self): - self.model.train() - - def eval(self): - self.model.eval() - - def multi_gpu_wrapper(self, f): - self.model = f(self.model) - - -def get_betas(schedule_type, b_start, b_end, time_num): - if schedule_type == 'linear': - betas = np.linspace(b_start, b_end, time_num) - elif schedule_type == 'warm0.1': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.1) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.2': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.2) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.5': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.5) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - else: - raise NotImplementedError(schedule_type) - return betas - -def get_constrain_function(ground_truth, mask, eps, num_steps=1): - ''' - - :param target_shape_constraint: target voxels - :return: constrained x - ''' - # eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2)) - eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000)**2 )) - def constrain_fn(x, t): - eps_ = eps_all[t] if (t<1000) else 0 - for _ in range(num_steps): - x = x - eps_ * ((x - ground_truth) * mask) - - - return x - return constrain_fn - - -############################################################################# - -def get_dataset(dataroot, npoints,category,use_mask=False): - tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True, use_mask = use_mask) - te_dataset = ShapeNet15kPointClouds(root_dir=dataroot, - categories=category, split='val', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - use_mask=use_mask - ) - return tr_dataset, te_dataset - -def get_svr_dataset(pc_dataroot, mesh_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, - cache=os.path.join(mesh_root, '../cache'), split='val', - categories=category, - radius=3, elev=0, azim=0, img_size=512, focal_length=1000, - npoints=npoints, sv_samples=200, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - ) - return te_dataset - - -def get_mvr_dataset(pc_dataroot, mesh_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, - cache=os.path.join(mesh_root, '../cache'), split='val', - categories=category, - npoints=npoints, sv_samples=200, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - ) - - - return te_dataset -def get_mvr_dataset_v2(pc_dataroot, views_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, - cache=os.path.join(pc_dataroot, '../cache'), split='val', - categories=category, - npoints=npoints, sv_samples=200, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - ) - return te_dataset - - -def evaluate_gen(opt, ref_pcs, logger): - - if ref_pcs is None: - _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=False) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - ref = [] - for data in tqdm(test_dataloader, total=len(test_dataloader), desc='Generating Samples'): - x = data['test_points'] - m, s = data['mean'].float(), data['std'].float() - - ref.append(x*s + m) - - ref_pcs = torch.cat(ref, dim=0).contiguous() - - logger.info("Loading sample path: %s" - % (opt.eval_path)) - sample_pcs = torch.load(opt.eval_path).contiguous() - - logger.info("Generation sample size:%s reference size: %s" - % (sample_pcs.size(), ref_pcs.size())) - - - # Compute metrics - results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) - results = {k: (v.cpu().detach().item() - if not isinstance(v, float) else v) for k, v in results.items()} - - pprint(results) - logger.info(results) - - jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy()) - pprint('JSD: {}'.format(jsd)) - logger.info('JSD: {}'.format(jsd)) - -def evaluate_recon(opt, netE, save_dir, logger): - test_dataset = get_mvr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2', - opt.npoints, opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - ref = [] - samples = [] - images = [] - masked = [] - k = 0 - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - - randind = i%24 - gt_all = data['test_points'][:,randind:randind+1] - x_all = data['sv_points'][:,randind:randind+1] - mask_all= data['masks'][:,randind:randind+1] - img_all = data['image'][:,randind:randind+1] - - - B,V,N,C = x_all.shape - - x = x_all.reshape(B*V,N,C).transpose(1,2).contiguous() - mask = mask_all.reshape(B*V,N,C).transpose(1,2).contiguous() - img = img_all.reshape(B*V, *img_all.shape[2:]) - - m, s = data['mean'].float(), data['std'].float() - - # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None) - # for t in [10]: - # recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, - # opt.constrain_steps)).detach().cpu() - - recon = netE.gen_samples(x.shape, 'cuda', - constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, - opt.constrain_steps), - clip_denoised=False).detach().cpu() - - recon = recon.transpose(1, 2).contiguous() - x = x.transpose(1, 2).contiguous() - - - # for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))): - # write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - # visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None) - # - # export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - # - # k+=1 - - x_adj = x.reshape(B,V,N,C)* s + m - recon_adj = recon.reshape(B,V,N,C)* s + m - img = img.reshape(B,V,*img.shape[1:]) - - ref.append( gt_all * s + m) - masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) - samples.append(recon_adj) - images.append(img) - - - ref_pcs = torch.cat(ref, dim=0) - sample_pcs = torch.cat(samples, dim=0) - images = torch.cat(images, dim=0) - masked = torch.cat(masked, dim=0) - - B, V, N, C = ref_pcs.shape - - - torch.save(ref_pcs.reshape(B*V, N, C), os.path.join(save_dir, 'recon_gt.pth')) - torch.save(images.reshape(B*V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) - torch.save(masked.reshape(B*V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) - # Compute metrics - results = EMD_CD(sample_pcs.reshape(B*V, N, C), - ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) - - results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} - - pprint({key: val.mean() for key, val in results.items()}) - logger.info({key: val.mean() for key, val in results.items()}) - - results['pc'] = sample_pcs - torch.save(results, os.path.join(save_dir, 'ours_results.pth')) - - # - # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) - # - # results = {k: (v.cpu().detach().item() - # if not isinstance(v, float) else v) for k, v in results.items()} - # pprint(results) - # logger.info(results) -def evaluate_recon_mvr(opt, netE, save_dir, logger): - test_dataset = get_mvr_dataset_v2(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', - opt.npoints, opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - ref = [] - samples = [] - images = [] - masked = [] - k = 0 - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - - gt_all = data['test_points'] - x_all = data['sv_points'] - mask_all= data['masks'] - img_all = data['image'] - - - B,V,N,C = x_all.shape - gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1) - - # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None) - # for t in [10]: - # recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, - # opt.constrain_steps)).detach().cpu() - - cd_res = [] - recon_res = [] - for p in range(5): - x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous() - mask = mask_all.reshape(B * V, N, C).transpose(1, 2).contiguous() - img = img_all.reshape(B * V, *img_all.shape[2:]) - - m, s = data['mean'].float(), data['std'].float() - - recon = netE.gen_samples(x.shape, 'cuda', - constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, - opt.constrain_steps), - clip_denoised=False).detach().cpu() - - recon = recon.transpose(1, 2).contiguous() - x = x.transpose(1, 2).contiguous() - - cd = (((recon - x)**2)*mask.transpose(1, 2).contiguous()).sum(dim=(1,2)) - - cd_res.append(cd) - recon_res.append(recon) - - cd_res = torch.stack(cd_res, dim=0) - recon_res = torch.stack(recon_res, dim=0) - _, argmin = torch.min(cd_res, 0) - recon = recon_res[argmin,torch.arange(0,argmin.shape[0])] - - # for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))): - # write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - # visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None) - # - # export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - # - # k+=1 - - x_adj = x.reshape(B,V,N,C)* s + m - recon_adj = recon.reshape(B,V,N,C)* s + m - img = img.reshape(B,V,*img.shape[1:]) - - ref.append( gt_all * s + m) - masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) - samples.append(recon_adj) - images.append(img) - - ref_pcs = torch.cat(ref, dim=0) - sample_pcs = torch.cat(samples, dim=0) - images = torch.cat(images, dim=0) - masked = torch.cat(masked, dim=0) - - B, V, N, C = ref_pcs.shape - - - torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth')) - torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) - torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) - # Compute metrics - results = EMD_CD(sample_pcs.reshape(B*V, N, C), - ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) - - results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} - - pprint({key: val.mean().item() for key, val in results.items()}) - logger.info({key: val.mean().item() for key, val in results.items()}) - - results['pc'] = sample_pcs - torch.save(results, os.path.join(save_dir, 'ours_results.pth')) - - # - # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) - # - # results = {k: (v.cpu().detach().item() - # if not isinstance(v, float) else v) for k, v in results.items()} - # pprint(results) - # logger.info(results) - - - -def generate(netE, opt, logger): - - _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes) - - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - - with torch.no_grad(): - - samples = [] - ref = [] - - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Generating Samples'): - - x = data['test_points'].transpose(1,2) - m, s = data['mean'].float(), data['std'].float() - - gen = netE.gen_samples(x.shape, - 'cuda', clip_denoised=False).detach().cpu() - - gen = gen.transpose(1,2).contiguous() - x = x.transpose(1,2).contiguous() - - - - gen = gen * s + m - x = x * s + m - samples.append(gen) - ref.append(x) - - visualize_pointcloud_batch(os.path.join(str(Path(opt.eval_path).parent), 'x.png'), gen[:64], None, - None, None) - - write_to_xml_batch(os.path.join(str(Path(opt.eval_path).parent), 'xml_samples_%03d'%i), gen[:min(gen.shape[0], 40)].numpy(), default_color='b') - - samples = torch.cat(samples, dim=0) - ref = torch.cat(ref, dim=0) - - torch.save(samples, opt.eval_path) - - - - return ref - - -def generate_multimodal(opt, netE, save_dir, logger): - test_dataset = get_svr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2', - opt.npoints, opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - - k = 0 - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - - gt_all = data['test_points'] - x_all = data['sv_points'] - mask_all= data['masks'] - img_all = data['image'] - - - # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None) - # for t in [10]: - # recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, - # opt.constrain_steps)).detach().cpu() - - Path(os.path.join(save_dir, 'x_%03d'%k)).mkdir(exist_ok=True, parents=True) - Path(os.path.join(save_dir, 'x_ply_%03d' % k)).mkdir(exist_ok=True, parents=True) - - for v in range(5): - x = x_all.transpose(1, 2).contiguous() - mask = mask_all.transpose(1, 2).contiguous() - img = img_all - - m, s = data['mean'].float(), data['std'].float() - - recon = netE.gen_samples(x.shape, 'cuda', - constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, - opt.constrain_steps), - clip_denoised=False).detach().cpu() - - recon = recon.transpose(1, 2).contiguous() - x = x.transpose(1, 2).contiguous() - - - for d in zip(list(gt_all), list(recon), list(x), list(img)): - - im = np.fliplr(np.flipud(d[-1])) - plt.imsave(os.path.join(save_dir, 'depth_%03d.png'%k), im, cmap='gray') - write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k, 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - - export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k, 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - - k+=1 - -def main(opt): - exp_id = os.path.splitext(os.path.basename(__file__))[0] - dir_id = os.path.dirname(__file__) - output_dir = get_output_dir(dir_id, exp_id) - copy_source(__file__, output_dir) - logger = setup_logging(output_dir) - - outf_syn, = setup_output_subdirs(output_dir, 'syn') - - betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) - netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) - - if opt.cuda: - netE.cuda() - - def _transform_(m): - return nn.parallel.DataParallel(m) - - netE = netE.cuda() - netE.multi_gpu_wrapper(_transform_) - - netE.eval() - - ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] - - with torch.no_grad(): - for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): - #'/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-02-45/epoch_1399.pth' - #'/viscam/u/alexzhou907/research/diffusion/shapenet/output/67_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-03-38/epoch_2799.pth' - opt.netE = '/viscam/u/alexzhou907/research/diffusion/shapenet/output/67_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-03-38/epoch_2799.pth'#ckpt - logger.info("Resume Path:%s" % opt.netE) - - resumed_param = torch.load(opt.netE) - netE.load_state_dict(resumed_param['model_state']) - - - ref = None - if opt.generate: - epoch = int(os.path.basename(ckpt).split('.')[0].split('_')[-1]) - opt.eval_path = os.path.join(outf_syn, 'epoch_{}_samples.pth'.format(epoch)) - Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True) - ref=generate(netE, opt, logger) - if opt.eval_gen: - # Evaluate generation - evaluate_gen(opt, ref, logger) - - if opt.eval_recon: - # Evaluate generation - evaluate_recon(opt, netE, outf_syn, logger) - - if opt.eval_recon_mvr: - # Evaluate generation - evaluate_recon_mvr(opt, netE, outf_syn, logger) - - if opt.generate_multimodal: - - generate_multimodal(opt, netE, outf_syn, logger) - - exit() - -def parse_args(): - - parser = argparse.ArgumentParser() - parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') - parser.add_argument('--classes', default=['car']) - - parser.add_argument('--batch_size', type=int, default=50, help='input batch size') - parser.add_argument('--workers', type=int, default=16, help='workers') - parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') - - parser.add_argument('--generate',default=True) - parser.add_argument('--eval_gen', default=True) - - parser.add_argument('--nc', default=3) - parser.add_argument('--npoints', default=2048) - '''model''' - parser.add_argument('--beta_start', default=0.0001) - parser.add_argument('--beta_end', default=0.02) - parser.add_argument('--schedule_type', default='linear') - parser.add_argument('--time_num', default=1000) - - #params - parser.add_argument('--attention', default=True) - parser.add_argument('--dropout', default=0.1) - parser.add_argument('--embed_dim', type=int, default=64) - parser.add_argument('--loss_type', default='mse') - parser.add_argument('--model_mean_type', default='eps') - parser.add_argument('--model_var_type', default='fixedsmall') - - - # constrain function - parser.add_argument('--constrain_eps', default=0.2) - parser.add_argument('--constrain_steps', type=int, default=1) - - parser.add_argument('--ckpt_dir', default='',required=True, help="path to netE (to continue training)") - - '''eval''' - - parser.add_argument('--eval_path', - default='') - - parser.add_argument('--manualSeed', default=42, type=int, help='random seed') - - parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') - - opt = parser.parse_args() - - if torch.cuda.is_available(): - opt.cuda = True - else: - opt.cuda = False - - return opt -if __name__ == '__main__': - opt = parse_args() - set_seed(opt) - - main(opt) - - # results in /viscam/u/alexzhou907/research/diffusion/shapenet/output/test_chair \ No newline at end of file diff --git a/shapenet/test_chair.py b/shapenet/test_chair.py deleted file mode 100644 index c88ec0b..0000000 --- a/shapenet/test_chair.py +++ /dev/null @@ -1,911 +0,0 @@ -import torch -from pprint import pprint -from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD -from metrics.evaluation_metrics import compute_all_metrics, EMD_CD - -import torch.nn as nn -import torch.optim as optim -import torch.utils.data - -import argparse -from model.unet import get_model -from torch.distributions import Normal - -from utils.file_utils import * -from utils.visualize import * -from utils.mitsuba_renderer import write_to_xml_batch -from model.pvcnn_generation import PVCNN2Base - -from tqdm import tqdm - -from datasets.shapenet_data_pc import ShapeNet15kPointClouds -from datasets.shapenet_data_sv import * -''' -models -''' -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - KL divergence between normal distributions parameterized by mean and log-variance. - """ - return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) - + (mean1 - mean2)**2 * torch.exp(-logvar2)) - -def discretized_gaussian_log_likelihood(x, *, means, log_scales): - # Assumes data is integers [0, 1] - assert x.shape == means.shape == log_scales.shape - px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) - - centered_x = x - means - inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_x + 0.5) - cdf_plus = px0.cdf(plus_in) - min_in = inv_stdv * (centered_x - .5) - cdf_min = px0.cdf(min_in) - log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) - log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) - cdf_delta = cdf_plus - cdf_min - - log_probs = torch.where( - x < 0.001, log_cdf_plus, - torch.where(x > 0.999, log_one_minus_cdf_min, - torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) - assert log_probs.shape == x.shape - return log_probs - - -class GaussianDiffusion: - def __init__(self,betas, loss_type, model_mean_type, model_var_type): - self.loss_type = loss_type - self.model_mean_type = model_mean_type - self.model_var_type = model_var_type - assert isinstance(betas, np.ndarray) - self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy - assert (betas > 0).all() and (betas <= 1).all() - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - - # initialize twice the actual length so we can keep running for eval - # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) - - alphas = 1. - betas - alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() - alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() - - self.betas = torch.from_numpy(betas).float() - self.alphas_cumprod = alphas_cumprod.float() - self.alphas_cumprod_prev = alphas_cumprod_prev.float() - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() - self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() - self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() - self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() - - betas = torch.from_numpy(betas).float() - alphas = torch.from_numpy(alphas).float() - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.posterior_variance = posterior_variance - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) - self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) - self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) - - @staticmethod - def _extract(a, t, x_shape): - """ - Extract some coefficients at specified timesteps, - then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. - """ - bs, = t.shape - assert x_shape[0] == bs - out = torch.gather(a, 0, t) - assert out.shape == torch.Size([bs]) - return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) - - - - def q_mean_variance(self, x_start, t): - mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start - variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) - log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) - return mean, variance, log_variance - - def q_sample(self, x_start, t, noise=None): - """ - Diffuse the data (t == 0 means diffused for 1 step) - """ - if noise is None: - noise = torch.randn(x_start.shape, device=x_start.device) - assert noise.shape == x_start.shape - return ( - self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + - self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise - ) - - - def q_posterior_mean_variance(self, x_start, x_t, t): - """ - Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) - """ - assert x_start.shape == x_t.shape - posterior_mean = ( - self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + - self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t - ) - posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) - posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) - assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == - x_start.shape[0]) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - - def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): - - model_output = denoise_fn(data, t) - - - if self.model_var_type in ['fixedsmall', 'fixedlarge']: - # below: only log_variance is used in the KL computations - model_variance, model_log_variance = { - # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood - 'fixedlarge': (self.betas.to(data.device), - torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), - 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), - }[self.model_var_type] - model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data) - model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data) - else: - raise NotImplementedError(self.model_var_type) - - if self.model_mean_type == 'eps': - x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output) - - if clip_denoised: - x_recon = torch.clamp(x_recon, -.5, .5) - - model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t) - else: - raise NotImplementedError(self.loss_type) - - - assert model_mean.shape == x_recon.shape == data.shape - assert model_variance.shape == model_log_variance.shape == data.shape - if return_pred_xstart: - return model_mean, model_variance, model_log_variance, x_recon - else: - return model_mean, model_variance, model_log_variance - - def _predict_xstart_from_eps(self, x_t, t, eps): - assert x_t.shape == eps.shape - return ( - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps - ) - - ''' samples ''' - - def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True): - """ - Sample from the model - """ - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, - return_pred_xstart=True) - noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device) - assert noise.shape == data.shape - # no noise when t == 0 - nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(data.shape) - 1)) - - sample = model_mean - if use_var: - sample = sample + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise - assert sample.shape == pred_xstart.shape - return (sample, pred_xstart) if return_pred_xstart else sample - - - def p_sample_loop(self, denoise_fn, shape, device, - noise_fn=torch.randn, constrain_fn=lambda x, t:x, - clip_denoised=True, max_timestep=None, keep_running=False): - """ - Generate samples - keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps - - """ - if max_timestep is None: - final_time = self.num_timesteps - else: - final_time = max_timestep - - assert isinstance(shape, (tuple, list)) - img_t = noise_fn(size=shape, dtype=torch.float, device=device) - for t in reversed(range(0, final_time if not keep_running else len(self.betas))): - img_t = constrain_fn(img_t, t) - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, return_pred_xstart=False).detach() - - - assert img_t.shape == shape - return img_t - - def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x): - - assert t >= 1 - - t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1) - encoding = self.q_sample(x0, t_vec) - - img_t = encoding - - for k in reversed(range(0,t)): - img_t = constrain_fn(img_t, k) - t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=False, return_pred_xstart=False, use_var=True).detach() - - - return img_t - - def interpolate(self, x0, x1, t, lamb, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x): - - assert t >= 1 - - t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1) - encoding0 = self.q_sample(x0, t_vec) - encoding1 = self.q_sample(x1, t_vec) - - enc = encoding0 * lamb + (1-lamb) * encoding1 - - img_t = enc - - for k in reversed(range(0,t)): - img_t = constrain_fn(img_t, k) - t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=False, return_pred_xstart=False, use_var=True).detach() - - - return img_t - -class PVCNN2(PVCNN2Base): - sa_blocks = [ - ((32, 2, 32), (1024, 0.1, 32, (32, 64))), - ((64, 3, 16), (256, 0.2, 32, (64, 128))), - ((128, 3, 8), (64, 0.4, 32, (128, 256))), - (None, (16, 0.8, 32, (256, 256, 512))), - ] - fp_blocks = [ - ((256, 256), (256, 3, 8)), - ((256, 256), (256, 3, 8)), - ((256, 128), (128, 2, 16)), - ((128, 128, 64), (64, 2, 32)), - ] - - def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, - voxel_resolution_multiplier=1): - super().__init__( - num_classes=num_classes, embed_dim=embed_dim, use_att=use_att, - dropout=dropout, extra_feature_channels=extra_feature_channels, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier - ) - -class Model(nn.Module): - def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): - super(Model, self).__init__() - self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type) - - self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention, - dropout=args.dropout, extra_feature_channels=0) - - def prior_kl(self, x0): - return self.diffusion._prior_bpd(x0) - - def all_kl(self, x0, clip_denoised=True): - total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) - - return { - 'total_bpd_b': total_bpd_b, - 'terms_bpd': vals_bt, - 'prior_bpd_b': prior_bpd_b, - 'mse_bt':mse_bt - } - - - def _denoise(self, data, t): - B, D,N= data.shape - assert data.dtype == torch.float - assert t.shape == torch.Size([B]) and t.dtype == torch.int64 - - out = self.model(data, t) - - assert out.shape == torch.Size([B, D, N]) - return out - - def get_loss_iter(self, data, noises=None): - B, D, N = data.shape - t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) - - if noises is not None: - noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) - - losses = self.diffusion.p_losses( - denoise_fn=self._denoise, data_start=data, t=t, noise=noises) - assert losses.shape == t.shape == torch.Size([B]) - return losses - - def gen_samples(self, shape, device, noise_fn=torch.randn, constrain_fn=lambda x, t:x, - clip_denoised=False, max_timestep=None, - keep_running=False): - return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn, - constrain_fn=constrain_fn, - clip_denoised=clip_denoised, max_timestep=max_timestep, - keep_running=keep_running) - - def reconstruct(self, x0, t, constrain_fn=lambda x, t:x): - - return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn) - - def interpolate(self, x0, x1, t, lamb, constrain_fn=lambda x, t:x): - - return self.diffusion.interpolate(x0, x1, t, lamb, self._denoise, constrain_fn=constrain_fn) - - def train(self): - self.model.train() - - def eval(self): - self.model.eval() - - def multi_gpu_wrapper(self, f): - self.model = f(self.model) - - -def get_betas(schedule_type, b_start, b_end, time_num): - if schedule_type == 'linear': - betas = np.linspace(b_start, b_end, time_num) - elif schedule_type == 'warm0.1': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.1) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.2': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.2) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.5': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.5) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - else: - raise NotImplementedError(schedule_type) - return betas - -def get_constrain_function(ground_truth, mask, eps, num_steps=1): - ''' - - :param target_shape_constraint: target voxels - :return: constrained x - ''' - # eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2)) - eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000)**2 )) - def constrain_fn(x, t): - eps_ = eps_all[t] if (t<1000) else 0 - for _ in range(num_steps): - x = x - eps_ * ((x - ground_truth) * mask) - - - return x - return constrain_fn - - -############################################################################# - -def get_dataset(dataroot, npoints,category,use_mask=False): - tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True, use_mask = use_mask) - te_dataset = ShapeNet15kPointClouds(root_dir=dataroot, - categories=category, split='val', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - use_mask=use_mask - ) - return tr_dataset, te_dataset - -def get_svr_dataset(pc_dataroot, mesh_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, - cache=os.path.join(mesh_root, '../cache'), split='val', - categories=category, - radius=3, elev=-89, azim=180, img_size=512, focal_length=1000, - npoints=npoints, sv_samples=200, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - ) - return te_dataset - -def get_mvr_dataset(pc_dataroot, mesh_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, - cache=os.path.join(mesh_root, '../cache'), split='val', - categories=category, - npoints=npoints, sv_samples=200, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - ) - - - return te_dataset - -def get_mvr_dataset_v2(pc_dataroot, views_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, - cache=os.path.join(pc_dataroot, '../cache'), split='val', - categories=category, - npoints=npoints, sv_samples=200, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - ) - return te_dataset - - -def evaluate_gen(opt, ref_pcs, logger): - - if ref_pcs is None: - _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=False) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - ref = [] - for data in tqdm(test_dataloader, total=len(test_dataloader), desc='Generating Samples'): - x = data['test_points'] - m, s = data['mean'].float(), data['std'].float() - - ref.append(x*s + m) - - ref_pcs = torch.cat(ref, dim=0).contiguous() - - logger.info("Loading sample path: %s" - % (opt.eval_path)) - sample_pcs = torch.load(opt.eval_path).contiguous() - - logger.info("Generation sample size:%s reference size: %s" - % (sample_pcs.size(), ref_pcs.size())) - - - # Compute metrics - # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) - # results = {k: (v.cpu().detach().item() - # if not isinstance(v, float) else v) for k, v in results.items()} - # - # pprint(results) - # logger.info(results) - - jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy()) - pprint('JSD: {}'.format(jsd)) - logger.info('JSD: {}'.format(jsd)) - -def evaluate_recon(opt, netE, save_dir, logger): - test_dataset = get_mvr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2', - opt.npoints, opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - ref = [] - samples = [] - images = [] - masked = [] - k = 0 - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - - randind = i%24 - gt_all = data['test_points'][:,randind:randind+1] - x_all = data['sv_points'][:,randind:randind+1] - mask_all= data['masks'][:,randind:randind+1] - img_all = data['image'][:,randind:randind+1] - - - B,V,N,C = x_all.shape - - gt = gt_all.reshape(B*V,N,C).transpose(1,2).contiguous() - x = x_all.reshape(B*V,N,C).transpose(1,2).contiguous() - mask = mask_all.reshape(B*V,N,C).transpose(1,2).contiguous() - img = img_all.reshape(B*V, *img_all.shape[2:]) - - m, s = data['mean'].float(), data['std'].float() - - # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None) - # for t in [10]: - # t_vec = torch.empty(gt.shape[0], dtype=torch.int64, device='cuda').fill_(80) - # recon = netE.diffusion.q_sample(gt.cuda(), t_vec).detach().cpu() - recon = netE.gen_samples(x.shape, 'cuda', - constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, - opt.constrain_steps), - clip_denoised=False).detach().cpu() - - # recon = recon.transpose(1, 2).contiguous() - # x = x.transpose(1, 2).contiguous() - # gt = gt.transpose(1, 2).contiguous() - # write_to_xml_batch(os.path.join(save_dir, 'intermediate_%03d' % i), - # (recon.detach().cpu() * s[0].squeeze() + m[0].squeeze()).numpy()) - # write_to_xml_batch(os.path.join(save_dir, 'x_%03d' % i), - # (gt.detach().cpu() * s[0].squeeze() + m[0].squeeze()).numpy()) - # write_to_xml_batch(os.path.join(save_dir, 'noise_%03d' % i), - # (torch.randn_like(gt).detach().cpu() * s[0].squeeze() + m[0].squeeze()).numpy()) - # for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))): - # write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - # visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None) - # - # k+=1 - - x_adj = x.reshape(B,V,N,C)* s + m - recon_adj = recon.reshape(B,V,N,C)* s + m - img = img.reshape(B,V,*img.shape[1:]) - - ref.append( gt_all * s + m) - masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) - samples.append(recon_adj) - images.append(img) - - - ref_pcs = torch.cat(ref, dim=0) - sample_pcs = torch.cat(samples, dim=0) - images = torch.cat(images, dim=0) - masked = torch.cat(masked, dim=0) - - B, V, N, C = ref_pcs.shape - - - torch.save(ref_pcs.reshape(B*V, N, C), os.path.join(save_dir, 'recon_gt.pth')) - torch.save(images.reshape(B*V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) - torch.save(masked.reshape(B*V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) - # Compute metrics - results = EMD_CD(sample_pcs.reshape(B*V, N, C), - ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) - - results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} - - pprint({key: val.mean().item() for key, val in results.items()}) - logger.info({key: val.mean().item() for key, val in results.items()}) - - results['pc'] = sample_pcs - torch.save(results, os.path.join(save_dir, 'ours_results.pth')) - - # - # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) - # - # results = {k: (v.cpu().detach().item() - # if not isinstance(v, float) else v) for k, v in results.items()} - # pprint(results) - # logger.info(results) - -def evaluate_recon_mvr(opt, netE, save_dir, logger): - test_dataset = get_mvr_dataset_v2(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', - opt.npoints, opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - ref = [] - samples = [] - images = [] - masked = [] - k = 0 - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - - gt_all = data['test_points'] - x_all = data['sv_points'] - mask_all= data['masks'] - # img_all = data['image'] - - - B,V,N,C = x_all.shape - gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1) - - x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous() - mask = mask_all.reshape(B * V, N, C).transpose(1, 2).contiguous() - # img = img_all.reshape(B * V, *img_all.shape[2:]) - - m, s = data['mean'].float(), data['std'].float() - - recon = netE.gen_samples(x.shape, 'cuda', - constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, - opt.constrain_steps), - clip_denoised=False).detach().cpu() - - recon = recon.transpose(1, 2).contiguous() - x = x.transpose(1, 2).contiguous() - - x_adj = x.reshape(B,V,N,C)* s + m - recon_adj = recon.reshape(B,V,N,C)* s + m - # img = img.reshape(B,V,*img.shape[1:]) - - ref.append( gt_all * s + m) - masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) - samples.append(recon_adj) - # images.append(img) - - ref_pcs = torch.cat(ref, dim=0) - sample_pcs = torch.cat(samples, dim=0) - # images = torch.cat(images, dim=0) - masked = torch.cat(masked, dim=0) - - B, V, N, C = ref_pcs.shape - - - torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth')) - # torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) - torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) - # Compute metrics - results = EMD_CD(sample_pcs.reshape(B*V, N, C), - ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) - - results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} - - pprint({key: val.mean().item() for key, val in results.items()}) - logger.info({key: val.mean().item() for key, val in results.items()}) - - results['pc'] = sample_pcs - torch.save(results, os.path.join(save_dir, 'ours_results.pth')) - - # - # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) - # - # results = {k: (v.cpu().detach().item() - # if not isinstance(v, float) else v) for k, v in results.items()} - # pprint(results) - # logger.info(results) - - del ref_pcs, masked - - -def generate(netE, opt, logger): - - _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes) - - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - - with torch.no_grad(): - - samples = [] - ref = [] - - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Generating Samples'): - - x = data['test_points'].transpose(1,2) - m, s = data['mean'].float(), data['std'].float() - - gen = netE.gen_samples(x.shape, - 'cuda', clip_denoised=False).detach().cpu() - - gen = gen.transpose(1,2).contiguous() - x = x.transpose(1,2).contiguous() - - - - gen = gen * s + m - x = x * s + m - samples.append(gen) - ref.append(x) - - visualize_pointcloud_batch(os.path.join(str(Path(opt.eval_path).parent), 'x.png'), gen[:64], None, - None, None) - - write_to_xml_batch(os.path.join(str(Path(opt.eval_path).parent), 'xml_samples_%03d'%i), gen[:min(gen.shape[0], 40)].numpy(), cat='chair') - - samples = torch.cat(samples, dim=0) - ref = torch.cat(ref, dim=0) - - torch.save(samples, opt.eval_path) - - - - return ref - - - -def generate_multimodal(opt, netE, save_dir, logger): - test_dataset = get_svr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2', - opt.npoints, opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - - gt_all = data['test_points'] - x_all = data['sv_points'] - mask_all= data['masks'] - img_all = data['image'] - - - # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None) - # for t in [10]: - # recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, - # opt.constrain_steps)).detach().cpu() - - for v in range(10): - x = x_all.transpose(1, 2).contiguous() - mask = mask_all.transpose(1, 2).contiguous() - img = img_all - - m, s = data['mean'].float(), data['std'].float() - - recon = netE.gen_samples(x.shape, 'cuda', - constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, - opt.constrain_steps), - clip_denoised=False).detach().cpu() - - recon = recon.transpose(1, 2).contiguous() - x = x.transpose(1, 2).contiguous() - - for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))): - - im = np.fliplr(np.flipud(d[-1])) - plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray') - write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - - export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - - - - - - -def main(opt): - exp_id = os.path.splitext(os.path.basename(__file__))[0] - dir_id = os.path.dirname(__file__) - output_dir = get_output_dir(dir_id, exp_id) - copy_source(__file__, output_dir) - logger = setup_logging(output_dir) - - outf_syn, = setup_output_subdirs(output_dir, 'syn') - - betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) - netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) - - if opt.cuda: - netE.cuda() - - def _transform_(m): - return nn.parallel.DataParallel(m) - - netE = netE.cuda() - netE.multi_gpu_wrapper(_transform_) - - netE.eval() - - ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] - - with torch.no_grad(): - for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): - #'/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-02-45/epoch_1399.pth' - #'/viscam/u/alexzhou907/research/diffusion/shapenet/output/67_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-03-38/epoch_2799.pth' - opt.netE = '/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do_best/2020-10-16-12-23-44/epoch_1799.pth'#ckpt - logger.info("Resume Path:%s" % opt.netE) - - resumed_param = torch.load(opt.netE) - netE.load_state_dict(resumed_param['model_state']) - - - ref = None - if opt.generate: - epoch = int(os.path.basename(ckpt).split('.')[0].split('_')[-1]) - opt.eval_path = os.path.join(outf_syn, 'epoch_{}_samples.pth'.format(epoch)) - Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True) - ref=generate(netE, opt, logger) - if opt.eval_gen: - # Evaluate generation - evaluate_gen(opt, ref, logger) - - if opt.eval_recon: - # Evaluate generation - evaluate_recon(opt, netE, outf_syn, logger) - - if opt.eval_recon_mvr: - # Evaluate generation - evaluate_recon_mvr(opt, netE, outf_syn, logger) - - if opt.generate_multimodal: - - generate_multimodal(opt, netE, outf_syn, logger) - - - - exit() - -def parse_args(): - - parser = argparse.ArgumentParser() - parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') - parser.add_argument('--classes', default=['chair']) - - parser.add_argument('--batch_size', type=int, default=1, help='input batch size') - parser.add_argument('--workers', type=int, default=16, help='workers') - parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') - - parser.add_argument('--generate',default=True) - parser.add_argument('--eval_gen', default=False) - parser.add_argument('--eval_recon', default=False) - parser.add_argument('--eval_recon_mvr', default=False) - parser.add_argument('--generate_multimodal', default=False) - - parser.add_argument('--nc', default=3) - parser.add_argument('--npoints', default=2048) - '''model''' - parser.add_argument('--beta_start', default=0.0001) - parser.add_argument('--beta_end', default=0.02) - parser.add_argument('--schedule_type', default='linear') - parser.add_argument('--time_num', default=1000) - - #params - parser.add_argument('--attention', default=True) - parser.add_argument('--dropout', default=0.1) - parser.add_argument('--embed_dim', type=int, default=64) - parser.add_argument('--loss_type', default='mse') - parser.add_argument('--model_mean_type', default='eps') - parser.add_argument('--model_var_type', default='fixedsmall') - - # constrain function - parser.add_argument('--constrain_eps', default=.051) - parser.add_argument('--constrain_steps', type=int, default=1) - - parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do_best/2020-10-16-12-23-44/', help="path to netE (to continue training)") - - '''eval''' - - parser.add_argument('--eval_path', - default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test_chair/2020-10-18-13-46-21/syn/epoch_1699_samples.pth') - - parser.add_argument('--manualSeed', default=42, type=int, help='random seed') - - parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') - - opt = parser.parse_args() - - if torch.cuda.is_available(): - opt.cuda = True - else: - opt.cuda = False - - return opt -if __name__ == '__main__': - opt = parse_args() - set_seed(opt) - - main(opt) - - # results in /viscam/u/alexzhou907/research/diffusion/shapenet/output/test_chair/2020-10-18-13-46-21 \ No newline at end of file diff --git a/shapenet/test_plane.py b/shapenet/test_plane.py deleted file mode 100644 index 5fd3b1d..0000000 --- a/shapenet/test_plane.py +++ /dev/null @@ -1,925 +0,0 @@ -import torch -import functools -from pprint import pprint -from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD -from metrics.evaluation_metrics import compute_all_metrics, EMD_CD, distChamfer - -import torch.nn as nn -import torch.optim as optim -import torch.utils.data - -import argparse -from model.unet import get_model -from torch.distributions import Normal - -from utils.file_utils import * -from utils.visualize import * -from utils.mitsuba_renderer import write_to_xml_batch -from model.pvcnn_generation import PVCNN2Base - -from tqdm import tqdm - -from datasets.shapenet_data_pc import ShapeNet15kPointClouds -from datasets.shapenet_data_sv import * -''' -models -''' -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - KL divergence between normal distributions parameterized by mean and log-variance. - """ - return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) - + (mean1 - mean2)**2 * torch.exp(-logvar2)) - -def discretized_gaussian_log_likelihood(x, *, means, log_scales): - # Assumes data is integers [0, 1] - assert x.shape == means.shape == log_scales.shape - px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) - - centered_x = x - means - inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_x + 0.5) - cdf_plus = px0.cdf(plus_in) - min_in = inv_stdv * (centered_x - .5) - cdf_min = px0.cdf(min_in) - log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) - log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) - cdf_delta = cdf_plus - cdf_min - - log_probs = torch.where( - x < 0.001, log_cdf_plus, - torch.where(x > 0.999, log_one_minus_cdf_min, - torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) - assert log_probs.shape == x.shape - return log_probs - - -class GaussianDiffusion: - def __init__(self,betas, loss_type, model_mean_type, model_var_type): - self.loss_type = loss_type - self.model_mean_type = model_mean_type - self.model_var_type = model_var_type - assert isinstance(betas, np.ndarray) - self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy - assert (betas > 0).all() and (betas <= 1).all() - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - - # initialize twice the actual length so we can keep running for eval - # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) - - alphas = 1. - betas - alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() - alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() - - self.betas = torch.from_numpy(betas).float() - self.alphas_cumprod = alphas_cumprod.float() - self.alphas_cumprod_prev = alphas_cumprod_prev.float() - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() - self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() - self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() - self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() - - betas = torch.from_numpy(betas).float() - alphas = torch.from_numpy(alphas).float() - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.posterior_variance = posterior_variance - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) - self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) - self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) - - @staticmethod - def _extract(a, t, x_shape): - """ - Extract some coefficients at specified timesteps, - then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. - """ - bs, = t.shape - assert x_shape[0] == bs - out = torch.gather(a, 0, t) - assert out.shape == torch.Size([bs]) - return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) - - - - def q_mean_variance(self, x_start, t): - mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start - variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) - log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) - return mean, variance, log_variance - - def q_sample(self, x_start, t, noise=None): - """ - Diffuse the data (t == 0 means diffused for 1 step) - """ - if noise is None: - noise = torch.randn(x_start.shape, device=x_start.device) - assert noise.shape == x_start.shape - return ( - self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + - self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise - ) - - - def q_posterior_mean_variance(self, x_start, x_t, t): - """ - Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) - """ - assert x_start.shape == x_t.shape - posterior_mean = ( - self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + - self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t - ) - posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) - posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) - assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == - x_start.shape[0]) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - - def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): - - model_output = denoise_fn(data, t) - - - if self.model_var_type in ['fixedsmall', 'fixedlarge']: - # below: only log_variance is used in the KL computations - model_variance, model_log_variance = { - # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood - 'fixedlarge': (self.betas.to(data.device), - torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), - 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), - }[self.model_var_type] - model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data) - model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data) - else: - raise NotImplementedError(self.model_var_type) - - if self.model_mean_type == 'eps': - x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output) - - if clip_denoised: - x_recon = torch.clamp(x_recon, -.5, .5) - - model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t) - else: - raise NotImplementedError(self.loss_type) - - - assert model_mean.shape == x_recon.shape == data.shape - assert model_variance.shape == model_log_variance.shape == data.shape - if return_pred_xstart: - return model_mean, model_variance, model_log_variance, x_recon - else: - return model_mean, model_variance, model_log_variance - - def _predict_xstart_from_eps(self, x_t, t, eps): - assert x_t.shape == eps.shape - return ( - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps - ) - - ''' samples ''' - - def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True): - """ - Sample from the model - """ - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, - return_pred_xstart=True) - noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device) - assert noise.shape == data.shape - # no noise when t == 0 - nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(data.shape) - 1)) - - sample = model_mean - if use_var: - sample = sample + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise - assert sample.shape == pred_xstart.shape - return (sample, pred_xstart) if return_pred_xstart else sample - - - def p_sample_loop(self, denoise_fn, shape, device, - noise_fn=torch.randn, constrain_fn=lambda x, t:x, - clip_denoised=True, max_timestep=None, keep_running=False): - """ - Generate samples - keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps - - """ - if max_timestep is None: - final_time = self.num_timesteps - else: - final_time = max_timestep - - assert isinstance(shape, (tuple, list)) - img_t = noise_fn(size=shape, dtype=torch.float, device=device) - for t in reversed(range(0, final_time if not keep_running else len(self.betas))): - img_t = constrain_fn(img_t, t) - t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, return_pred_xstart=False).detach() - - - assert img_t.shape == shape - return img_t - - def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x): - - assert t >= 1 - - t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1) - encoding = self.q_sample(x0, t_vec) - - img_t = encoding - - for k in reversed(range(0,t)): - img_t = constrain_fn(img_t, k) - t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=False, return_pred_xstart=False, use_var=True).detach() - - - return img_t - - def reconstruct2(self, x0, mask, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda forward, x, t:x): - z = noise_fn(size=x0.shape, dtype=torch.float, device=x0.device) - - for _ in range(10): - img_t = z - outputs =[None for _ in range(len(self.betas))] - for t in reversed(range(0, len(self.betas))): - - t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t) - outputs[t] = img_t.detach().cpu().clone() - - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=False, return_pred_xstart=False, use_var=True).detach() - - - img_t = torch.autograd.Variable(img_t.data, requires_grad=True) - - dist = ((img_t - x0) ** 2 * mask).sum(dim=0).mean() - grad = torch.autograd.grad(dist, [img_t])[0].detach() - - print('Dist', dist.detach().cpu().item()) - - for t in (range(0, len(outputs))): - - t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t) - - x = outputs[t].to(x0).requires_grad_() - - y = self.p_sample(denoise_fn=denoise_fn, data=x, t=t_, noise_fn=noise_fn, - clip_denoised=False, return_pred_xstart=False, use_var=True) - - grad = torch.autograd.grad(y, [x], grad_outputs=grad)[0] - - - z = x.detach().to(x0) - 0.1 * grad.detach() - - img_t = z - for t in reversed(range(0, len(self.betas))): - - t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t) - - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=False, return_pred_xstart=False, use_var=True).detach() - - return img_t - - -# class PVCNN2(PVCNN2Base): -# sa_blocks = [ -# ((32, 2, 32), (1024, 0.1, 32, (32, 64))), -# ((64, 3, 16), (256, 0.2, 32, (64, 128))), -# ((128, 3, 8), (64, 0.4, 32, (128, 256))), -# (None, (16, 0.8, 32, (256, 256, 512))), -# ] -# fp_blocks = [ -# ((256, 256), (256, 3, 8)), -# ((256, 256), (256, 3, 8)), -# ((256, 128), (128, 2, 16)), -# ((128, 128, 64), (64, 2, 32)), -# ] -# -# def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, -# voxel_resolution_multiplier=1): -# super().__init__( -# num_classes=num_classes, embed_dim=embed_dim, use_att=use_att, -# dropout=dropout, extra_feature_channels=extra_feature_channels, -# width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier -# ) - -class PVCNN2(PVCNN2Base): - sa_blocks = [ - ((32, 2, 32), (1024, 0.1, 32, (32, 64))), - ((64, 3, 16), (256, 0.2, 32, (64, 128))), - ((128, 3, 8), (64, 0.4, 32, (128, 256))), - (None, (16, 0.8, 32, (256, 256, 512))), - ] - fp_blocks = [ - ((256, 256), (256, 3, 8)), - ((256, 256), (256, 3, 8)), - ((256, 128), (128, 2, 16)), - ((128, 128, 64), (64, 2, 32)), - ] - - def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, - voxel_resolution_multiplier=1): - super().__init__( - num_classes=num_classes, embed_dim=embed_dim, use_att=use_att, - dropout=dropout, extra_feature_channels=extra_feature_channels, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier - ) - - -class Model(nn.Module): - def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): - super(Model, self).__init__() - self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type) - - self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention, - dropout=args.dropout, extra_feature_channels=0) - - def prior_kl(self, x0): - return self.diffusion._prior_bpd(x0) - - def all_kl(self, x0, clip_denoised=True): - total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) - - return { - 'total_bpd_b': total_bpd_b, - 'terms_bpd': vals_bt, - 'prior_bpd_b': prior_bpd_b, - 'mse_bt':mse_bt - } - - - def _denoise(self, data, t): - B, D,N= data.shape - assert data.dtype == torch.float - assert t.shape == torch.Size([B]) and t.dtype == torch.int64 - - out = self.model(data, t) - - assert out.shape == torch.Size([B, D, N]) - return out - - def get_loss_iter(self, data, noises=None): - B, D, N = data.shape - t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) - - if noises is not None: - noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) - - losses = self.diffusion.p_losses( - denoise_fn=self._denoise, data_start=data, t=t, noise=noises) - assert losses.shape == t.shape == torch.Size([B]) - return losses - - def gen_samples(self, shape, device, noise_fn=torch.randn, constrain_fn=lambda x, t:x, - clip_denoised=False, max_timestep=None, - keep_running=False): - return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn, - constrain_fn=constrain_fn, - clip_denoised=clip_denoised, max_timestep=max_timestep, - keep_running=keep_running) - - def reconstruct(self, x0, t, constrain_fn=lambda x, t:x): - - return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn) - - def reconstruct2(self, x0, mask, constrain_fn): - - return self.diffusion.reconstruct2(x0, mask, self._denoise, constrain_fn=constrain_fn) - - def train(self): - self.model.train() - - def eval(self): - self.model.eval() - - def multi_gpu_wrapper(self, f): - self.model = f(self.model) - - -def get_betas(schedule_type, b_start, b_end, time_num): - if schedule_type == 'linear': - betas = np.linspace(b_start, b_end, time_num) - elif schedule_type == 'warm0.1': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.1) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.2': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.2) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.5': - - betas = b_end * np.ones(time_num, dtype=np.float64) - warmup_time = int(time_num * 0.5) - betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - else: - raise NotImplementedError(schedule_type) - return betas - -def get_constrain_function(ground_truth, mask, eps, num_steps=1): - ''' - - :param target_shape_constraint: target voxels - :return: constrained x - ''' - # eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2)) - eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 500)**2 )) - def constrain_fn(x, t): - eps_ = eps_all[t] if (t<500) else 0 - for _ in range(num_steps): - x = x - eps_ * ((x - ground_truth) * mask) - - - return x - - # mask_single = mask[0, :, 0] - # num = mask_single.sum().int().item() - def constrain_fn2(forward, x, t): - - x = torch.autograd.Variable(x.data, requires_grad=True) - y = forward(x) - - - - dist = ((y - ground_truth)**2 * mask).sum(dim=0).mean() - grad = torch.autograd.grad(dist, [x], retain_graph=True)[0] - x = x - eps * (grad) - - print('Dist', dist.detach().cpu().item()) - - return x - return constrain_fn - - - -############################################################################# - -def get_dataset(dataroot, npoints,category,use_mask=False): - tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True, use_mask = use_mask) - te_dataset = ShapeNet15kPointClouds(root_dir=dataroot, - categories=category, split='val', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - use_mask=use_mask - ) - return tr_dataset, te_dataset - -def get_mvr_dataset(pc_dataroot, mesh_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_mesh=mesh_root, - cache=os.path.join(mesh_root, '../cache'), split='val', - categories=category, - npoints=npoints, sv_samples=200, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - ) - return te_dataset - -def get_mvr_dataset_v2(pc_dataroot, views_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=category, split='train', - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1., - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, - cache=os.path.join(pc_dataroot, '../cache'), split='val', - categories=category, - npoints=npoints, sv_samples=200, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - ) - return te_dataset - - -def evaluate_gen(opt, ref_pcs, logger): - - if ref_pcs is None: - _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=False) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - ref = [] - for data in tqdm(test_dataloader, total=len(test_dataloader), desc='Generating Samples'): - x = data['test_points'] - m, s = data['mean'].float(), data['std'].float() - - ref.append(x*s + m) - - ref_pcs = torch.cat(ref, dim=0).contiguous() - - logger.info("Loading sample path: %s" - % (opt.eval_path)) - sample_pcs = torch.load(opt.eval_path).contiguous() - - logger.info("Generation sample size:%s reference size: %s" - % (sample_pcs.size(), ref_pcs.size())) - - - # Compute metrics - # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) - # results = {k: (v.cpu().detach().item() - # if not isinstance(v, float) else v) for k, v in results.items()} - # - # pprint(results) - # logger.info(results) - - jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy()) - pprint('JSD: {}'.format(jsd)) - logger.info('JSD: {}'.format(jsd)) - - -def evaluate_recon(opt, netE, save_dir, logger): - test_dataset = get_mvr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2', - opt.npoints, opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - ref = [] - samples = [] - images = [] - masked = [] - k = 0 - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - - randind = i%24 - gt_all = data['test_points'][:,randind:randind+1] - x_all = data['sv_points'][:,randind:randind+1] - mask_all= data['masks'][:,randind:randind+1] - img_all = data['image'][:,randind:randind+1] - - - B,V,N,C = x_all.shape - - x = x_all.reshape(B*V,N,C).transpose(1,2).contiguous() - mask = mask_all.reshape(B*V,N,C).transpose(1,2).contiguous() - img = img_all.reshape(B*V, *img_all.shape[2:]) - - m, s = data['mean'].float(), data['std'].float() - - # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None) - # for t in [10]: - # recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, - # opt.constrain_steps)).detach().cpu() - - recon = netE.gen_samples(x.shape, 'cuda', - constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, - opt.constrain_steps), - clip_denoised=False).detach().cpu() - - recon = recon.transpose(1, 2).contiguous() - x = x.transpose(1, 2).contiguous() - - - # for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))): - # write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - # visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None) - # - # export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - # - # k+=1 - - x_adj = x.reshape(B,V,N,C)* s + m - recon_adj = recon.reshape(B,V,N,C)* s + m - img = img.reshape(B,V,*img.shape[1:]) - - ref.append( gt_all * s + m) - masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) - samples.append(recon_adj) - images.append(img) - - - ref_pcs = torch.cat(ref, dim=0) - sample_pcs = torch.cat(samples, dim=0) - images = torch.cat(images, dim=0) - masked = torch.cat(masked, dim=0) - - B, V, N, C = ref_pcs.shape - - - torch.save(ref_pcs.reshape(B*V, N, C), os.path.join(save_dir, 'recon_gt.pth')) - torch.save(images.reshape(B*V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) - torch.save(masked.reshape(B*V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) - # Compute metrics - results = EMD_CD(sample_pcs.reshape(B*V, N, C), - ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) - - results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} - - pprint({key: val.mean() for key, val in results.items()}) - logger.info({key: val.mean() for key, val in results.items()}) - - results['pc'] = sample_pcs - torch.save(results, os.path.join(save_dir, 'ours_results.pth')) - - # - # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) - # - # results = {k: (v.cpu().detach().item() - # if not isinstance(v, float) else v) for k, v in results.items()} - # pprint(results) - # logger.info(results) - - - -def evaluate_recon_mvr(opt, netE, save_dir, logger): - test_dataset = get_mvr_dataset_v2(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', - opt.npoints, opt.classes) - # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - ref = [] - samples = [] - images = [] - masked = [] - k = 0 - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): - - gt_all = data['test_points'] - x_all = data['sv_points'] - mask_all= data['masks'] - img_all = data['image'] - - - B,V,N,C = x_all.shape - gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1) - - # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None) - # for t in [10]: - # recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, - # opt.constrain_steps)).detach().cpu() - - cd_res = [] - recon_res = [] - for p in range(5): - x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous() - mask = mask_all.reshape(B * V, N, C).transpose(1, 2).contiguous() - img = img_all.reshape(B * V, *img_all.shape[2:]) - - m, s = data['mean'].float(), data['std'].float() - - recon = netE.gen_samples(x.shape, 'cuda', - constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps, - opt.constrain_steps), - clip_denoised=False).detach().cpu() - - recon = recon.transpose(1, 2).contiguous() - x = x.transpose(1, 2).contiguous() - - cd = (((recon - x)**2)*mask.transpose(1, 2).contiguous()).sum(dim=(1,2)) - - cd_res.append(cd) - recon_res.append(recon) - - cd_res = torch.stack(cd_res, dim=0) - recon_res = torch.stack(recon_res, dim=0) - _, argmin = torch.min(cd_res, 0) - recon = recon_res[argmin,torch.arange(0,argmin.shape[0])] - - # for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))): - # write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - # visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None) - # - # export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy()) - # - # k+=1 - - x_adj = x.reshape(B,V,N,C)* s + m - recon_adj = recon.reshape(B,V,N,C)* s + m - img = img.reshape(B,V,*img.shape[1:]) - - ref.append( gt_all * s + m) - masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) - samples.append(recon_adj) - images.append(img) - - ref_pcs = torch.cat(ref, dim=0) - sample_pcs = torch.cat(samples, dim=0) - images = torch.cat(images, dim=0) - masked = torch.cat(masked, dim=0) - - B, V, N, C = ref_pcs.shape - - - torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth')) - torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth')) - torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) - # Compute metrics - results = EMD_CD(sample_pcs.reshape(B*V, N, C), - ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) - - results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} - - pprint({key: val.mean().item() for key, val in results.items()}) - logger.info({key: val.mean().item() for key, val in results.items()}) - - results['pc'] = sample_pcs - torch.save(results, os.path.join(save_dir, 'ours_results.pth')) - - # - # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) - # - # results = {k: (v.cpu().detach().item() - # if not isinstance(v, float) else v) for k, v in results.items()} - # pprint(results) - # logger.info(results) - - - -def generate(netE, opt, logger): - - _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes) - - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) - - with torch.no_grad(): - - samples = [] - ref = [] - - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Generating Samples'): - - x = data['test_points'].transpose(1,2) - m, s = data['mean'].float(), data['std'].float() - - gen = netE.gen_samples(x.shape, - 'cuda', clip_denoised=False).detach().cpu() - - gen = gen.transpose(1,2).contiguous() - x = x.transpose(1,2).contiguous() - - - - gen = gen * s + m - x = x * s + m - samples.append(gen) - ref.append(x) - - # visualize_pointcloud_batch(os.path.join(str(Path(opt.eval_path).parent), 'x_%03d.png'%i), gen[:64], None, - # None, None) - # export_to_pc_batch(os.path.join(str(Path(opt.eval_path).parent), 'ply_%03d'%i), - # gen[:64].numpy()) - write_to_xml_batch(os.path.join(str(Path(opt.eval_path).parent), 'xml_samples_%03d'%i), gen[:min(gen.shape[0], 40)].numpy(), cat='airplane') - - samples = torch.cat(samples, dim=0) - ref = torch.cat(ref, dim=0) - - torch.save(samples, opt.eval_path) - - - - return ref - -def main(opt): - exp_id = os.path.splitext(os.path.basename(__file__))[0] - dir_id = os.path.dirname(__file__) - output_dir = get_output_dir(dir_id, exp_id) - copy_source(__file__, output_dir) - logger = setup_logging(output_dir) - - outf_syn, = setup_output_subdirs(output_dir, 'syn') - - betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) - netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) - - if opt.cuda: - netE.cuda() - - def _transform_(m): - return nn.parallel.DataParallel(m) - - netE = netE.cuda() - netE.multi_gpu_wrapper(_transform_) - - # netE.eval() - - ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')] - - with torch.no_grad(): - for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )): - #'/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-02-45/epoch_1399.pth' - #'/viscam/u/alexzhou907/research/diffusion/shapenet/output/67_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-03-38/epoch_2799.pth' - # opt.netE = '/viscam/u/alexzhou907/research/diffusion/shapenet/output/71_res32_pc_plane_mse_fs_dattng_8e3-1e5_w0.1beta_0wd_0.1do_best/2020-10-13-13-33-53/epoch_2899.pth'#ckpt - opt.netE = '/viscam/u/alexzhou907/research/diffusion/shapenet/output/71_res32_pc_plane_mse_fs_dattng_8e3-1e5_w0.1beta_0wd_0.1do/2020-10-07-13-26-10/epoch_2299.pth' - logger.info("Resume Path:%s" % opt.netE) - - resumed_param = torch.load(opt.netE) - netE.load_state_dict(resumed_param['model_state']) - - - ref = None - if opt.generate: - epoch = int(os.path.basename(ckpt).split('.')[0].split('_')[-1]) - opt.eval_path = os.path.join(outf_syn, 'epoch_{}_samples.pth'.format(epoch)) - Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True) - ref=generate(netE, opt, logger) - if opt.eval_gen: - # Evaluate generation - evaluate_gen(opt, ref, logger) - - if opt.eval_recon: - # Evaluate generation - evaluate_recon(opt, netE, outf_syn, logger) - - if opt.eval_recon_mvr: - # Evaluate generation - evaluate_recon_mvr(opt, netE, outf_syn, logger) - - -def parse_args(): - - parser = argparse.ArgumentParser() - parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') - parser.add_argument('--classes', default=['airplane']) - - parser.add_argument('--batch_size', type=int, default=20, help='input batch size') - parser.add_argument('--workers', type=int, default=16, help='workers') - parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') - - parser.add_argument('--generate',default=False) - parser.add_argument('--eval_gen', default=True) - parser.add_argument('--eval_recon', default=False) - parser.add_argument('--eval_recon_mvr', default=False) - - parser.add_argument('--nc', default=3) - parser.add_argument('--npoints', default=2048) - '''model''' - parser.add_argument('--beta_start', default=0.00001) - parser.add_argument('--beta_end', default=0.008) - parser.add_argument('--schedule_type', default='warm0.1') - parser.add_argument('--time_num', default=1000) - - #params - parser.add_argument('--attention', default=True) - parser.add_argument('--dropout', default=0.1) - parser.add_argument('--embed_dim', type=int, default=64) - parser.add_argument('--loss_type', default='mse') - parser.add_argument('--model_mean_type', default='eps') - parser.add_argument('--model_var_type', default='fixedsmall') - - # constrain function - parser.add_argument('--constrain_eps', default=.1) - parser.add_argument('--constrain_steps', type=int, default=1) - - parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/71_res32_pc_plane_mse_fs_dattng_8e3-1e5_w0.1beta_0wd_0.1do_best/2020-10-13-13-33-53', help="path to netE (to continue training)") - - '''eval''' - - parser.add_argument('--eval_path', - default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test_plane/2020-10-18-13-49-20/syn/epoch_2499_samples.pth') - - parser.add_argument('--manualSeed', default=42, type=int, help='random seed') - - parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') - - opt = parser.parse_args() - - if torch.cuda.is_available(): - opt.cuda = True - else: - opt.cuda = False - - return opt -if __name__ == '__main__': - opt = parse_args() - set_seed(opt) - - main(opt) diff --git a/shape_completion/test_completion.py b/test_completion.py similarity index 98% rename from shape_completion/test_completion.py rename to test_completion.py index 9be1ed7..4bbd5a1 100644 --- a/shape_completion/test_completion.py +++ b/test_completion.py @@ -7,9 +7,7 @@ import torch.utils.data import argparse from torch.distributions import Normal -from utils.visualize import * from utils.file_utils import * -from utils.mitsuba_renderer import write_to_xml_batch from model.pvcnn_completion import PVCNN2Base from datasets.shapenet_data_pc import ShapeNet15kPointClouds @@ -579,9 +577,8 @@ def main(opt): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') - parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', - help='input batch size') + parser.add_argument('--dataroot_pc', default='ShapeNetCore.v2.PC15k/') + parser.add_argument('--dataroot_sv', default='GenReData/') parser.add_argument('--category', default='chair') parser.add_argument('--batch_size', type=int, default=50, help='input batch size') diff --git a/shapenet/test_generation.py b/test_generation.py similarity index 98% rename from shapenet/test_generation.py rename to test_generation.py index 979e0d0..9659ba0 100644 --- a/shapenet/test_generation.py +++ b/test_generation.py @@ -4,16 +4,13 @@ from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD from metrics.evaluation_metrics import compute_all_metrics, EMD_CD import torch.nn as nn -import torch.optim as optim import torch.utils.data import argparse -from model.unet import get_model from torch.distributions import Normal from utils.file_utils import * from utils.visualize import * -from utils.mitsuba_renderer import write_to_xml_batch from model.pvcnn_generation import PVCNN2Base from tqdm import tqdm @@ -271,6 +268,8 @@ class PVCNN2(PVCNN2Base): width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier ) + + class Model(nn.Module): def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): super(Model, self).__init__() @@ -534,8 +533,8 @@ def main(opt): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') - parser.add_argument('--category', default='car') + parser.add_argument('--dataroot', default='ShapeNetCore.v2.PC15k/') + parser.add_argument('--category', default='chair') parser.add_argument('--batch_size', type=int, default=50, help='input batch size') parser.add_argument('--workers', type=int, default=16, help='workers') @@ -585,5 +584,3 @@ if __name__ == '__main__': set_seed(opt) main(opt) - - # results in /viscam/u/alexzhou907/research/diffusion/shapenet/output/test_chair \ No newline at end of file diff --git a/shape_completion/train_completion.py b/train_completion.py similarity index 98% rename from shape_completion/train_completion.py rename to train_completion.py index 6c3aae5..8d81818 100644 --- a/shape_completion/train_completion.py +++ b/train_completion.py @@ -4,7 +4,6 @@ import torch.optim as optim import torch.utils.data import argparse -from model.unet import get_model from torch.distributions import Normal from utils.file_utils import * @@ -559,7 +558,7 @@ def train(gpu, opt, output_dir, noises_init): ''' data ''' - train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints,opt.classes) + train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints,opt.category) dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None) @@ -609,10 +608,6 @@ def train(gpu, opt, output_dir, noises_init): else: start_epoch = 0 - def new_x_chain(x, num_chain): - return torch.randn(num_chain, *x.shape[1:], device=x.device) - - for epoch in range(start_epoch, opt.niter): @@ -754,7 +749,7 @@ def main(): ''' workaround ''' - train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints,opt.classes) + train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints,opt.category) noises_init = torch.randn(len(train_dataset), opt.npoints-opt.svpoints, opt.nc) if opt.dist_url == "env://" and opt.world_size == -1: @@ -772,9 +767,8 @@ def main(): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') - parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed', - help='input batch size') + parser.add_argument('--dataroot_pc', default='ShapeNetCore.v2.PC15k/') + parser.add_argument('--dataroot_sv', default='GenReData/') parser.add_argument('--category', default='chair') parser.add_argument('--bs', type=int, default=48, help='input batch size') diff --git a/shapenet/train_generation.py b/train_generation.py similarity index 99% rename from shapenet/train_generation.py rename to train_generation.py index 9141740..83d39ad 100644 --- a/shapenet/train_generation.py +++ b/train_generation.py @@ -4,7 +4,6 @@ import torch.optim as optim import torch.utils.data import argparse -from model.unet import get_model from torch.distributions import Normal from utils.file_utils import * @@ -787,10 +786,10 @@ def main(): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size') + parser.add_argument('--dataroot', default='ShapeNetCore.v2.PC15k/') parser.add_argument('--category', default='chair') - parser.add_argument('--bs', type=int, default=48, help='input batch size') + parser.add_argument('--bs', type=int, default=16, help='input batch size') parser.add_argument('--workers', type=int, default=16, help='workers') parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') diff --git a/utils/binvox_rw.py b/utils/binvox_rw.py deleted file mode 100644 index 73190d2..0000000 --- a/utils/binvox_rw.py +++ /dev/null @@ -1,266 +0,0 @@ -# Copyright (C) 2012 Daniel Maturana -# This file is part of binvox-rw-py. -# -# binvox-rw-py is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# binvox-rw-py is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with binvox-rw-py. If not, see . -# - -""" -Binvox to Numpy and back. ->>> import numpy as np ->>> import binvox_rw ->>> with open('chair.binvox', 'rb') as f: -... m1 = binvox_rw.read_as_3d_array(f) -... ->>> m1.dims -[32, 32, 32] ->>> m1.scale -41.133000000000003 ->>> m1.translate -[0.0, 0.0, 0.0] ->>> with open('chair_out.binvox', 'wb') as f: -... m1.write(f) -... ->>> with open('chair_out.binvox', 'rb') as f: -... m2 = binvox_rw.read_as_3d_array(f) -... ->>> m1.dims==m2.dims -True ->>> m1.scale==m2.scale -True ->>> m1.translate==m2.translate -True ->>> np.all(m1.data==m2.data) -True ->>> with open('chair.binvox', 'rb') as f: -... md = binvox_rw.read_as_3d_array(f) -... ->>> with open('chair.binvox', 'rb') as f: -... ms = binvox_rw.read_as_coord_array(f) -... ->>> data_ds = binvox_rw.dense_to_sparse(md.data) ->>> data_sd = binvox_rw.sparse_to_dense(ms.data, 32) ->>> np.all(data_sd==md.data) -True ->>> # the ordering of elements returned by numpy.nonzero changes with axis ->>> # ordering, so to compare for equality we first lexically sort the voxels. ->>> np.all(ms.data[:, np.lexsort(ms.data)] == data_ds[:, np.lexsort(data_ds)]) -True -""" - -import numpy as np - -class Voxels(object): - """ Holds a binvox model. - data is either a three-dimensional numpy boolean array (dense representation) - or a two-dimensional numpy float array (coordinate representation). - dims, translate and scale are the model metadata. - dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model. - scale and translate relate the voxels to the original model coordinates. - To translate voxel coordinates i, j, k to original coordinates x, y, z: - x_n = (i+.5)/dims[0] - y_n = (j+.5)/dims[1] - z_n = (k+.5)/dims[2] - x = scale*x_n + translate[0] - y = scale*y_n + translate[1] - z = scale*z_n + translate[2] - """ - - def __init__(self, data, dims, translate, scale, axis_order): - self.data = data - self.dims = dims - self.translate = translate - self.scale = scale - assert (axis_order in ('xzy', 'xyz')) - self.axis_order = axis_order - - def clone(self): - data = self.data.copy() - dims = self.dims[:] - translate = self.translate[:] - return Voxels(data, dims, translate, self.scale, self.axis_order) - - def write(self, fp): - write(self, fp) - -def read_header(fp): - """ Read binvox header. Mostly meant for internal use. - """ - line = fp.readline().strip() - if not line.startswith(b'#binvox'): - raise IOError('Not a binvox file') - dims = list(map(int, fp.readline().strip().split(b' ')[1:])) - translate = list(map(float, fp.readline().strip().split(b' ')[1:])) - scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0] - line = fp.readline() - return dims, translate, scale - -def read_as_3d_array(fp, fix_coords=True): - """ Read binary binvox format as array. - Returns the model with accompanying metadata. - Voxels are stored in a three-dimensional numpy array, which is simple and - direct, but may use a lot of memory for large models. (Storage requirements - are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy - boolean arrays use a byte per element). - Doesn't do any checks on input except for the '#binvox' line. - """ - dims, translate, scale = read_header(fp) - raw_data = np.frombuffer(fp.read(), dtype=np.uint8) - # if just using reshape() on the raw data: - # indexing the array as array[i,j,k], the indices map into the - # coords as: - # i -> x - # j -> z - # k -> y - # if fix_coords is true, then data is rearranged so that - # mapping is - # i -> x - # j -> y - # k -> z - values, counts = raw_data[::2], raw_data[1::2] - data = np.repeat(values, counts).astype(np.bool) - data = data.reshape(dims) - if fix_coords: - # xzy to xyz TODO the right thing - data = np.transpose(data, (0, 2, 1)) - axis_order = 'xyz' - else: - axis_order = 'xzy' - return Voxels(data, dims, translate, scale, axis_order) - -def read_as_coord_array(fp, fix_coords=True): - """ Read binary binvox format as coordinates. - Returns binvox model with voxels in a "coordinate" representation, i.e. an - 3 x N array where N is the number of nonzero voxels. Each column - corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates - of the voxel. (The odd ordering is due to the way binvox format lays out - data). Note that coordinates refer to the binvox voxels, without any - scaling or translation. - Use this to save memory if your model is very sparse (mostly empty). - Doesn't do any checks on input except for the '#binvox' line. - """ - dims, translate, scale = read_header(fp) - raw_data = np.frombuffer(fp.read(), dtype=np.uint8) - - values, counts = raw_data[::2], raw_data[1::2] - - sz = np.prod(dims) - index, end_index = 0, 0 - end_indices = np.cumsum(counts) - indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype) - - values = values.astype(np.bool) - indices = indices[values] - end_indices = end_indices[values] - - nz_voxels = [] - for index, end_index in zip(indices, end_indices): - nz_voxels.extend(range(index, end_index)) - nz_voxels = np.array(nz_voxels) - # TODO are these dims correct? - # according to docs, - # index = x * wxh + z * width + y; // wxh = width * height = d * d - - x = nz_voxels / (dims[0]*dims[1]) - zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y - z = zwpy / dims[0] - y = zwpy % dims[0] - if fix_coords: - data = np.vstack((x, y, z)) - axis_order = 'xyz' - else: - data = np.vstack((x, z, y)) - axis_order = 'xzy' - - #return Voxels(data, dims, translate, scale, axis_order) - return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order) - -def dense_to_sparse(voxel_data, dtype=np.int): - """ From dense representation to sparse (coordinate) representation. - No coordinate reordering. - """ - if voxel_data.ndim!=3: - raise ValueError('voxel_data is wrong shape; should be 3D array.') - return np.asarray(np.nonzero(voxel_data), dtype) - -def sparse_to_dense(voxel_data, dims, dtype=np.bool): - if voxel_data.ndim!=2 or voxel_data.shape[0]!=3: - raise ValueError('voxel_data is wrong shape; should be 3xN array.') - if np.isscalar(dims): - dims = [dims]*3 - dims = np.atleast_2d(dims).T - # truncate to integers - xyz = voxel_data.astype(np.int) - # discard voxels that fall outside dims - valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0) - xyz = xyz[:,valid_ix] - out = np.zeros(dims.flatten(), dtype=dtype) - out[tuple(xyz)] = True - return out - -#def get_linear_index(x, y, z, dims): - #""" Assuming xzy order. (y increasing fastest. - #TODO ensure this is right when dims are not all same - #""" - #return x*(dims[1]*dims[2]) + z*dims[1] + y - -def write(voxel_model, fp): - """ Write binary binvox format. - Note that when saving a model in sparse (coordinate) format, it is first - converted to dense format. - Doesn't check if the model is 'sane'. - """ - if voxel_model.data.ndim==2: - # TODO avoid conversion to dense - dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims) - else: - dense_voxel_data = voxel_model.data - - fp.write('#binvox 1\n') - fp.write('dim '+' '.join(map(str, voxel_model.dims))+'\n') - fp.write('translate '+' '.join(map(str, voxel_model.translate))+'\n') - fp.write('scale '+str(voxel_model.scale)+'\n') - fp.write('data\n') - if not voxel_model.axis_order in ('xzy', 'xyz'): - raise ValueError('Unsupported voxel model axis order') - - if voxel_model.axis_order=='xzy': - voxels_flat = dense_voxel_data.flatten() - elif voxel_model.axis_order=='xyz': - voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten() - - # keep a sort of state machine for writing run length encoding - state = voxels_flat[0] - ctr = 0 - for c in voxels_flat: - if c==state: - ctr += 1 - # if ctr hits max, dump - if ctr==255: - fp.write(chr(state)) - fp.write(chr(ctr)) - ctr = 0 - else: - # if switch state, dump - fp.write(chr(state)) - fp.write(chr(ctr)) - state = c - ctr = 1 - # flush out remainders - if ctr > 0: - fp.write(chr(state)) - fp.write(chr(ctr)) - -if __name__ == '__main__': - import doctest - doctest.testmod() \ No newline at end of file diff --git a/utils/conversion.py b/utils/conversion.py deleted file mode 100644 index 1d90335..0000000 --- a/utils/conversion.py +++ /dev/null @@ -1,46 +0,0 @@ - -from skimage import measure -import numpy as np - - -def get_mesh(tsdf_vol, color_vol, threshold=0, vol_max=.5, vol_min=-.5): - """Compute a mesh from the voxel volume using marching cubes. - """ - vol_origin = vol_min - voxel_size = (vol_max - vol_min) / tsdf_vol.shape[-1] - - # Marching cubes - verts, faces, norms, vals = measure.marching_cubes_lewiner(tsdf_vol, level=threshold) - verts_ind = np.round(verts).astype(int) - verts = verts * voxel_size + vol_origin # voxel grid coordinates to world coordinates - - # Get vertex colors - if color_vol is None: - return verts, faces, norms - colors = color_vol[:, verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]].T - - return verts, faces, norms, colors - - -def get_point_cloud(tsdf_vol, color_vol, vol_max=0.5, vol_min=-0.5): - vol_origin = vol_min - voxel_size = (vol_max - vol_min) / tsdf_vol.shape[-1] - # Marching cubes - verts = measure.marching_cubes_lewiner(tsdf_vol, level=0)[0] - verts_ind = np.round(verts).astype(int) - verts = verts * voxel_size + vol_origin - - # Get vertex colors - colors = color_vol[:, verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]].T - - pc = np.hstack([verts, colors]) - - return pc - -def sparse_to_dense_voxel(coords, feats, res): - coords = coords.astype('int64', copy=False) - a = np.zeros((res, res, res), dtype=feats.dtype) - - a[coords[:,0],coords[:,1],coords[:,2] ] = feats[:,0].astype(a.dtype, copy=False) - - return a \ No newline at end of file diff --git a/utils/mitsuba_renderer.py b/utils/mitsuba_renderer.py deleted file mode 100644 index 26cc3bf..0000000 --- a/utils/mitsuba_renderer.py +++ /dev/null @@ -1,146 +0,0 @@ -import numpy as np -from pathlib import Path -import os - - -def standardize_bbox(pcl, points_per_object, scale=None): - pt_indices = np.random.choice(pcl.shape[0], points_per_object, replace=False) - np.random.shuffle(pt_indices) - pcl = pcl[pt_indices] # n by 3 - mins = np.amin(pcl, axis=0) - maxs = np.amax(pcl, axis=0) - center = (mins + maxs) / 2. - if scale is None: - scale = np.amax(maxs - mins) - result = ((pcl - center) / scale).astype(np.float32) # [-0.5, 0.5] - return result - - -xml_head = \ - """ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - """ - -xml_ball_segment = \ - """ - - - - - - - - - - """ - -xml_tail = \ - """ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - """ - - -def colormap_fn(x, y, z): - vec = np.array([x, y, z]) - vec = np.clip(vec, 0.001, 1.0) - norm = np.sqrt(np.sum(vec ** 2)) - vec /= norm - return [vec[0], vec[1], vec[2]] - - -color_dict = {'r': [163, 102, 96], 'g': [20, 130, 3], - 'o': [145, 128, 47], 'b': [91, 102, 112], 'p':[133,111,139], 'br':[111,92,81]} - -color_map = {'airplane': 'r', 'chair': 'o', 'car': 'b', 'table': 'p', 'lamp':'br'} -fov_map = {'airplane': 12, 'chair': 16, 'car':15, 'table': 13, 'lamp':13} -radius_map = {'airplane': 0.02, 'chair': 0.035, 'car': 0.01, 'table':0.035, 'lamp':0.035} - -def write_to_xml_batch(dir, pcl_batch, filenames=None, color_batch=None, cat='airplane'): - default_color = color_map[cat] - Path(dir).mkdir(parents=True, exist_ok=True) - if filenames is not None: - assert len(filenames) == pcl_batch.shape[0] - # mins = np.amin(pcl_batch, axis=(0,1)) - # maxs = np.amax(pcl_batch, axis=(0,1)) - # scale = 1; print(np.amax(maxs - mins)) - - for k, pcl in enumerate(pcl_batch): - xml_segments = [xml_head.format(fov_map[cat])] - pcl = standardize_bbox(pcl, pcl.shape[0]) - pcl = pcl[:, [2, 0, 1]] - pcl[:, 0] *= -1 - pcl[:, 2] += 0.0125 - for i in range(pcl.shape[0]): - if color_batch is not None: - color = color_batch[k, i] - else: - color = np.array(color_dict[default_color]) / 255 - # color = colormap_fn(pcl[i,0]+0.5,pcl[i,1]+0.5,pcl[i,2]+0.5-0.0125) - xml_segments.append(xml_ball_segment.format(radius_map[cat], pcl[i, 0], pcl[i, 1], pcl[i, 2], *color)) - xml_segments.append( - xml_tail.format(pcl[:, 2].min())) - - xml_content = str.join('', xml_segments) - - if filenames is None: - fn = 'sample_{}.xml'.format(k) - else: - fn = filenames[k] - with open(os.path.join(dir, fn), 'w') as f: - f.write(xml_content) diff --git a/utils/mitsuba_renderer2.py b/utils/mitsuba_renderer2.py deleted file mode 100644 index cb725fd..0000000 --- a/utils/mitsuba_renderer2.py +++ /dev/null @@ -1,170 +0,0 @@ -import numpy as np -from pathlib import Path -import os - - -def standardize_bbox(pcl, points_per_object): - pt_indices = np.random.choice(pcl.shape[0], points_per_object, replace=False) - np.random.shuffle(pt_indices) - pcl = pcl[pt_indices] # n by 3 - mins = np.amin(pcl, axis=0) - maxs = np.amax(pcl, axis=0) - center = (mins + maxs) / 2. - scale = np.amax(maxs - mins) - result = ((pcl - center) / scale).astype(np.float32) # [-0.5, 0.5] - return result - - -xml_head = \ - """ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - """ - -xml_ball_segment = \ - """ - - - - - - - - - - """ - -xml_tail = \ - """ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - """ - - -def colormap_fn(x, y, z): - vec = np.array([x, y, z]) - vec = np.clip(vec, 0.001, 1.0) - norm = np.sqrt(np.sum(vec ** 2)) - vec /= norm - return [vec[0], vec[1], vec[2]] - - -color_dict = {'r': [163, 102, 96], 'p': [133,111,139], 'g': [20, 130, 3], - 'o': [145, 128, 47], 'b': [91, 102, 112]} - -color_map = {'airplane': 'r', 'chair': 'o', 'car': 'b', 'table': 'p'} -fov_map = {'airplane': 12, 'chair': 15, 'car':12, 'table':12} -radius_map = {'airplane': 0.0175, 'chair': 0.035, 'car': 0.025, 'table': 0.02} - -def write_to_xml_batch(dir, pcl_batch, color_batch=None, cat='airplane', elev=15, azim=45, radius=np.sqrt(18)): - elev_rad = elev * np.pi / 180 - azim_rad = azim * np.pi / 180 - - x = radius * np.cos(elev_rad)*np.cos(azim_rad) - y = radius * np.cos(elev_rad)*np.sin(azim_rad) - z = radius * np.sin(elev_rad) - - default_color = color_map[cat] - Path(dir).mkdir(parents=True, exist_ok=True) - for k, pcl in enumerate(pcl_batch): - xml_segments = [xml_head.format(x,y,z)] - pcl = standardize_bbox(pcl, pcl.shape[0]) - pcl = pcl[:, [2, 0, 1]] - pcl[:, 0] *= -1 - pcl[:, 2] += 0.0125 - for i in range(pcl.shape[0]): - if color_batch is not None: - color = color_batch[k, i] - else: - color = np.array(color_dict[default_color]) / 255 - # color = colormap_fn(pcl[i,0]+0.5,pcl[i,1]+0.5,pcl[i,2]+0.5-0.0125) - xml_segments.append(xml_ball_segment.format(0.0175, pcl[i, 0], pcl[i, 1], pcl[i, 2], *color)) - xml_segments.append( - xml_tail.format(pcl[:, 2].min())) - - xml_content = str.join('', xml_segments) - - with open(os.path.join(dir, 'sample_{}.xml'.format(k)), 'w') as f: - f.write(xml_content) - -def write_to_xml(file, pcl, cat='airplane', elev=15, azim=45, radius=np.sqrt(18)): - assert pcl.ndim == 2 - elev_rad = elev * np.pi / 180 - azim_rad = azim * np.pi / 180 - - x = radius * np.cos(elev_rad)*np.cos(azim_rad) - y = radius * np.cos(elev_rad)*np.sin(azim_rad) - z = radius * np.sin(elev_rad) - - default_color = color_map[cat] - - xml_segments = [xml_head.format(x,y,z)] - pcl = standardize_bbox(pcl, pcl.shape[0]) - pcl = pcl[:, [2, 0, 1]] - pcl[:, 0] *= -1 - pcl[:, 2] += 0.0125 - for i in range(pcl.shape[0]): - color = np.array(color_dict[default_color]) / 255 - # color = colormap_fn(pcl[i,0]+0.5,pcl[i,1]+0.5,pcl[i,2]+0.5-0.0125) - xml_segments.append(xml_ball_segment.format(radius_map[cat], pcl[i, 0], pcl[i, 1], pcl[i, 2], *color)) - xml_segments.append( - xml_tail.format(pcl[:, 2].min())) - - xml_content = str.join('', xml_segments) - - with open(file, 'w') as f: - f.write(xml_content) diff --git a/utils/xml_from_mesh.py b/utils/xml_from_mesh.py deleted file mode 100644 index b833148..0000000 --- a/utils/xml_from_mesh.py +++ /dev/null @@ -1,86 +0,0 @@ -import sys - -sys.path.append('..') -import argparse -import os -import numpy as np -import trimesh -import glob -from joblib import Parallel, delayed -import re -from utils.mitsuba_renderer import write_to_xml_batch -def rotation_matrix(axis, theta): - """ - Return the rotation matrix associated with counterclockwise rotation about - the given axis by theta radians. - """ - axis = np.asarray(axis) - axis = axis / np.sqrt(np.dot(axis, axis)) - a = np.cos(theta / 2.0) - b, c, d = -axis * np.sin(theta / 2.0) - aa, bb, cc, dd = a * a, b * b, c * c, d * d - bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d - return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], - [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], - [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) - -def rotate(vertices, faces): - ''' - vertices: [numpoints, 3] - ''' - N = rotation_matrix([1, 0, 0], 3* np.pi / 4).transpose() - # M = rotation_matrix([0, 1, 0], -np.pi / 2).transpose() - - - v, f = vertices.dot(N), faces - return v, f - -def as_mesh(scene_or_mesh): - if isinstance(scene_or_mesh, trimesh.Scene): - mesh = trimesh.util.concatenate([ - trimesh.Trimesh(vertices=m.vertices, faces=m.faces) - for m in scene_or_mesh.geometry.values()]) - else: - mesh = scene_or_mesh - return mesh -def process_one(shape_dir, cat): - pc_paths = glob.glob(os.path.join(shape_dir, "*.obj")) - pc_paths = sorted(pc_paths) - - xml_paths = [] #[re.sub('.ply', '.xml', os.path.basename(pth)) for pth in pc_paths] - - gen_pcs = [] - for path in pc_paths: - sample_mesh = trimesh.load(path, force='mesh') - v, f = rotate(sample_mesh.vertices,sample_mesh.faces) - mesh = trimesh.Trimesh(v, f) - sample_pts = trimesh.sample.sample_surface(mesh, 2048)[0] - gen_pcs.append(sample_pts) - xml_paths.append(re.sub('.obj', '.xml', os.path.basename(path))) - - - - gen_pcs = np.stack(gen_pcs, axis=0) - write_to_xml_batch(os.path.dirname(pc_paths[0]), gen_pcs, xml_paths, cat=cat) - - -def process(args): - shape_names = [n for n in sorted(os.listdir(args.src)) if - os.path.isdir(os.path.join(args.src, n)) and not n.startswith('x')] - - all_shape_dir = [os.path.join(args.src, name) for name in shape_names] - - Parallel(n_jobs=10, verbose=2)(delayed(process_one)(path) for path in all_shape_dir) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--src", type=str) - parser.add_argument("--cat", type=str) - args = parser.parse_args() - - process_one(args.src, args.cat) - - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/utils/xml_from_ply.py b/utils/xml_from_ply.py deleted file mode 100644 index 6a0aff1..0000000 --- a/utils/xml_from_ply.py +++ /dev/null @@ -1,54 +0,0 @@ -import sys - -sys.path.append('..') -import argparse -import os -import numpy as np -import trimesh -import glob -from joblib import Parallel, delayed -import re -from utils.mitsuba_renderer import write_to_xml_batch - - -def process_one(shape_dir): - pc_paths = glob.glob(os.path.join(shape_dir, "fake*.ply")) - pc_paths = sorted(pc_paths) - - xml_paths = [re.sub('.ply', '.xml', os.path.basename(pth)) for pth in pc_paths] - - gen_pcs = [] - for path in pc_paths: - sample_pts = trimesh.load(path) - sample_pts = np.array(sample_pts.vertices) - gen_pcs.append(sample_pts) - - raw_pc = np.array(trimesh.load(os.path.join(shape_dir, "raw.ply")).vertices) - raw_pc = np.concatenate([raw_pc, np.tile(raw_pc[0:1], (gen_pcs[0].shape[0]-raw_pc.shape[0],1))]) - - gen_pcs.append(raw_pc) - gen_pcs = np.stack(gen_pcs, axis=0) - xml_paths.append('raw.xml') - - write_to_xml_batch(os.path.dirname(pc_paths[0]), gen_pcs, xml_paths) - - -def process(args): - shape_names = [n for n in sorted(os.listdir(args.src)) if - os.path.isdir(os.path.join(args.src, n)) and not n.startswith('x')] - - all_shape_dir = [os.path.join(args.src, name) for name in shape_names] - - Parallel(n_jobs=10, verbose=2)(delayed(process_one)(path) for path in all_shape_dir) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--src", type=str) - args = parser.parse_args() - - process_one(args) - - -if __name__ == '__main__': - main() \ No newline at end of file