diff --git a/README.md b/README.md
index a7bbcfe..b8bc816 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,13 @@
# Shape Generation and Completion Through Point-Voxel Diffusion
+
+
+
-[Project]() | [Paper]()
+[Project](https://alexzhou907.github.io/pvd) | [Paper](https://arxiv.org/abs/2104.03670)
-Implementation of
+Implementation of Shape Generation and Completion Through Point-Voxel Diffusion
-## Pretrained Models
-
-Pretrained models can be accessed [here](https://www.dropbox.com/s/a3xydf594fzaokl/cifar10_pretrained.rar?dl=0).
+[Linqi Zhou](https://alexzhou907.github.io), [Yilun Du](https://yilundu.github.io/), [Jiajun Wu](https://jiajunwu.com/)
## Requirements:
@@ -20,24 +21,56 @@ cudatoolkit==10.1
matplotlib==2.2.5
tqdm==4.32.1
open3d==0.9.0
+trimesh=3.7.12
+scipy==1.5.1
```
+
+Install PyTorchEMD by
+```
+cd metrics/PyTorchEMD
+python setup.py install
+cp build/**/emd_cuda.cpython-36m-x86_64-linux-gnu.so .
+```
+
The code was tested on Unbuntu with Titan RTX.
+## Data
-## Training on CIFAR-10:
+For generation, we use ShapeNet point cloud, which can be downloaded [here](https://github.com/stevenygd/PointFlow).
+
+For completion, we use ShapeNet rendering provided by [GenRe](https://github.com/xiumingzhang/GenRe-ShapeHD).
+We provide script `convert_cam_params.py` to process the provided data.
+
+For training the model on shape completion, we need camera parameters for each view
+which are not directly available. To obtain these, simply run
+```bash
+$ python convert_cam_params.py --dataroot DATA_DIR --mitsuba_xml_root XML_DIR
+```
+which will create `..._cam_params.npz` in each provided data folder for each view.
+
+## Pretrained models
+Pretrained models can be downloaded [here](https://drive.google.com/drive/folders/1Q7aSaTr6lqmo8qx80nIm1j28mOHAHGiM?usp=sharing).
+
+## Training:
```bash
-$ python train_cifar.py
+$ python train_generation.py --category car|chair|airplane
```
Please refer to the python file for optimal training parameters.
+## Testing:
+
+```bash
+$ python train_generation.py --category car|chair|airplane --model MODEL_PATH
+```
+
## Results
Some generative results are as follows.
-
-
+
+
diff --git a/assets/gen_comp.gif b/assets/gen_comp.gif
new file mode 100644
index 0000000..a169382
Binary files /dev/null and b/assets/gen_comp.gif differ
diff --git a/assets/mm_partnet.gif b/assets/mm_partnet.gif
new file mode 100644
index 0000000..28defd9
Binary files /dev/null and b/assets/mm_partnet.gif differ
diff --git a/assets/mm_redwood.gif b/assets/mm_redwood.gif
new file mode 100644
index 0000000..cb786e0
Binary files /dev/null and b/assets/mm_redwood.gif differ
diff --git a/assets/mm_shapenet.gif b/assets/mm_shapenet.gif
new file mode 100644
index 0000000..a9231d2
Binary files /dev/null and b/assets/mm_shapenet.gif differ
diff --git a/assets/pvd_teaser.gif b/assets/pvd_teaser.gif
new file mode 100644
index 0000000..37c03ad
Binary files /dev/null and b/assets/pvd_teaser.gif differ
diff --git a/convert_cam_params.py b/convert_cam_params.py
new file mode 100644
index 0000000..36abb2b
--- /dev/null
+++ b/convert_cam_params.py
@@ -0,0 +1,123 @@
+
+from glob import glob
+import re
+import argparse
+import numpy as np
+from pathlib import Path
+import os
+
+def raw_camparam_from_xml(path, pose="lookAt"):
+ import xml.etree.ElementTree as ET
+ tree = ET.parse(path)
+ elm = tree.find("./sensor/transform/" + pose)
+ camparam = elm.attrib
+ origin = np.fromstring(camparam['origin'], dtype=np.float32, sep=',')
+ target = np.fromstring(camparam['target'], dtype=np.float32, sep=',')
+ up = np.fromstring(camparam['up'], dtype=np.float32, sep=',')
+ height = int(
+ tree.find("./sensor/film/integer[@name='height']").attrib['value'])
+ width = int(
+ tree.find("./sensor/film/integer[@name='width']").attrib['value'])
+
+ camparam = dict()
+ camparam['origin'] = origin
+ camparam['up'] = up
+ camparam['target'] = target
+ camparam['height'] = height
+ camparam['width'] = width
+ return camparam
+
+def get_cam_pos(origin, target, up):
+ inward = origin - target
+ right = np.cross(up, inward)
+ up = np.cross(inward, right)
+ rx = np.cross(up, inward)
+ ry = np.array(up)
+ rz = np.array(inward)
+ rx /= np.linalg.norm(rx)
+ ry /= np.linalg.norm(ry)
+ rz /= np.linalg.norm(rz)
+
+ rot = np.stack([
+ rx,
+ ry,
+ -rz
+ ], axis=0)
+
+
+ aff = np.concatenate([
+ np.eye(3), -origin[:,None]
+ ], axis=1)
+
+
+ ext = np.matmul(rot, aff)
+
+ result = np.concatenate(
+ [ext, np.array([[0,0,0,1]])], axis=0
+ )
+
+
+
+ return result
+
+
+
+def convert_cam_params_all_views(datapoint_dir, dataroot, camera_param_dir):
+ depths = sorted(glob(os.path.join(datapoint_dir, '*depth.png')))
+ cam_ext = ['_'.join(re.sub(dataroot.strip('/'), camera_param_dir.strip('/'), f).split('_')[:-1])+'.xml' for f in depths]
+
+
+ for i, (f, pth) in enumerate(zip(cam_ext, depths)):
+ if not os.path.exists(f):
+ continue
+ params=raw_camparam_from_xml(f)
+ origin, target, up, width, height = params['origin'], params['target'], params['up'],\
+ params['width'], params['height']
+
+ ext_matrix = get_cam_pos(origin, target, up)
+
+ #####
+ diag = (0.036 ** 2 + 0.024 ** 2) ** 0.5
+ focal_length = 0.05
+ res = [480, 480]
+ h_relative = (res[1] / res[0])
+ sensor_width = np.sqrt(diag ** 2 / (1 + h_relative ** 2))
+ pix_size = sensor_width / res[0]
+
+ K = np.array([
+ [focal_length / pix_size, 0, (sensor_width / pix_size - 1) / 2],
+ [0, -focal_length / pix_size, (sensor_width * (res[1] / res[0]) / pix_size - 1) / 2],
+ [0, 0, 1]
+ ])
+
+ np.savez(pth.split('depth.png')[0]+ 'cam_params.npz', extr=ext_matrix, intr=K)
+
+
+def main(opt):
+ dataroot_dir = Path(opt.dataroot)
+
+ leaf_subdirs = []
+
+ for dirpath, dirnames, filenames in os.walk(dataroot_dir):
+ if (not dirnames) and opt.mitsuba_xml_root not in dirpath:
+ leaf_subdirs.append(dirpath)
+
+
+
+ for k, dir_ in enumerate(leaf_subdirs):
+ print('Processing dir {}/{}: {}'.format(k, len(leaf_subdirs), dir_))
+
+ convert_cam_params_all_views(dir_, opt.dataroot, opt.mitsuba_xml_root)
+
+
+
+
+
+if __name__ == '__main__':
+ args = argparse.ArgumentParser()
+ args.add_argument('--dataroot', type=str, default='GenReData/')
+ args.add_argument('--mitsuba_xml_root', type=str, default='GenReData/genre-xml_v2')
+
+ opt = args.parse_args()
+
+ main(opt)
diff --git a/datasets/partnet.py b/datasets/partnet.py
index 72417a0..f074663 100644
--- a/datasets/partnet.py
+++ b/datasets/partnet.py
@@ -5,9 +5,7 @@ import os
import json
import random
import trimesh
-import csv
from plyfile import PlyData, PlyElement
-from glob import glob
def project_pc_to_image(points, resolution=64):
"""project point clouds into 2D image
@@ -181,33 +179,3 @@ class GANdatasetPartNet(Dataset):
-
-if __name__ == '__main__':
- data_root = '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetPointCloud'
- data_raw_root = '/viscam/u/alexzhou907/01DATA/shapenet/shapenet_dim32_sdf_pc'
- pc_dataroot = '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k'
-
- sn_root = '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2'
- classes = 'car'
- npoints = 2048
- # from datasets.shapenet_data_pc import ShapeNet15kPointClouds
- # pc_ds = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- # categories=[classes], split='train',
- # tr_sample_size=npoints,
- # te_sample_size=npoints,
- # scale=1.,
- # normalize_per_shape=False,
- # normalize_std_per_axis=False,
- # random_subsample=True)
-
- train_ds = GANdatasetPartNet('test', pc_dataroot, data_raw_root, classes, npoints, np.array([0,0,0]),
- np.array([1, 1, 1]))
-
- d1 = train_ds[0]
- real = d1['real']
- raw = d1['raw']
- m, s = d1['m'], d1['s']
- x = (torch.cat([raw, real], dim=-1) * s + m).transpose(0,1)
-
- write_ply(x.numpy(), 'x.ply')
- pass
diff --git a/shape_completion/12_res32_partnet_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py b/shape_completion/12_res32_partnet_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py
deleted file mode 100644
index cd76baf..0000000
--- a/shape_completion/12_res32_partnet_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py
+++ /dev/null
@@ -1,825 +0,0 @@
-import torch.multiprocessing as mp
-import torch.nn as nn
-import torch.optim as optim
-import torch.utils.data
-
-import argparse
-from model.unet import get_model
-from torch.distributions import Normal
-
-from utils.file_utils import *
-from utils.visualize import *
-from model.pvcnn_completion import PVCNN2Base
-import torch.distributed as dist
-from datasets.partnet import GANdatasetPartNet
-from datasets.shapenet_data_pc import ShapeNet15kPointClouds
-'''
-some utils
-'''
-def rotation_matrix(axis, theta):
- """
- Return the rotation matrix associated with counterclockwise rotation about
- the given axis by theta radians.
- """
- axis = np.asarray(axis)
- axis = axis / np.sqrt(np.dot(axis, axis))
- a = np.cos(theta / 2.0)
- b, c, d = -axis * np.sin(theta / 2.0)
- aa, bb, cc, dd = a * a, b * b, c * c, d * d
- bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
- return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
- [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
- [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
-
-def rotate(vertices, faces):
- '''
- vertices: [numpoints, 3]
- '''
- M = rotation_matrix([0, 1, 0], np.pi / 2).transpose()
- N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose()
- K = rotation_matrix([0, 0, 1], np.pi).transpose()
-
- v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]]
- return v, f
-
-def norm(v, f):
- v = (v - v.min())/(v.max() - v.min()) - 0.5
-
- return v, f
-
-def getGradNorm(net):
- pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters()))
- gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters()))
- return pNorm, gradNorm
-
-
-def weights_init(m):
- """
- xavier initialization
- """
- classname = m.__class__.__name__
- if classname.find('Conv') != -1 and m.weight is not None:
- torch.nn.init.xavier_normal_(m.weight)
- elif classname.find('BatchNorm') != -1:
- m.weight.data.normal_()
- m.bias.data.fill_(0)
-
-'''
-models
-'''
-def normal_kl(mean1, logvar1, mean2, logvar2):
- """
- KL divergence between normal distributions parameterized by mean and log-variance.
- """
- return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
- + (mean1 - mean2)**2 * torch.exp(-logvar2))
-
-def discretized_gaussian_log_likelihood(x, *, means, log_scales):
- # Assumes data is integers [0, 1]
- assert x.shape == means.shape == log_scales.shape
- px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
-
- centered_x = x - means
- inv_stdv = torch.exp(-log_scales)
- plus_in = inv_stdv * (centered_x + 0.5)
- cdf_plus = px0.cdf(plus_in)
- min_in = inv_stdv * (centered_x - .5)
- cdf_min = px0.cdf(min_in)
- log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
- log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
- cdf_delta = cdf_plus - cdf_min
-
- log_probs = torch.where(
- x < 0.001, log_cdf_plus,
- torch.where(x > 0.999, log_one_minus_cdf_min,
- torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
- assert log_probs.shape == x.shape
- return log_probs
-
-class GaussianDiffusion:
- def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
- self.loss_type = loss_type
- self.model_mean_type = model_mean_type
- self.model_var_type = model_var_type
- assert isinstance(betas, np.ndarray)
- self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
- assert (betas > 0).all() and (betas <= 1).all()
- timesteps, = betas.shape
- self.num_timesteps = int(timesteps)
- self.sv_points = sv_points
- # initialize twice the actual length so we can keep running for eval
- # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
-
- alphas = 1. - betas
- alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
- alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
-
- self.betas = torch.from_numpy(betas).float()
- self.alphas_cumprod = alphas_cumprod.float()
- self.alphas_cumprod_prev = alphas_cumprod_prev.float()
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
- self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
- self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
- self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
- self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
-
- betas = torch.from_numpy(betas).float()
- alphas = torch.from_numpy(alphas).float()
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.posterior_variance = posterior_variance
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
- self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
- self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
-
- @staticmethod
- def _extract(a, t, x_shape):
- """
- Extract some coefficients at specified timesteps,
- then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
- """
- bs, = t.shape
- assert x_shape[0] == bs
- out = torch.gather(a, 0, t)
- assert out.shape == torch.Size([bs])
- return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
-
-
-
- def q_mean_variance(self, x_start, t):
- mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
- variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
- log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
- return mean, variance, log_variance
-
- def q_sample(self, x_start, t, noise=None):
- """
- Diffuse the data (t == 0 means diffused for 1 step)
- """
- if noise is None:
- noise = torch.randn(x_start.shape, device=x_start.device)
- assert noise.shape == x_start.shape
- return (
- self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
- self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
- )
-
-
- def q_posterior_mean_variance(self, x_start, x_t, t):
- """
- Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
- """
- assert x_start.shape == x_t.shape
- posterior_mean = (
- self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
- self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
- )
- posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
- posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
- assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
- x_start.shape[0])
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
-
- def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
-
- model_output = denoise_fn(data, t)[:,:,self.sv_points:]
-
-
- if self.model_var_type in ['fixedsmall', 'fixedlarge']:
- # below: only log_variance is used in the KL computations
- model_variance, model_log_variance = {
- # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
- 'fixedlarge': (self.betas.to(data.device),
- torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
- 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
- }[self.model_var_type]
- model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
- model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
- else:
- raise NotImplementedError(self.model_var_type)
-
- if self.model_mean_type == 'eps':
- x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
-
-
- model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
- else:
- raise NotImplementedError(self.loss_type)
-
-
- assert model_mean.shape == x_recon.shape
- assert model_variance.shape == model_log_variance.shape
- if return_pred_xstart:
- return model_mean, model_variance, model_log_variance, x_recon
- else:
- return model_mean, model_variance, model_log_variance
-
- def _predict_xstart_from_eps(self, x_t, t, eps):
- assert x_t.shape == eps.shape
- return (
- self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
- )
-
- ''' samples '''
-
- def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
- """
- Sample from the model
- """
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
- return_pred_xstart=True)
- noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
-
- # no noise when t == 0
- nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
-
- sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
- sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
- return (sample, pred_xstart) if return_pred_xstart else sample
-
-
- def p_sample_loop(self, partial_x, denoise_fn, shape, device,
- noise_fn=torch.randn, clip_denoised=True, keep_running=False):
- """
- Generate samples
- keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
-
- """
-
- assert isinstance(shape, (tuple, list))
- img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
- for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised, return_pred_xstart=False)
-
- assert img_t[:,:,self.sv_points:].shape == shape
- return img_t
-
- def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
- noise_fn=torch.randn,clip_denoised=True, keep_running=False):
- """
- Generate samples, returning intermediate images
- Useful for visualizing how denoised images evolve over time
- Args:
- repeat_noise_steps (int): Number of denoising timesteps in which the same noise
- is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
- """
- assert isinstance(shape, (tuple, list))
-
- total_steps = self.num_timesteps if not keep_running else len(self.betas)
-
- img_t = noise_fn(size=shape, dtype=torch.float, device=device)
- imgs = [img_t]
- for t in reversed(range(0,total_steps)):
-
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- return_pred_xstart=False)
- if t % freq == 0 or t == total_steps-1:
- imgs.append(img_t)
-
- assert imgs[-1].shape == shape
- return imgs
-
- '''losses'''
-
- def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
- true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
- denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
-
- kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
- kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
-
- return (kl, pred_xstart) if return_pred_xstart else kl
-
- def p_losses(self, denoise_fn, data_start, t, noise=None):
- """
- Training loss calculation
- """
- B, D, N = data_start.shape
- assert t.shape == torch.Size([B])
-
- if noise is None:
- noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
-
- data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
-
- if self.loss_type == 'mse':
- # predict the noise instead of x_start. seems to be weighted naturally like SNR
- eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
-
- losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
- elif self.loss_type == 'kl':
- losses = self._vb_terms_bpd(
- denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
- return_pred_xstart=False)
- else:
- raise NotImplementedError(self.loss_type)
-
- assert losses.shape == torch.Size([B])
- return losses
-
- '''debug'''
-
- def _prior_bpd(self, x_start):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
- t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
- mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
- assert kl_prior.shape == x_start.shape
- return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
-
- def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
-
- vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
- for t in reversed(range(T)):
-
- t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
- # Calculate VLB term at the current timestep
- data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
- new_vals_b, pred_xstart = self._vb_terms_bpd(
- denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
- clip_denoised=clip_denoised, return_pred_xstart=True)
- # MSE for progressive prediction loss
- assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
- new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
- assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
- # Insert the calculated term into the tensor of all terms
- mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
- vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
- mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
- assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
-
- prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
- total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
- assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
- total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
- return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
-
-
-class PVCNN2(PVCNN2Base):
- sa_blocks = [
- ((32, 2, 32), (1024, 0.1, 32, (32, 64))),
- ((64, 3, 16), (256, 0.2, 32, (64, 128))),
- ((128, 3, 8), (64, 0.4, 32, (128, 256))),
- (None, (16, 0.8, 32, (256, 256, 512))),
- ]
- fp_blocks = [
- ((256, 256), (256, 3, 8)),
- ((256, 256), (256, 3, 8)),
- ((256, 128), (128, 2, 16)),
- ((128, 128, 64), (64, 2, 32)),
- ]
-
- def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
- voxel_resolution_multiplier=1):
- super().__init__(
- num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
- dropout=dropout, extra_feature_channels=extra_feature_channels,
- width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
- )
-
-
-class Model(nn.Module):
- def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
- super(Model, self).__init__()
- self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
-
- self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
- dropout=args.dropout, extra_feature_channels=0)
-
- def prior_kl(self, x0):
- return self.diffusion._prior_bpd(x0)
-
- def all_kl(self, x0, clip_denoised=True):
- total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
-
- return {
- 'total_bpd_b': total_bpd_b,
- 'terms_bpd': vals_bt,
- 'prior_bpd_b': prior_bpd_b,
- 'mse_bt':mse_bt
- }
-
-
- def _denoise(self, data, t):
- B, D,N= data.shape
- assert data.dtype == torch.float
- assert t.shape == torch.Size([B]) and t.dtype == torch.int64
-
- out = self.model(data, t)
-
- return out
-
- def get_loss_iter(self, data, noises=None):
- B, D, N = data.shape
- t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
-
- if noises is not None:
- noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
-
- losses = self.diffusion.p_losses(
- denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
- assert losses.shape == t.shape == torch.Size([B])
- return losses
-
- def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
- clip_denoised=True,
- keep_running=False):
- return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- keep_running=keep_running)
-
-
- def train(self):
- self.model.train()
-
- def eval(self):
- self.model.eval()
-
- def multi_gpu_wrapper(self, f):
- self.model = f(self.model)
-
-
-def get_betas(schedule_type, b_start, b_end, time_num):
- if schedule_type == 'linear':
- betas = np.linspace(b_start, b_end, time_num)
- elif schedule_type == 'warm0.1':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.1)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.2':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.2)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.5':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.5)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- else:
- raise NotImplementedError(schedule_type)
- return betas
-
-
-def get_dataset(data_root, data_raw_root, pc_dataroot, npoints, category):
-
- train_ds = GANdatasetPartNet('train', data_root, category, npoints)
- return train_ds
-
-
-def get_dataloader(opt, train_dataset, test_dataset=None):
-
- if opt.distribution_type == 'multi':
- train_sampler = torch.utils.data.distributed.DistributedSampler(
- train_dataset,
- num_replicas=opt.world_size,
- rank=opt.rank
- )
- if test_dataset is not None:
- test_sampler = torch.utils.data.distributed.DistributedSampler(
- test_dataset,
- num_replicas=opt.world_size,
- rank=opt.rank
- )
- else:
- test_sampler = None
- else:
- train_sampler = None
- test_sampler = None
-
- train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler,
- shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True)
-
- if test_dataset is not None:
- test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
- else:
- test_dataloader = None
-
- return train_dataloader, test_dataloader, train_sampler, test_sampler
-
-
-def train(gpu, opt, output_dir, noises_init):
-
- set_seed(opt)
- logger = setup_logging(output_dir)
- if opt.distribution_type == 'multi':
- should_diag = gpu==0
- else:
- should_diag = True
- if should_diag:
- outf_syn, = setup_output_subdirs(output_dir, 'syn')
-
- if opt.distribution_type == 'multi':
- if opt.dist_url == "env://" and opt.rank == -1:
- opt.rank = int(os.environ["RANK"])
-
- base_rank = opt.rank * opt.ngpus_per_node
- opt.rank = base_rank + gpu
- dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,
- world_size=opt.world_size, rank=opt.rank)
-
- opt.bs = int(opt.bs / opt.ngpus_per_node)
- opt.workers = 0
-
- opt.saveIter = int(opt.saveIter / opt.ngpus_per_node)
- opt.diagIter = int(opt.diagIter / opt.ngpus_per_node)
- opt.vizIter = int(opt.vizIter / opt.ngpus_per_node)
-
-
- ''' data '''
- train_dataset = get_dataset(opt.data_root, opt.data_raw_root,opt.pc_dataroot, opt.npoints,opt.classes)
- dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None)
-
-
- '''
- create networks
- '''
-
- betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
- netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
-
- if opt.distribution_type == 'multi': # Multiple processes, single GPU per process
- def _transform_(m):
- return nn.parallel.DistributedDataParallel(
- m, device_ids=[gpu], output_device=gpu)
-
- torch.cuda.set_device(gpu)
- netE.cuda(gpu)
- netE.multi_gpu_wrapper(_transform_)
-
-
- elif opt.distribution_type == 'single':
- def _transform_(m):
- return nn.parallel.DataParallel(m)
- netE = netE.cuda()
- netE.multi_gpu_wrapper(_transform_)
-
- elif gpu is not None:
- torch.cuda.set_device(gpu)
- netE = netE.cuda(gpu)
- else:
- raise ValueError('distribution_type = multi | single | None')
-
- if should_diag:
- logger.info(opt)
-
- optimizer= optim.Adam(netE.parameters(), lr=opt.lrE, weight_decay=opt.e_decay, betas=(opt.beta1, 0.999))
-
- lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.e_gamma)
-
- if opt.netE != '':
- ckpt = torch.load(opt.netE)
- netE.load_state_dict(ckpt['model_state'])
- optimizer.load_state_dict(ckpt['optimizer_state'])
-
- if opt.netE != '':
- start_epoch = torch.load(opt.netE)['epoch'] + 1
- else:
- start_epoch = 0
-
-
- for epoch in range(start_epoch, opt.niter):
-
- if opt.distribution_type == 'multi':
- train_sampler.set_epoch(epoch)
-
- lr_scheduler.step(epoch)
-
- for i, data in enumerate(dataloader):
-
- x = data['real']
- sv_x = data['raw']
-
- sv_x = torch.cat([sv_x, x[:, :, opt.svpoints:]], dim=-1)
- noises_batch = noises_init[data['idx']]
-
- '''
- train diffusion
- '''
-
- if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None):
- sv_x = sv_x.cuda(gpu)
- noises_batch = noises_batch.cuda(gpu)
- elif opt.distribution_type == 'single':
- sv_x = sv_x.cuda()
- noises_batch = noises_batch.cuda()
-
- loss = netE.get_loss_iter(sv_x, noises_batch).mean()
-
- optimizer.zero_grad()
- loss.backward()
- netpNorm, netgradNorm = getGradNorm(netE)
- if opt.grad_clip is not None:
- torch.nn.utils.clip_grad_norm_(netE.parameters(), opt.grad_clip)
-
- optimizer.step()
-
-
- if i % opt.print_freq == 0 and should_diag:
-
- logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, '
- 'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} '
- .format(
- epoch, opt.niter, i, len(dataloader),loss.item(),
- netpNorm, netgradNorm,
- ))
-
- if (epoch + 1) % opt.diagIter == 0 and should_diag:
-
- logger.info('Diagnosis:')
-
- x_range = [x.min().item(), x.max().item()]
- kl_stats = netE.all_kl(sv_x)
- logger.info(' [{:>3d}/{:>3d}] '
- 'x_range: [{:>10.4f}, {:>10.4f}], '
- 'total_bpd_b: {:>10.4f}, '
- 'terms_bpd: {:>10.4f}, '
- 'prior_bpd_b: {:>10.4f} '
- 'mse_bt: {:>10.4f} '
- .format(
- epoch, opt.niter,
- *x_range,
- kl_stats['total_bpd_b'].item(),
- kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item()
- ))
-
-
-
- if (epoch + 1) % opt.vizIter == 0 and should_diag:
- logger.info('Generation: eval')
-
- netE.eval()
- m, s = torch.tensor([[0,0,0]]).transpose(0,1)[None], torch.tensor([[1,1,1]]).transpose(0,1)[None]#train_dataset.get_pc_stats(0)
-
- with torch.no_grad():
-
- x_gen_eval = netE.gen_samples(sv_x[:,:,:opt.svpoints], sv_x[:,:,opt.svpoints:].shape, sv_x.device, clip_denoised=False).detach().cpu()
-
-
- gen_stats = [x_gen_eval.mean(), x_gen_eval.std()]
- gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()]
-
- logger.info(' [{:>3d}/{:>3d}] '
- 'eval_gen_range: [{:>10.4f}, {:>10.4f}] '
- 'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] '
- .format(
- epoch, opt.niter,
- *gen_eval_range, *gen_stats,
- ))
-
- export_to_pc_batch('%s/epoch_%03d_samples_eval' % (outf_syn, epoch),
- (x_gen_eval*s+m).transpose(1, 2).numpy()*3)
-
- export_to_pc_batch('%s/epoch_%03d_ground_truth' % (outf_syn, epoch),
- (sv_x.detach().cpu()*s+m).transpose(1, 2).numpy()*3)
-
- export_to_pc_batch('%s/epoch_%03d_partial' % (outf_syn, epoch),
- (sv_x[:,:,:opt.svpoints].detach().cpu()*s+m).transpose(1, 2).numpy()*3)
-
-
- netE.train()
-
-
-
-
-
-
-
- if (epoch + 1) % opt.saveIter == 0:
-
- if should_diag:
-
-
- save_dict = {
- 'epoch': epoch,
- 'model_state': netE.state_dict(),
- 'optimizer_state': optimizer.state_dict()
- }
-
- torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch))
-
-
- if opt.distribution_type == 'multi':
- dist.barrier()
- map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu}
- netE.load_state_dict(
- torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state'])
-
- dist.destroy_process_group()
-
-def main():
- opt = parse_args()
-
- exp_id = os.path.splitext(os.path.basename(__file__))[0]
- dir_id = os.path.dirname(__file__)
- output_dir = get_output_dir(dir_id, exp_id)
- copy_source(__file__, output_dir)
-
- ''' workaround '''
-
- train_dataset = get_dataset(opt.data_root, opt.data_raw_root, opt.pc_dataroot, opt.npoints,opt.classes)
- noises_init = torch.randn(len(train_dataset), opt.nc, opt.npoints-opt.svpoints)
-
- if opt.dist_url == "env://" and opt.world_size == -1:
- opt.world_size = int(os.environ["WORLD_SIZE"])
-
- if opt.distribution_type == 'multi':
- opt.ngpus_per_node = torch.cuda.device_count()
- opt.world_size = opt.ngpus_per_node * opt.world_size
- mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init))
- else:
- train(opt.gpu, opt, output_dir, noises_init)
-
-
-
-def parse_args():
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--data_root', default='/home/ubuntu/01DATA/partnet/data_v0', help='input batch size')
- parser.add_argument('--data_raw_root', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_dim32_sdf_pc',
- help='input batch size')
- parser.add_argument('--pc_dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k',
- help='input batch size')
- parser.add_argument('--classes', default='Chair')
-
- parser.add_argument('--bs', type=int, default=64, help='input batch size')
- parser.add_argument('--workers', type=int, default=16, help='workers')
- parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
-
- parser.add_argument('--nc', default=3)
- parser.add_argument('--npoints', default=2048)
- parser.add_argument('--svpoints', default=1024)
- '''model'''
- parser.add_argument('--beta_start', default=0.0001)
- parser.add_argument('--beta_end', default=0.02)
- parser.add_argument('--schedule_type', default='linear')
- parser.add_argument('--time_num', default=1000)
-
- #params
- parser.add_argument('--attention', default=True)
- parser.add_argument('--dropout', default=0.1)
- parser.add_argument('--embed_dim', type=int, default=64)
- parser.add_argument('--loss_type', default='mse')
- parser.add_argument('--model_mean_type', default='eps')
- parser.add_argument('--model_var_type', default='fixedsmall')
-
- parser.add_argument('--lrE', type=float, default=2e-4, help='learning rate for E, default=0.0002')
- parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
- parser.add_argument('--e_decay', type=float, default=0, help='weight decay for EBM')
- parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM')
- parser.add_argument('--e_gamma', type=float, default=0.998, help='lr decay for EBM')
-
- parser.add_argument('--netE', default='', help="path to netE (to continue training)")
-
-
- '''distributed'''
- parser.add_argument('--world_size', default=1, type=int,
- help='Number of distributed nodes.')
- parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str,
- help='url used to set up distributed training')
- parser.add_argument('--dist_backend', default='nccl', type=str,
- help='distributed backend')
- parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None],
- help='Use multi-processing distributed training to launch '
- 'N processes per node, which has N GPUs. This is the '
- 'fastest way to use PyTorch for either single node or '
- 'multi node data parallel training')
- parser.add_argument('--rank', default=0, type=int,
- help='node rank for distributed training')
- parser.add_argument('--gpu', default=None, type=int,
- help='GPU id to use. None means using all available GPUs.')
-
- '''eval'''
- parser.add_argument('--saveIter', default=100, type=int,help='unit: epoch')
- parser.add_argument('--diagIter', default=50, type=int, help='unit: epoch')
- parser.add_argument('--vizIter', default=50, type=int,help='unit: epoch')
- parser.add_argument('--print_freq', default=50, type=int,help='unit: iter')
-
- parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
-
-
- opt = parser.parse_args()
-
- return opt
-
-if __name__ == '__main__':
- main()
diff --git a/shape_completion/13_res32_partnet_table_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py b/shape_completion/13_res32_partnet_table_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py
deleted file mode 100644
index f7a39fa..0000000
--- a/shape_completion/13_res32_partnet_table_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py
+++ /dev/null
@@ -1,825 +0,0 @@
-import torch.multiprocessing as mp
-import torch.nn as nn
-import torch.optim as optim
-import torch.utils.data
-
-import argparse
-from model.unet import get_model
-from torch.distributions import Normal
-
-from utils.file_utils import *
-from utils.visualize import *
-from model.pvcnn_completion import PVCNN2Base
-import torch.distributed as dist
-from datasets.partnet import GANdatasetPartNet
-from datasets.shapenet_data_pc import ShapeNet15kPointClouds
-'''
-some utils
-'''
-def rotation_matrix(axis, theta):
- """
- Return the rotation matrix associated with counterclockwise rotation about
- the given axis by theta radians.
- """
- axis = np.asarray(axis)
- axis = axis / np.sqrt(np.dot(axis, axis))
- a = np.cos(theta / 2.0)
- b, c, d = -axis * np.sin(theta / 2.0)
- aa, bb, cc, dd = a * a, b * b, c * c, d * d
- bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
- return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
- [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
- [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
-
-def rotate(vertices, faces):
- '''
- vertices: [numpoints, 3]
- '''
- M = rotation_matrix([0, 1, 0], np.pi / 2).transpose()
- N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose()
- K = rotation_matrix([0, 0, 1], np.pi).transpose()
-
- v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]]
- return v, f
-
-def norm(v, f):
- v = (v - v.min())/(v.max() - v.min()) - 0.5
-
- return v, f
-
-def getGradNorm(net):
- pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters()))
- gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters()))
- return pNorm, gradNorm
-
-
-def weights_init(m):
- """
- xavier initialization
- """
- classname = m.__class__.__name__
- if classname.find('Conv') != -1 and m.weight is not None:
- torch.nn.init.xavier_normal_(m.weight)
- elif classname.find('BatchNorm') != -1:
- m.weight.data.normal_()
- m.bias.data.fill_(0)
-
-'''
-models
-'''
-def normal_kl(mean1, logvar1, mean2, logvar2):
- """
- KL divergence between normal distributions parameterized by mean and log-variance.
- """
- return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
- + (mean1 - mean2)**2 * torch.exp(-logvar2))
-
-def discretized_gaussian_log_likelihood(x, *, means, log_scales):
- # Assumes data is integers [0, 1]
- assert x.shape == means.shape == log_scales.shape
- px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
-
- centered_x = x - means
- inv_stdv = torch.exp(-log_scales)
- plus_in = inv_stdv * (centered_x + 0.5)
- cdf_plus = px0.cdf(plus_in)
- min_in = inv_stdv * (centered_x - .5)
- cdf_min = px0.cdf(min_in)
- log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
- log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
- cdf_delta = cdf_plus - cdf_min
-
- log_probs = torch.where(
- x < 0.001, log_cdf_plus,
- torch.where(x > 0.999, log_one_minus_cdf_min,
- torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
- assert log_probs.shape == x.shape
- return log_probs
-
-class GaussianDiffusion:
- def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
- self.loss_type = loss_type
- self.model_mean_type = model_mean_type
- self.model_var_type = model_var_type
- assert isinstance(betas, np.ndarray)
- self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
- assert (betas > 0).all() and (betas <= 1).all()
- timesteps, = betas.shape
- self.num_timesteps = int(timesteps)
- self.sv_points = sv_points
- # initialize twice the actual length so we can keep running for eval
- # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
-
- alphas = 1. - betas
- alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
- alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
-
- self.betas = torch.from_numpy(betas).float()
- self.alphas_cumprod = alphas_cumprod.float()
- self.alphas_cumprod_prev = alphas_cumprod_prev.float()
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
- self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
- self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
- self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
- self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
-
- betas = torch.from_numpy(betas).float()
- alphas = torch.from_numpy(alphas).float()
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.posterior_variance = posterior_variance
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
- self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
- self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
-
- @staticmethod
- def _extract(a, t, x_shape):
- """
- Extract some coefficients at specified timesteps,
- then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
- """
- bs, = t.shape
- assert x_shape[0] == bs
- out = torch.gather(a, 0, t)
- assert out.shape == torch.Size([bs])
- return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
-
-
-
- def q_mean_variance(self, x_start, t):
- mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
- variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
- log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
- return mean, variance, log_variance
-
- def q_sample(self, x_start, t, noise=None):
- """
- Diffuse the data (t == 0 means diffused for 1 step)
- """
- if noise is None:
- noise = torch.randn(x_start.shape, device=x_start.device)
- assert noise.shape == x_start.shape
- return (
- self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
- self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
- )
-
-
- def q_posterior_mean_variance(self, x_start, x_t, t):
- """
- Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
- """
- assert x_start.shape == x_t.shape
- posterior_mean = (
- self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
- self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
- )
- posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
- posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
- assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
- x_start.shape[0])
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
-
- def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
-
- model_output = denoise_fn(data, t)[:,:,self.sv_points:]
-
-
- if self.model_var_type in ['fixedsmall', 'fixedlarge']:
- # below: only log_variance is used in the KL computations
- model_variance, model_log_variance = {
- # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
- 'fixedlarge': (self.betas.to(data.device),
- torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
- 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
- }[self.model_var_type]
- model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
- model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
- else:
- raise NotImplementedError(self.model_var_type)
-
- if self.model_mean_type == 'eps':
- x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
-
-
- model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
- else:
- raise NotImplementedError(self.loss_type)
-
-
- assert model_mean.shape == x_recon.shape
- assert model_variance.shape == model_log_variance.shape
- if return_pred_xstart:
- return model_mean, model_variance, model_log_variance, x_recon
- else:
- return model_mean, model_variance, model_log_variance
-
- def _predict_xstart_from_eps(self, x_t, t, eps):
- assert x_t.shape == eps.shape
- return (
- self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
- )
-
- ''' samples '''
-
- def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
- """
- Sample from the model
- """
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
- return_pred_xstart=True)
- noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
-
- # no noise when t == 0
- nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
-
- sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
- sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
- return (sample, pred_xstart) if return_pred_xstart else sample
-
-
- def p_sample_loop(self, partial_x, denoise_fn, shape, device,
- noise_fn=torch.randn, clip_denoised=True, keep_running=False):
- """
- Generate samples
- keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
-
- """
-
- assert isinstance(shape, (tuple, list))
- img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
- for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised, return_pred_xstart=False)
-
- assert img_t[:,:,self.sv_points:].shape == shape
- return img_t
-
- def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
- noise_fn=torch.randn,clip_denoised=True, keep_running=False):
- """
- Generate samples, returning intermediate images
- Useful for visualizing how denoised images evolve over time
- Args:
- repeat_noise_steps (int): Number of denoising timesteps in which the same noise
- is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
- """
- assert isinstance(shape, (tuple, list))
-
- total_steps = self.num_timesteps if not keep_running else len(self.betas)
-
- img_t = noise_fn(size=shape, dtype=torch.float, device=device)
- imgs = [img_t]
- for t in reversed(range(0,total_steps)):
-
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- return_pred_xstart=False)
- if t % freq == 0 or t == total_steps-1:
- imgs.append(img_t)
-
- assert imgs[-1].shape == shape
- return imgs
-
- '''losses'''
-
- def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
- true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
- denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
-
- kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
- kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
-
- return (kl, pred_xstart) if return_pred_xstart else kl
-
- def p_losses(self, denoise_fn, data_start, t, noise=None):
- """
- Training loss calculation
- """
- B, D, N = data_start.shape
- assert t.shape == torch.Size([B])
-
- if noise is None:
- noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
-
- data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
-
- if self.loss_type == 'mse':
- # predict the noise instead of x_start. seems to be weighted naturally like SNR
- eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
-
- losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
- elif self.loss_type == 'kl':
- losses = self._vb_terms_bpd(
- denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
- return_pred_xstart=False)
- else:
- raise NotImplementedError(self.loss_type)
-
- assert losses.shape == torch.Size([B])
- return losses
-
- '''debug'''
-
- def _prior_bpd(self, x_start):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
- t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
- mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
- assert kl_prior.shape == x_start.shape
- return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
-
- def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
-
- vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
- for t in reversed(range(T)):
-
- t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
- # Calculate VLB term at the current timestep
- data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
- new_vals_b, pred_xstart = self._vb_terms_bpd(
- denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
- clip_denoised=clip_denoised, return_pred_xstart=True)
- # MSE for progressive prediction loss
- assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
- new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
- assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
- # Insert the calculated term into the tensor of all terms
- mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
- vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
- mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
- assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
-
- prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
- total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
- assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
- total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
- return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
-
-
-class PVCNN2(PVCNN2Base):
- sa_blocks = [
- ((32, 2, 32), (1024, 0.1, 32, (32, 64))),
- ((64, 3, 16), (256, 0.2, 32, (64, 128))),
- ((128, 3, 8), (64, 0.4, 32, (128, 256))),
- (None, (16, 0.8, 32, (256, 256, 512))),
- ]
- fp_blocks = [
- ((256, 256), (256, 3, 8)),
- ((256, 256), (256, 3, 8)),
- ((256, 128), (128, 2, 16)),
- ((128, 128, 64), (64, 2, 32)),
- ]
-
- def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
- voxel_resolution_multiplier=1):
- super().__init__(
- num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
- dropout=dropout, extra_feature_channels=extra_feature_channels,
- width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
- )
-
-
-class Model(nn.Module):
- def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
- super(Model, self).__init__()
- self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
-
- self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
- dropout=args.dropout, extra_feature_channels=0)
-
- def prior_kl(self, x0):
- return self.diffusion._prior_bpd(x0)
-
- def all_kl(self, x0, clip_denoised=True):
- total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
-
- return {
- 'total_bpd_b': total_bpd_b,
- 'terms_bpd': vals_bt,
- 'prior_bpd_b': prior_bpd_b,
- 'mse_bt':mse_bt
- }
-
-
- def _denoise(self, data, t):
- B, D,N= data.shape
- assert data.dtype == torch.float
- assert t.shape == torch.Size([B]) and t.dtype == torch.int64
-
- out = self.model(data, t)
-
- return out
-
- def get_loss_iter(self, data, noises=None):
- B, D, N = data.shape
- t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
-
- if noises is not None:
- noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
-
- losses = self.diffusion.p_losses(
- denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
- assert losses.shape == t.shape == torch.Size([B])
- return losses
-
- def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
- clip_denoised=True,
- keep_running=False):
- return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- keep_running=keep_running)
-
-
- def train(self):
- self.model.train()
-
- def eval(self):
- self.model.eval()
-
- def multi_gpu_wrapper(self, f):
- self.model = f(self.model)
-
-
-def get_betas(schedule_type, b_start, b_end, time_num):
- if schedule_type == 'linear':
- betas = np.linspace(b_start, b_end, time_num)
- elif schedule_type == 'warm0.1':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.1)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.2':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.2)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.5':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.5)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- else:
- raise NotImplementedError(schedule_type)
- return betas
-
-
-def get_dataset(data_root, data_raw_root, pc_dataroot, npoints, category):
-
- train_ds = GANdatasetPartNet('train', data_root, category, npoints)
- return train_ds
-
-
-def get_dataloader(opt, train_dataset, test_dataset=None):
-
- if opt.distribution_type == 'multi':
- train_sampler = torch.utils.data.distributed.DistributedSampler(
- train_dataset,
- num_replicas=opt.world_size,
- rank=opt.rank
- )
- if test_dataset is not None:
- test_sampler = torch.utils.data.distributed.DistributedSampler(
- test_dataset,
- num_replicas=opt.world_size,
- rank=opt.rank
- )
- else:
- test_sampler = None
- else:
- train_sampler = None
- test_sampler = None
-
- train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler,
- shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True)
-
- if test_dataset is not None:
- test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
- else:
- test_dataloader = None
-
- return train_dataloader, test_dataloader, train_sampler, test_sampler
-
-
-def train(gpu, opt, output_dir, noises_init):
-
- set_seed(opt)
- logger = setup_logging(output_dir)
- if opt.distribution_type == 'multi':
- should_diag = gpu==0
- else:
- should_diag = True
- if should_diag:
- outf_syn, = setup_output_subdirs(output_dir, 'syn')
-
- if opt.distribution_type == 'multi':
- if opt.dist_url == "env://" and opt.rank == -1:
- opt.rank = int(os.environ["RANK"])
-
- base_rank = opt.rank * opt.ngpus_per_node
- opt.rank = base_rank + gpu
- dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,
- world_size=opt.world_size, rank=opt.rank)
-
- opt.bs = int(opt.bs / opt.ngpus_per_node)
- opt.workers = 0
-
- opt.saveIter = int(opt.saveIter / opt.ngpus_per_node)
- opt.diagIter = int(opt.diagIter / opt.ngpus_per_node)
- opt.vizIter = int(opt.vizIter / opt.ngpus_per_node)
-
-
- ''' data '''
- train_dataset = get_dataset(opt.data_root, opt.data_raw_root,opt.pc_dataroot, opt.npoints,opt.classes)
- dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None)
-
-
- '''
- create networks
- '''
-
- betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
- netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
-
- if opt.distribution_type == 'multi': # Multiple processes, single GPU per process
- def _transform_(m):
- return nn.parallel.DistributedDataParallel(
- m, device_ids=[gpu], output_device=gpu)
-
- torch.cuda.set_device(gpu)
- netE.cuda(gpu)
- netE.multi_gpu_wrapper(_transform_)
-
-
- elif opt.distribution_type == 'single':
- def _transform_(m):
- return nn.parallel.DataParallel(m)
- netE = netE.cuda()
- netE.multi_gpu_wrapper(_transform_)
-
- elif gpu is not None:
- torch.cuda.set_device(gpu)
- netE = netE.cuda(gpu)
- else:
- raise ValueError('distribution_type = multi | single | None')
-
- if should_diag:
- logger.info(opt)
-
- optimizer= optim.Adam(netE.parameters(), lr=opt.lrE, weight_decay=opt.e_decay, betas=(opt.beta1, 0.999))
-
- lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.e_gamma)
-
- if opt.netE != '':
- ckpt = torch.load(opt.netE)
- netE.load_state_dict(ckpt['model_state'])
- optimizer.load_state_dict(ckpt['optimizer_state'])
-
- if opt.netE != '':
- start_epoch = torch.load(opt.netE)['epoch'] + 1
- else:
- start_epoch = 0
-
-
- for epoch in range(start_epoch, opt.niter):
-
- if opt.distribution_type == 'multi':
- train_sampler.set_epoch(epoch)
-
- lr_scheduler.step(epoch)
-
- for i, data in enumerate(dataloader):
-
- x = data['real']
- sv_x = data['raw']
-
- sv_x = torch.cat([sv_x, x[:, :, opt.svpoints:]], dim=-1)
- noises_batch = noises_init[data['idx']]
-
- '''
- train diffusion
- '''
-
- if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None):
- sv_x = sv_x.cuda(gpu)
- noises_batch = noises_batch.cuda(gpu)
- elif opt.distribution_type == 'single':
- sv_x = sv_x.cuda()
- noises_batch = noises_batch.cuda()
-
- loss = netE.get_loss_iter(sv_x, noises_batch).mean()
-
- optimizer.zero_grad()
- loss.backward()
- netpNorm, netgradNorm = getGradNorm(netE)
- if opt.grad_clip is not None:
- torch.nn.utils.clip_grad_norm_(netE.parameters(), opt.grad_clip)
-
- optimizer.step()
-
-
- if i % opt.print_freq == 0 and should_diag:
-
- logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, '
- 'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} '
- .format(
- epoch, opt.niter, i, len(dataloader),loss.item(),
- netpNorm, netgradNorm,
- ))
-
- if (epoch + 1) % opt.diagIter == 0 and should_diag:
-
- logger.info('Diagnosis:')
-
- x_range = [x.min().item(), x.max().item()]
- kl_stats = netE.all_kl(sv_x)
- logger.info(' [{:>3d}/{:>3d}] '
- 'x_range: [{:>10.4f}, {:>10.4f}], '
- 'total_bpd_b: {:>10.4f}, '
- 'terms_bpd: {:>10.4f}, '
- 'prior_bpd_b: {:>10.4f} '
- 'mse_bt: {:>10.4f} '
- .format(
- epoch, opt.niter,
- *x_range,
- kl_stats['total_bpd_b'].item(),
- kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item()
- ))
-
-
-
- if (epoch + 1) % opt.vizIter == 0 and should_diag:
- logger.info('Generation: eval')
-
- netE.eval()
- m, s = torch.tensor([[0,0,0]]).transpose(0,1)[None], torch.tensor([[1,1,1]]).transpose(0,1)[None]#train_dataset.get_pc_stats(0)
-
- with torch.no_grad():
-
- x_gen_eval = netE.gen_samples(sv_x[:,:,:opt.svpoints], sv_x[:,:,opt.svpoints:].shape, sv_x.device, clip_denoised=False).detach().cpu()
-
-
- gen_stats = [x_gen_eval.mean(), x_gen_eval.std()]
- gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()]
-
- logger.info(' [{:>3d}/{:>3d}] '
- 'eval_gen_range: [{:>10.4f}, {:>10.4f}] '
- 'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] '
- .format(
- epoch, opt.niter,
- *gen_eval_range, *gen_stats,
- ))
-
- export_to_pc_batch('%s/epoch_%03d_samples_eval' % (outf_syn, epoch),
- (x_gen_eval*s+m).transpose(1, 2).numpy()*3)
-
- export_to_pc_batch('%s/epoch_%03d_ground_truth' % (outf_syn, epoch),
- (sv_x.detach().cpu()*s+m).transpose(1, 2).numpy()*3)
-
- export_to_pc_batch('%s/epoch_%03d_partial' % (outf_syn, epoch),
- (sv_x[:,:,:opt.svpoints].detach().cpu()*s+m).transpose(1, 2).numpy()*3)
-
-
- netE.train()
-
-
-
-
-
-
-
- if (epoch + 1) % opt.saveIter == 0:
-
- if should_diag:
-
-
- save_dict = {
- 'epoch': epoch,
- 'model_state': netE.state_dict(),
- 'optimizer_state': optimizer.state_dict()
- }
-
- torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch))
-
-
- if opt.distribution_type == 'multi':
- dist.barrier()
- map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu}
- netE.load_state_dict(
- torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state'])
-
- dist.destroy_process_group()
-
-def main():
- opt = parse_args()
-
- exp_id = os.path.splitext(os.path.basename(__file__))[0]
- dir_id = os.path.dirname(__file__)
- output_dir = get_output_dir(dir_id, exp_id)
- copy_source(__file__, output_dir)
-
- ''' workaround '''
-
- train_dataset = get_dataset(opt.data_root, opt.data_raw_root, opt.pc_dataroot, opt.npoints,opt.classes)
- noises_init = torch.randn(len(train_dataset), opt.nc, opt.npoints-opt.svpoints)
-
- if opt.dist_url == "env://" and opt.world_size == -1:
- opt.world_size = int(os.environ["WORLD_SIZE"])
-
- if opt.distribution_type == 'multi':
- opt.ngpus_per_node = torch.cuda.device_count()
- opt.world_size = opt.ngpus_per_node * opt.world_size
- mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init))
- else:
- train(opt.gpu, opt, output_dir, noises_init)
-
-
-
-def parse_args():
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--data_root', default='/home/ubuntu/01DATA/partnet/', help='input batch size')
- parser.add_argument('--data_raw_root', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_dim32_sdf_pc',
- help='input batch size')
- parser.add_argument('--pc_dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k',
- help='input batch size')
- parser.add_argument('--classes', default='Table')
-
- parser.add_argument('--bs', type=int, default=64, help='input batch size')
- parser.add_argument('--workers', type=int, default=16, help='workers')
- parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
-
- parser.add_argument('--nc', default=3)
- parser.add_argument('--npoints', default=2048)
- parser.add_argument('--svpoints', default=1024)
- '''model'''
- parser.add_argument('--beta_start', default=0.0001)
- parser.add_argument('--beta_end', default=0.02)
- parser.add_argument('--schedule_type', default='linear')
- parser.add_argument('--time_num', default=1000)
-
- #params
- parser.add_argument('--attention', default=True)
- parser.add_argument('--dropout', default=0.1)
- parser.add_argument('--embed_dim', type=int, default=64)
- parser.add_argument('--loss_type', default='mse')
- parser.add_argument('--model_mean_type', default='eps')
- parser.add_argument('--model_var_type', default='fixedsmall')
-
- parser.add_argument('--lrE', type=float, default=2e-4, help='learning rate for E, default=0.0002')
- parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
- parser.add_argument('--e_decay', type=float, default=0, help='weight decay for EBM')
- parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM')
- parser.add_argument('--e_gamma', type=float, default=0.998, help='lr decay for EBM')
-
- parser.add_argument('--netE', default='', help="path to netE (to continue training)")
-
-
- '''distributed'''
- parser.add_argument('--world_size', default=1, type=int,
- help='Number of distributed nodes.')
- parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str,
- help='url used to set up distributed training')
- parser.add_argument('--dist_backend', default='nccl', type=str,
- help='distributed backend')
- parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None],
- help='Use multi-processing distributed training to launch '
- 'N processes per node, which has N GPUs. This is the '
- 'fastest way to use PyTorch for either single node or '
- 'multi node data parallel training')
- parser.add_argument('--rank', default=0, type=int,
- help='node rank for distributed training')
- parser.add_argument('--gpu', default=None, type=int,
- help='GPU id to use. None means using all available GPUs.')
-
- '''eval'''
- parser.add_argument('--saveIter', default=100, type=int,help='unit: epoch')
- parser.add_argument('--diagIter', default=50, type=int, help='unit: epoch')
- parser.add_argument('--vizIter', default=50, type=int,help='unit: epoch')
- parser.add_argument('--print_freq', default=50, type=int,help='unit: iter')
-
- parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
-
-
- opt = parser.parse_args()
-
- return opt
-
-if __name__ == '__main__':
- main()
diff --git a/shape_completion/14_res32_partnet_lamp_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py b/shape_completion/14_res32_partnet_lamp_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py
deleted file mode 100644
index b8a23f5..0000000
--- a/shape_completion/14_res32_partnet_lamp_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do.py
+++ /dev/null
@@ -1,822 +0,0 @@
-import torch.multiprocessing as mp
-import torch.nn as nn
-import torch.optim as optim
-import torch.utils.data
-
-import argparse
-from model.unet import get_model
-from torch.distributions import Normal
-
-from utils.file_utils import *
-from utils.visualize import *
-from model.pvcnn_completion import PVCNN2Base
-import torch.distributed as dist
-from datasets.partnet import GANdatasetPartNet
-from datasets.shapenet_data_pc import ShapeNet15kPointClouds
-'''
-some utils
-'''
-def rotation_matrix(axis, theta):
- """
- Return the rotation matrix associated with counterclockwise rotation about
- the given axis by theta radians.
- """
- axis = np.asarray(axis)
- axis = axis / np.sqrt(np.dot(axis, axis))
- a = np.cos(theta / 2.0)
- b, c, d = -axis * np.sin(theta / 2.0)
- aa, bb, cc, dd = a * a, b * b, c * c, d * d
- bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
- return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
- [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
- [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
-
-def rotate(vertices, faces):
- '''
- vertices: [numpoints, 3]
- '''
- M = rotation_matrix([0, 1, 0], np.pi / 2).transpose()
- N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose()
- K = rotation_matrix([0, 0, 1], np.pi).transpose()
-
- v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]]
- return v, f
-
-def norm(v, f):
- v = (v - v.min())/(v.max() - v.min()) - 0.5
-
- return v, f
-
-def getGradNorm(net):
- pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters()))
- gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters()))
- return pNorm, gradNorm
-
-
-def weights_init(m):
- """
- xavier initialization
- """
- classname = m.__class__.__name__
- if classname.find('Conv') != -1 and m.weight is not None:
- torch.nn.init.xavier_normal_(m.weight)
- elif classname.find('BatchNorm') != -1:
- m.weight.data.normal_()
- m.bias.data.fill_(0)
-
-'''
-models
-'''
-def normal_kl(mean1, logvar1, mean2, logvar2):
- """
- KL divergence between normal distributions parameterized by mean and log-variance.
- """
- return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
- + (mean1 - mean2)**2 * torch.exp(-logvar2))
-
-def discretized_gaussian_log_likelihood(x, *, means, log_scales):
- # Assumes data is integers [0, 1]
- assert x.shape == means.shape == log_scales.shape
- px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
-
- centered_x = x - means
- inv_stdv = torch.exp(-log_scales)
- plus_in = inv_stdv * (centered_x + 0.5)
- cdf_plus = px0.cdf(plus_in)
- min_in = inv_stdv * (centered_x - .5)
- cdf_min = px0.cdf(min_in)
- log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
- log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
- cdf_delta = cdf_plus - cdf_min
-
- log_probs = torch.where(
- x < 0.001, log_cdf_plus,
- torch.where(x > 0.999, log_one_minus_cdf_min,
- torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
- assert log_probs.shape == x.shape
- return log_probs
-
-class GaussianDiffusion:
- def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
- self.loss_type = loss_type
- self.model_mean_type = model_mean_type
- self.model_var_type = model_var_type
- assert isinstance(betas, np.ndarray)
- self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
- assert (betas > 0).all() and (betas <= 1).all()
- timesteps, = betas.shape
- self.num_timesteps = int(timesteps)
- self.sv_points = sv_points
- # initialize twice the actual length so we can keep running for eval
- # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
-
- alphas = 1. - betas
- alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
- alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
-
- self.betas = torch.from_numpy(betas).float()
- self.alphas_cumprod = alphas_cumprod.float()
- self.alphas_cumprod_prev = alphas_cumprod_prev.float()
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
- self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
- self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
- self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
- self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
-
- betas = torch.from_numpy(betas).float()
- alphas = torch.from_numpy(alphas).float()
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.posterior_variance = posterior_variance
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
- self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
- self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
-
- @staticmethod
- def _extract(a, t, x_shape):
- """
- Extract some coefficients at specified timesteps,
- then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
- """
- bs, = t.shape
- assert x_shape[0] == bs
- out = torch.gather(a, 0, t)
- assert out.shape == torch.Size([bs])
- return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
-
-
-
- def q_mean_variance(self, x_start, t):
- mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
- variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
- log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
- return mean, variance, log_variance
-
- def q_sample(self, x_start, t, noise=None):
- """
- Diffuse the data (t == 0 means diffused for 1 step)
- """
- if noise is None:
- noise = torch.randn(x_start.shape, device=x_start.device)
- assert noise.shape == x_start.shape
- return (
- self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
- self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
- )
-
-
- def q_posterior_mean_variance(self, x_start, x_t, t):
- """
- Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
- """
- assert x_start.shape == x_t.shape
- posterior_mean = (
- self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
- self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
- )
- posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
- posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
- assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
- x_start.shape[0])
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
-
- def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
-
- model_output = denoise_fn(data, t)[:,:,self.sv_points:]
-
-
- if self.model_var_type in ['fixedsmall', 'fixedlarge']:
- # below: only log_variance is used in the KL computations
- model_variance, model_log_variance = {
- # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
- 'fixedlarge': (self.betas.to(data.device),
- torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
- 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
- }[self.model_var_type]
- model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
- model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
- else:
- raise NotImplementedError(self.model_var_type)
-
- if self.model_mean_type == 'eps':
- x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
-
-
- model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
- else:
- raise NotImplementedError(self.loss_type)
-
-
- assert model_mean.shape == x_recon.shape
- assert model_variance.shape == model_log_variance.shape
- if return_pred_xstart:
- return model_mean, model_variance, model_log_variance, x_recon
- else:
- return model_mean, model_variance, model_log_variance
-
- def _predict_xstart_from_eps(self, x_t, t, eps):
- assert x_t.shape == eps.shape
- return (
- self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
- )
-
- ''' samples '''
-
- def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
- """
- Sample from the model
- """
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
- return_pred_xstart=True)
- noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
-
- # no noise when t == 0
- nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
-
- sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
- sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
- return (sample, pred_xstart) if return_pred_xstart else sample
-
-
- def p_sample_loop(self, partial_x, denoise_fn, shape, device,
- noise_fn=torch.randn, clip_denoised=True, keep_running=False):
- """
- Generate samples
- keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
-
- """
-
- assert isinstance(shape, (tuple, list))
- img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
- for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised, return_pred_xstart=False)
-
- assert img_t[:,:,self.sv_points:].shape == shape
- return img_t
-
- def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
- noise_fn=torch.randn,clip_denoised=True, keep_running=False):
- """
- Generate samples, returning intermediate images
- Useful for visualizing how denoised images evolve over time
- Args:
- repeat_noise_steps (int): Number of denoising timesteps in which the same noise
- is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
- """
- assert isinstance(shape, (tuple, list))
-
- total_steps = self.num_timesteps if not keep_running else len(self.betas)
-
- img_t = noise_fn(size=shape, dtype=torch.float, device=device)
- imgs = [img_t]
- for t in reversed(range(0,total_steps)):
-
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- return_pred_xstart=False)
- if t % freq == 0 or t == total_steps-1:
- imgs.append(img_t)
-
- assert imgs[-1].shape == shape
- return imgs
-
- '''losses'''
-
- def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
- true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
- denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
-
- kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
- kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
-
- return (kl, pred_xstart) if return_pred_xstart else kl
-
- def p_losses(self, denoise_fn, data_start, t, noise=None):
- """
- Training loss calculation
- """
- B, D, N = data_start.shape
- assert t.shape == torch.Size([B])
-
- if noise is None:
- noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
-
- data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
-
- if self.loss_type == 'mse':
- # predict the noise instead of x_start. seems to be weighted naturally like SNR
- eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
-
- losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
- elif self.loss_type == 'kl':
- losses = self._vb_terms_bpd(
- denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
- return_pred_xstart=False)
- else:
- raise NotImplementedError(self.loss_type)
-
- assert losses.shape == torch.Size([B])
- return losses
-
- '''debug'''
-
- def _prior_bpd(self, x_start):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
- t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
- mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
- assert kl_prior.shape == x_start.shape
- return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
-
- def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
-
- vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
- for t in reversed(range(T)):
-
- t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
- # Calculate VLB term at the current timestep
- data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
- new_vals_b, pred_xstart = self._vb_terms_bpd(
- denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
- clip_denoised=clip_denoised, return_pred_xstart=True)
- # MSE for progressive prediction loss
- assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
- new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
- assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
- # Insert the calculated term into the tensor of all terms
- mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
- vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
- mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
- assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
-
- prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
- total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
- assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
- total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
- return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
-
-
-class PVCNN2(PVCNN2Base):
- sa_blocks = [
- ((32, 2, 32), (1024, 0.1, 32, (32, 64))),
- ((64, 3, 16), (256, 0.2, 32, (64, 128))),
- ((128, 3, 8), (64, 0.4, 32, (128, 256))),
- (None, (16, 0.8, 32, (256, 256, 512))),
- ]
- fp_blocks = [
- ((256, 256), (256, 3, 8)),
- ((256, 256), (256, 3, 8)),
- ((256, 128), (128, 2, 16)),
- ((128, 128, 64), (64, 2, 32)),
- ]
-
- def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
- voxel_resolution_multiplier=1):
- super().__init__(
- num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
- dropout=dropout, extra_feature_channels=extra_feature_channels,
- width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
- )
-
-
-class Model(nn.Module):
- def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
- super(Model, self).__init__()
- self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
-
- self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
- dropout=args.dropout, extra_feature_channels=0)
-
- def prior_kl(self, x0):
- return self.diffusion._prior_bpd(x0)
-
- def all_kl(self, x0, clip_denoised=True):
- total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
-
- return {
- 'total_bpd_b': total_bpd_b,
- 'terms_bpd': vals_bt,
- 'prior_bpd_b': prior_bpd_b,
- 'mse_bt':mse_bt
- }
-
-
- def _denoise(self, data, t):
- B, D,N= data.shape
- assert data.dtype == torch.float
- assert t.shape == torch.Size([B]) and t.dtype == torch.int64
-
- out = self.model(data, t)
-
- return out
-
- def get_loss_iter(self, data, noises=None):
- B, D, N = data.shape
- t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
-
- if noises is not None:
- noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
-
- losses = self.diffusion.p_losses(
- denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
- assert losses.shape == t.shape == torch.Size([B])
- return losses
-
- def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
- clip_denoised=True,
- keep_running=False):
- return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- keep_running=keep_running)
-
-
- def train(self):
- self.model.train()
-
- def eval(self):
- self.model.eval()
-
- def multi_gpu_wrapper(self, f):
- self.model = f(self.model)
-
-
-def get_betas(schedule_type, b_start, b_end, time_num):
- if schedule_type == 'linear':
- betas = np.linspace(b_start, b_end, time_num)
- elif schedule_type == 'warm0.1':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.1)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.2':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.2)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.5':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.5)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- else:
- raise NotImplementedError(schedule_type)
- return betas
-
-
-def get_dataset(data_root, npoints, category):
-
- train_ds = GANdatasetPartNet('train', data_root, category, npoints)
- return train_ds
-
-
-def get_dataloader(opt, train_dataset, test_dataset=None):
-
- if opt.distribution_type == 'multi':
- train_sampler = torch.utils.data.distributed.DistributedSampler(
- train_dataset,
- num_replicas=opt.world_size,
- rank=opt.rank
- )
- if test_dataset is not None:
- test_sampler = torch.utils.data.distributed.DistributedSampler(
- test_dataset,
- num_replicas=opt.world_size,
- rank=opt.rank
- )
- else:
- test_sampler = None
- else:
- train_sampler = None
- test_sampler = None
-
- train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler,
- shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True)
-
- if test_dataset is not None:
- test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
- else:
- test_dataloader = None
-
- return train_dataloader, test_dataloader, train_sampler, test_sampler
-
-
-def train(gpu, opt, output_dir, noises_init):
-
- set_seed(opt)
- logger = setup_logging(output_dir)
- if opt.distribution_type == 'multi':
- should_diag = gpu==0
- else:
- should_diag = True
- if should_diag:
- outf_syn, = setup_output_subdirs(output_dir, 'syn')
-
- if opt.distribution_type == 'multi':
- if opt.dist_url == "env://" and opt.rank == -1:
- opt.rank = int(os.environ["RANK"])
-
- base_rank = opt.rank * opt.ngpus_per_node
- opt.rank = base_rank + gpu
- dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,
- world_size=opt.world_size, rank=opt.rank)
-
- opt.bs = int(opt.bs / opt.ngpus_per_node)
- opt.workers = 0
-
- opt.saveIter = int(opt.saveIter / opt.ngpus_per_node)
- opt.diagIter = int(opt.diagIter / opt.ngpus_per_node)
- opt.vizIter = int(opt.vizIter / opt.ngpus_per_node)
-
-
- ''' data '''
- train_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes)
- dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None)
-
-
- '''
- create networks
- '''
-
- betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
- netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
-
- if opt.distribution_type == 'multi': # Multiple processes, single GPU per process
- def _transform_(m):
- return nn.parallel.DistributedDataParallel(
- m, device_ids=[gpu], output_device=gpu)
-
- torch.cuda.set_device(gpu)
- netE.cuda(gpu)
- netE.multi_gpu_wrapper(_transform_)
-
-
- elif opt.distribution_type == 'single':
- def _transform_(m):
- return nn.parallel.DataParallel(m)
- netE = netE.cuda()
- netE.multi_gpu_wrapper(_transform_)
-
- elif gpu is not None:
- torch.cuda.set_device(gpu)
- netE = netE.cuda(gpu)
- else:
- raise ValueError('distribution_type = multi | single | None')
-
- if should_diag:
- logger.info(opt)
-
- optimizer= optim.Adam(netE.parameters(), lr=opt.lrE, weight_decay=opt.e_decay, betas=(opt.beta1, 0.999))
-
- lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.e_gamma)
-
- if opt.netE != '':
- ckpt = torch.load(opt.netE)
- netE.load_state_dict(ckpt['model_state'])
- optimizer.load_state_dict(ckpt['optimizer_state'])
-
- if opt.netE != '':
- start_epoch = torch.load(opt.netE)['epoch'] + 1
- else:
- start_epoch = 0
-
-
- for epoch in range(start_epoch, opt.niter):
-
- if opt.distribution_type == 'multi':
- train_sampler.set_epoch(epoch)
-
- lr_scheduler.step(epoch)
-
- for i, data in enumerate(dataloader):
-
- x = data['real']
- sv_x = data['raw']
-
- sv_x = torch.cat([sv_x, x[:, :, opt.svpoints:]], dim=-1)
- noises_batch = noises_init[data['idx']]
-
- '''
- train diffusion
- '''
-
- if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None):
- sv_x = sv_x.cuda(gpu)
- noises_batch = noises_batch.cuda(gpu)
- elif opt.distribution_type == 'single':
- sv_x = sv_x.cuda()
- noises_batch = noises_batch.cuda()
-
- loss = netE.get_loss_iter(sv_x, noises_batch).mean()
-
- optimizer.zero_grad()
- loss.backward()
- netpNorm, netgradNorm = getGradNorm(netE)
- if opt.grad_clip is not None:
- torch.nn.utils.clip_grad_norm_(netE.parameters(), opt.grad_clip)
-
- optimizer.step()
-
-
- if i % opt.print_freq == 0 and should_diag:
-
- logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, '
- 'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} '
- .format(
- epoch, opt.niter, i, len(dataloader),loss.item(),
- netpNorm, netgradNorm,
- ))
-
- if (epoch + 1) % opt.diagIter == 0 and should_diag:
-
- logger.info('Diagnosis:')
-
- x_range = [x.min().item(), x.max().item()]
- kl_stats = netE.all_kl(sv_x)
- logger.info(' [{:>3d}/{:>3d}] '
- 'x_range: [{:>10.4f}, {:>10.4f}], '
- 'total_bpd_b: {:>10.4f}, '
- 'terms_bpd: {:>10.4f}, '
- 'prior_bpd_b: {:>10.4f} '
- 'mse_bt: {:>10.4f} '
- .format(
- epoch, opt.niter,
- *x_range,
- kl_stats['total_bpd_b'].item(),
- kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item()
- ))
-
-
-
- if (epoch + 1) % opt.vizIter == 0 and should_diag:
- logger.info('Generation: eval')
-
- netE.eval()
- m, s = torch.tensor([[0,0,0]]).transpose(0,1)[None], torch.tensor([[1,1,1]]).transpose(0,1)[None]#train_dataset.get_pc_stats(0)
-
- with torch.no_grad():
-
- x_gen_eval = netE.gen_samples(sv_x[:,:,:opt.svpoints], sv_x[:,:,opt.svpoints:].shape, sv_x.device, clip_denoised=False).detach().cpu()
-
-
- gen_stats = [x_gen_eval.mean(), x_gen_eval.std()]
- gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()]
-
- logger.info(' [{:>3d}/{:>3d}] '
- 'eval_gen_range: [{:>10.4f}, {:>10.4f}] '
- 'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] '
- .format(
- epoch, opt.niter,
- *gen_eval_range, *gen_stats,
- ))
-
- export_to_pc_batch('%s/epoch_%03d_samples_eval' % (outf_syn, epoch),
- (x_gen_eval*s+m).transpose(1, 2).numpy()*3)
-
- export_to_pc_batch('%s/epoch_%03d_ground_truth' % (outf_syn, epoch),
- (sv_x.detach().cpu()*s+m).transpose(1, 2).numpy()*3)
-
- export_to_pc_batch('%s/epoch_%03d_partial' % (outf_syn, epoch),
- (sv_x[:,:,:opt.svpoints].detach().cpu()*s+m).transpose(1, 2).numpy()*3)
-
-
- netE.train()
-
-
-
-
-
-
-
- if (epoch + 1) % opt.saveIter == 0:
-
- if should_diag:
-
-
- save_dict = {
- 'epoch': epoch,
- 'model_state': netE.state_dict(),
- 'optimizer_state': optimizer.state_dict()
- }
-
- torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch))
-
-
- if opt.distribution_type == 'multi':
- dist.barrier()
- map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu}
- netE.load_state_dict(
- torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state'])
-
- dist.destroy_process_group()
-
-def main():
- opt = parse_args()
-
- exp_id = os.path.splitext(os.path.basename(__file__))[0]
- dir_id = os.path.dirname(__file__)
- output_dir = get_output_dir(dir_id, exp_id)
- copy_source(__file__, output_dir)
-
- ''' workaround '''
-
- train_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes)
- noises_init = torch.randn(len(train_dataset), opt.nc, opt.npoints-opt.svpoints)
-
- if opt.dist_url == "env://" and opt.world_size == -1:
- opt.world_size = int(os.environ["WORLD_SIZE"])
-
- if opt.distribution_type == 'multi':
- opt.ngpus_per_node = torch.cuda.device_count()
- opt.world_size = opt.ngpus_per_node * opt.world_size
- mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init))
- else:
- train(opt.gpu, opt, output_dir, noises_init)
-
-
-
-def parse_args():
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--data_root', default='/viscam/u/alexzhou907/01DATA/partnet/', help='input batch size')
-
- parser.add_argument('--classes', default='Table')
-
- parser.add_argument('--bs', type=int, default=64, help='input batch size')
- parser.add_argument('--workers', type=int, default=16, help='workers')
- parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
-
- parser.add_argument('--nc', default=3)
- parser.add_argument('--npoints', default=2048)
- parser.add_argument('--svpoints', default=1024)
- '''model'''
- parser.add_argument('--beta_start', default=0.0001)
- parser.add_argument('--beta_end', default=0.02)
- parser.add_argument('--schedule_type', default='linear')
- parser.add_argument('--time_num', default=1000)
-
- #params
- parser.add_argument('--attention', default=True)
- parser.add_argument('--dropout', default=0.1)
- parser.add_argument('--embed_dim', type=int, default=64)
- parser.add_argument('--loss_type', default='mse')
- parser.add_argument('--model_mean_type', default='eps')
- parser.add_argument('--model_var_type', default='fixedsmall')
-
- parser.add_argument('--lrE', type=float, default=2e-4, help='learning rate for E, default=0.0002')
- parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
- parser.add_argument('--e_decay', type=float, default=0, help='weight decay for EBM')
- parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM')
- parser.add_argument('--e_gamma', type=float, default=0.998, help='lr decay for EBM')
-
- parser.add_argument('--netE', default='', help="path to netE (to continue training)")
-
-
- '''distributed'''
- parser.add_argument('--world_size', default=1, type=int,
- help='Number of distributed nodes.')
- parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str,
- help='url used to set up distributed training')
- parser.add_argument('--dist_backend', default='nccl', type=str,
- help='distributed backend')
- parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None],
- help='Use multi-processing distributed training to launch '
- 'N processes per node, which has N GPUs. This is the '
- 'fastest way to use PyTorch for either single node or '
- 'multi node data parallel training')
- parser.add_argument('--rank', default=0, type=int,
- help='node rank for distributed training')
- parser.add_argument('--gpu', default=None, type=int,
- help='GPU id to use. None means using all available GPUs.')
-
- '''eval'''
- parser.add_argument('--saveIter', default=100, type=int,help='unit: epoch')
- parser.add_argument('--diagIter', default=50, type=int, help='unit: epoch')
- parser.add_argument('--vizIter', default=50, type=int,help='unit: epoch')
- parser.add_argument('--print_freq', default=50, type=int,help='unit: iter')
-
- parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
-
-
- opt = parser.parse_args()
-
- return opt
-
-if __name__ == '__main__':
- main()
diff --git a/shape_completion/__init__.py b/shape_completion/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/shape_completion/control_gen_chair.py b/shape_completion/control_gen_chair.py
deleted file mode 100644
index 96a1387..0000000
--- a/shape_completion/control_gen_chair.py
+++ /dev/null
@@ -1,660 +0,0 @@
-
-from pprint import pprint
-from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
-
-import torch.nn as nn
-import torch.utils.data
-
-import argparse
-from torch.distributions import Normal
-from utils.visualize import *
-from utils.file_utils import *
-from utils.mitsuba_renderer import write_to_xml_batch
-from model.pvcnn_completion import PVCNN2Base
-
-from datasets.shapenet_data_pc import ShapeNet15kPointClouds
-from datasets.shapenet_data_sv import *
-'''
-models
-'''
-def normal_kl(mean1, logvar1, mean2, logvar2):
- """
- KL divergence between normal distributions parameterized by mean and log-variance.
- """
- return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
- + (mean1 - mean2)**2 * torch.exp(-logvar2))
-
-def discretized_gaussian_log_likelihood(x, *, means, log_scales):
- # Assumes data is integers [0, 1]
- assert x.shape == means.shape == log_scales.shape
- px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
-
- centered_x = x - means
- inv_stdv = torch.exp(-log_scales)
- plus_in = inv_stdv * (centered_x + 0.5)
- cdf_plus = px0.cdf(plus_in)
- min_in = inv_stdv * (centered_x - .5)
- cdf_min = px0.cdf(min_in)
- log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
- log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
- cdf_delta = cdf_plus - cdf_min
-
- log_probs = torch.where(
- x < 0.001, log_cdf_plus,
- torch.where(x > 0.999, log_one_minus_cdf_min,
- torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
- assert log_probs.shape == x.shape
- return log_probs
-
-
-
-class GaussianDiffusion:
- def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
- self.loss_type = loss_type
- self.model_mean_type = model_mean_type
- self.model_var_type = model_var_type
- assert isinstance(betas, np.ndarray)
- self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
- assert (betas > 0).all() and (betas <= 1).all()
- timesteps, = betas.shape
- self.num_timesteps = int(timesteps)
- self.sv_points = sv_points
- # initialize twice the actual length so we can keep running for eval
- # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
-
- alphas = 1. - betas
- alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
- alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
-
- self.betas = torch.from_numpy(betas).float()
- self.alphas_cumprod = alphas_cumprod.float()
- self.alphas_cumprod_prev = alphas_cumprod_prev.float()
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
- self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
- self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
- self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
- self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
-
- betas = torch.from_numpy(betas).float()
- alphas = torch.from_numpy(alphas).float()
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.posterior_variance = posterior_variance
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
- self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
- self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
-
- @staticmethod
- def _extract(a, t, x_shape):
- """
- Extract some coefficients at specified timesteps,
- then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
- """
- bs, = t.shape
- assert x_shape[0] == bs
- out = torch.gather(a, 0, t)
- assert out.shape == torch.Size([bs])
- return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
-
-
-
- def q_mean_variance(self, x_start, t):
- mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
- variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
- log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
- return mean, variance, log_variance
-
- def q_sample(self, x_start, t, noise=None):
- """
- Diffuse the data (t == 0 means diffused for 1 step)
- """
- if noise is None:
- noise = torch.randn(x_start.shape, device=x_start.device)
- assert noise.shape == x_start.shape
- return (
- self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
- self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
- )
-
-
- def q_posterior_mean_variance(self, x_start, x_t, t):
- """
- Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
- """
- assert x_start.shape == x_t.shape
- posterior_mean = (
- self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
- self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
- )
- posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
- posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
- assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
- x_start.shape[0])
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
-
- def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
-
- model_output = denoise_fn(data, t)[:,:,self.sv_points:]
-
-
- if self.model_var_type in ['fixedsmall', 'fixedlarge']:
- # below: only log_variance is used in the KL computations
- model_variance, model_log_variance = {
- # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
- 'fixedlarge': (self.betas.to(data.device),
- torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
- 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
- }[self.model_var_type]
- model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
- model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
- else:
- raise NotImplementedError(self.model_var_type)
-
- if self.model_mean_type == 'eps':
- x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
-
-
- model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
- else:
- raise NotImplementedError(self.loss_type)
-
-
- assert model_mean.shape == x_recon.shape
- assert model_variance.shape == model_log_variance.shape
- if return_pred_xstart:
- return model_mean, model_variance, model_log_variance, x_recon
- else:
- return model_mean, model_variance, model_log_variance
-
- def _predict_xstart_from_eps(self, x_t, t, eps):
- assert x_t.shape == eps.shape
- return (
- self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
- )
-
- ''' samples '''
-
- def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
- """
- Sample from the model
- """
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
- return_pred_xstart=True)
- noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
-
- # no noise when t == 0
- nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
-
- sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
- sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
- return (sample, pred_xstart) if return_pred_xstart else sample
-
-
- def p_sample_loop(self, partial_x, denoise_fn, shape, device,
- noise_fn=torch.randn, clip_denoised=True, keep_running=False):
- """
- Generate samples
- keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
-
- """
-
- assert isinstance(shape, (tuple, list))
- img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
- for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised, return_pred_xstart=False)
-
- assert img_t[:,:,self.sv_points:].shape == shape
- return img_t
-
- def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
- noise_fn=torch.randn,clip_denoised=True, keep_running=False):
- """
- Generate samples, returning intermediate images
- Useful for visualizing how denoised images evolve over time
- Args:
- repeat_noise_steps (int): Number of denoising timesteps in which the same noise
- is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
- """
- assert isinstance(shape, (tuple, list))
-
- total_steps = self.num_timesteps if not keep_running else len(self.betas)
-
- img_t = noise_fn(size=shape, dtype=torch.float, device=device)
- imgs = [img_t]
- for t in reversed(range(0,total_steps)):
-
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- return_pred_xstart=False)
- if t % freq == 0 or t == total_steps-1:
- imgs.append(img_t)
-
- assert imgs[-1].shape == shape
- return imgs
-
-
-
-
- '''losses'''
-
- def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
- true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
- denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
-
- kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
- kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
-
- return (kl, pred_xstart) if return_pred_xstart else kl
-
- def p_losses(self, denoise_fn, data_start, t, noise=None):
- """
- Training loss calculation
- """
- B, D, N = data_start.shape
- assert t.shape == torch.Size([B])
-
- if noise is None:
- noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
-
- data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
-
- if self.loss_type == 'mse':
- # predict the noise instead of x_start. seems to be weighted naturally like SNR
- eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
-
- losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
- elif self.loss_type == 'kl':
- losses = self._vb_terms_bpd(
- denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
- return_pred_xstart=False)
- else:
- raise NotImplementedError(self.loss_type)
-
- assert losses.shape == torch.Size([B])
- return losses
-
- '''debug'''
-
- def _prior_bpd(self, x_start):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
- t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
- mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
- assert kl_prior.shape == x_start.shape
- return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
-
- def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
-
- vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
- for t in reversed(range(T)):
-
- t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
- # Calculate VLB term at the current timestep
- data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
- new_vals_b, pred_xstart = self._vb_terms_bpd(
- denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
- clip_denoised=clip_denoised, return_pred_xstart=True)
- # MSE for progressive prediction loss
- assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
- new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
- assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
- # Insert the calculated term into the tensor of all terms
- mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
- vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
- mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
- assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
-
- prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
- total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
- assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
- total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
- return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
-
- def interpolate(self, x0_part, x1_part, x0_sv, x1_sv, t, lamb, denoise_fn, noise_fn=torch.randn):
-
- assert t >= 1
-
- t_vec = torch.empty(x0_part.shape[0], dtype=torch.int64, device=x0_part.device).fill_(t-1)
- encoding0 = self.q_sample(x0_part, t_vec)
- encoding1 = self.q_sample(x1_part, t_vec)
-
- enc = encoding0 * (1-lamb) + (lamb) * encoding1
-
- img_t = torch.cat([torch.cat([x0_sv[:,:,:int(self.sv_points*(1-lamb))], x1_sv[:,:,:(self.sv_points - int(self.sv_points*(1-lamb)))]], dim=-1), enc], dim=-1)
-
- for k in reversed(range(0,t)):
- t_ = torch.empty(img_t.shape[0], dtype=torch.int64, device=img_t.device).fill_(k)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
- clip_denoised=False, return_pred_xstart=False).detach()
-
-
- return img_t
-
-
-class PVCNN2(PVCNN2Base):
- sa_blocks = [
- ((32, 2, 32), (1024, 0.1, 32, (32, 64))),
- ((64, 3, 16), (256, 0.2, 32, (64, 128))),
- ((128, 3, 8), (64, 0.4, 32, (128, 256))),
- (None, (16, 0.8, 32, (256, 256, 512))),
- ]
- fp_blocks = [
- ((256, 256), (256, 3, 8)),
- ((256, 256), (256, 3, 8)),
- ((256, 128), (128, 2, 16)),
- ((128, 128, 64), (64, 2, 32)),
- ]
-
- def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
- voxel_resolution_multiplier=1):
- super().__init__(
- num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
- dropout=dropout, extra_feature_channels=extra_feature_channels,
- width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
- )
-
-
-class Model(nn.Module):
- def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
- super(Model, self).__init__()
- self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
-
- self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
- dropout=args.dropout, extra_feature_channels=0)
-
- def prior_kl(self, x0):
- return self.diffusion._prior_bpd(x0)
-
- def all_kl(self, x0, clip_denoised=True):
- total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
-
- return {
- 'total_bpd_b': total_bpd_b,
- 'terms_bpd': vals_bt,
- 'prior_bpd_b': prior_bpd_b,
- 'mse_bt':mse_bt
- }
-
-
- def _denoise(self, data, t):
- B, D,N= data.shape
- assert data.dtype == torch.float
- assert t.shape == torch.Size([B]) and t.dtype == torch.int64
-
- out = self.model(data, t)
-
- return out
-
- def get_loss_iter(self, data, noises=None):
- B, D, N = data.shape
- t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
-
- if noises is not None:
- noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
-
- losses = self.diffusion.p_losses(
- denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
- assert losses.shape == t.shape == torch.Size([B])
- return losses
-
- def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
- clip_denoised=True,
- keep_running=False):
- return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- keep_running=keep_running)
-
- def interpolate(self, x0_part, x1_part, x0_sv, x1_sv, t, lamb):
-
- return self.diffusion.interpolate(x0_part, x1_part, x0_sv, x1_sv, t, lamb, self._denoise)
-
-
- def train(self):
- self.model.train()
-
- def eval(self):
- self.model.eval()
-
- def multi_gpu_wrapper(self, f):
- self.model = f(self.model)
-
-def get_betas(schedule_type, b_start, b_end, time_num):
- if schedule_type == 'linear':
- betas = np.linspace(b_start, b_end, time_num)
- elif schedule_type == 'warm0.1':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.1)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.2':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.2)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.5':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.5)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- else:
- raise NotImplementedError(schedule_type)
- return betas
-
-
-
-#############################################################################
-def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
- cache=os.path.join(mesh_root, '../cache'), split='val',
- categories=category,
- radius=3, elev=-89, azim=180, img_size=512, focal_length=1000,
- npoints=npoints, sv_samples=200,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- )
- return te_dataset
-
-def get_mvr_dataset(pc_dataroot, views_root, npoints,category, get_image=True):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
- cache=os.path.join(pc_dataroot, '../cache'), split='val',
- categories=category,
- npoints=npoints, sv_samples=200,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std, get_image=get_image,
- )
- return te_dataset
-
-
-def generate_multimodal(opt, netE, save_dir, logger):
- test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
- opt.npoints, opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
-
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
- if i!=3:
- continue
- gt_all = data['test_points']
- x_all = data['sv_points']
- img_all = data['image']
-
- for v in range(20):
-
- recons = []
- svs = []
- for p in [0,1]:
- x = x_all[:,p].transpose(1, 2).contiguous()
- img = img_all[:,p]
-
- m, s = data['mean'].float(), data['std'].float()
-
- recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
- clip_denoised=False).detach().cpu()
-
- recon = recon.transpose(1, 2).contiguous()
- x = x.transpose(1, 2).contiguous()
-
- recons.append(recon)
- svs.append(x[:, :opt.svpoints,:])
-
- for l, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))):
-
- # im = np.fliplr(np.flipud(d[-1]))
- write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,l), 'mode_%03d'%v, 'depth_%d'%p),
- (torch.stack(d[:-1], dim=0)* s[0] + m[0]).numpy(), cat='chair')
-
- export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,l), 'mode_%03d'%v, 'depth_%d'%p),
- (torch.stack(d[:-1], dim=0)* s[0] + m[0]).numpy())
- plt.imsave(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, l), 'mode_%03d' % v, 'depth_%d.png' % p),
- d[-1].permute(1, 2, 0), cmap='gray')
-
- x0_part = recons[0].transpose(1, 2).contiguous()[:,:,opt.svpoints:].cuda()
- x1_part = recons[1].transpose(1, 2).contiguous()[:,:,opt.svpoints:].cuda()
- x0_sv = svs[0].transpose(1,2).cuda()
- x1_sv = svs[1].transpose(1,2).cuda()
-
- interres = []
- for lamb in np.linspace(0.1, 0.9, 5):
- res = netE.interpolate(x0_part, x1_part, x0_sv, x1_sv, 1000, lamb)
-
- res = torch.cat([x0_sv, x1_sv, res[:,:,opt.svpoints:]], dim=-1).detach().cpu().transpose(1,2).contiguous()
- interres.append(res)
- for l, d in enumerate(torch.stack(interres, dim=1)):
- write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d' % (i, l), 'mode_%03d' % v),
- (d* s[0] + m[0]).numpy(), cat='chair')
- export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, l), 'mode_%03d' % v),
- (d * s[0] + m[0]).numpy())
-
-def main(opt):
- exp_id = os.path.splitext(os.path.basename(__file__))[0]
- dir_id = os.path.dirname(__file__)
- output_dir = get_output_dir(dir_id, exp_id)
- copy_source(__file__, output_dir)
- logger = setup_logging(output_dir)
-
- outf_syn, = setup_output_subdirs(output_dir, 'syn')
-
- betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
- netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
-
- if opt.cuda:
- netE.cuda()
-
- def _transform_(m):
- return nn.parallel.DataParallel(m)
-
- netE = netE.cuda()
- netE.multi_gpu_wrapper(_transform_)
-
- netE.eval()
-
- ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
-
- with torch.no_grad():
- for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
-
- opt.netE = ckpt
- logger.info("Resume Path:%s" % opt.netE)
-
- resumed_param = torch.load(opt.netE)
- netE.load_state_dict(resumed_param['model_state'])
-
-
- generate_multimodal(opt, netE, outf_syn, logger)
-
-
-def parse_args():
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
- parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
- help='input batch size')
- parser.add_argument('--classes', default=['chair'])
-
- parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
- parser.add_argument('--workers', type=int, default=16, help='workers')
- parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
-
- parser.add_argument('--eval_recon_mvr', default=False)
- parser.add_argument('--generate_multimodal', default=False)
- parser.add_argument('--eval_saved', default=False)
- parser.add_argument('--eval_redwood', default=True)
-
- parser.add_argument('--nc', default=3)
- parser.add_argument('--npoints', default=2048)
- parser.add_argument('--svpoints', default=200)
- '''model'''
- parser.add_argument('--beta_start', default=0.0001)
- parser.add_argument('--beta_end', default=0.02)
- parser.add_argument('--schedule_type', default='linear')
- parser.add_argument('--time_num', default=1000)
-
- #params
- parser.add_argument('--attention', default=True)
- parser.add_argument('--dropout', default=0.1)
- parser.add_argument('--embed_dim', type=int, default=64)
- parser.add_argument('--loss_type', default='mse')
- parser.add_argument('--model_mean_type', default='eps')
- parser.add_argument('--model_var_type', default='fixedsmall')
-
-
- # constrain function
- parser.add_argument('--constrain_eps', default=0.2)
- parser.add_argument('--constrain_steps', type=int, default=1)
-
- parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/6_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-11-01-19-21-18', help="path to netE (to continue training)")
-
- '''eval'''
-
- parser.add_argument('--eval_path',
- default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
-
- parser.add_argument('--manualSeed', default=None, type=int, help='random seed')
-
- parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
-
- opt = parser.parse_args()
-
- if torch.cuda.is_available():
- opt.cuda = True
- else:
- opt.cuda = False
-
- return opt
-if __name__ == '__main__':
- opt = parse_args()
- set_seed(opt)
- main(opt)
diff --git a/shape_completion/teaser_chair.py b/shape_completion/teaser_chair.py
deleted file mode 100644
index 5d022e5..0000000
--- a/shape_completion/teaser_chair.py
+++ /dev/null
@@ -1,706 +0,0 @@
-
-
-import torch.nn as nn
-import torch.utils.data
-
-import argparse
-from torch.distributions import Normal
-from utils.visualize import *
-from utils.file_utils import *
-from utils.mitsuba_renderer2 import write_to_xml_batch, write_to_xml
-from model.pvcnn_completion import PVCNN2Base
-
-from datasets.shapenet_data_pc import ShapeNet15kPointClouds
-from datasets.shapenet_data_sv import *
-'''
-models
-'''
-def normal_kl(mean1, logvar1, mean2, logvar2):
- """
- KL divergence between normal distributions parameterized by mean and log-variance.
- """
- return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
- + (mean1 - mean2)**2 * torch.exp(-logvar2))
-
-def discretized_gaussian_log_likelihood(x, *, means, log_scales):
- # Assumes data is integers [0, 1]
- assert x.shape == means.shape == log_scales.shape
- px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
-
- centered_x = x - means
- inv_stdv = torch.exp(-log_scales)
- plus_in = inv_stdv * (centered_x + 0.5)
- cdf_plus = px0.cdf(plus_in)
- min_in = inv_stdv * (centered_x - .5)
- cdf_min = px0.cdf(min_in)
- log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
- log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
- cdf_delta = cdf_plus - cdf_min
-
- log_probs = torch.where(
- x < 0.001, log_cdf_plus,
- torch.where(x > 0.999, log_one_minus_cdf_min,
- torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
- assert log_probs.shape == x.shape
- return log_probs
-
-
-
-class GaussianDiffusion:
- def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
- self.loss_type = loss_type
- self.model_mean_type = model_mean_type
- self.model_var_type = model_var_type
- assert isinstance(betas, np.ndarray)
- self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
- assert (betas > 0).all() and (betas <= 1).all()
- timesteps, = betas.shape
- self.num_timesteps = int(timesteps)
- self.sv_points = sv_points
- # initialize twice the actual length so we can keep running for eval
- # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
-
- alphas = 1. - betas
- alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
- alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
-
- self.betas = torch.from_numpy(betas).float()
- self.alphas_cumprod = alphas_cumprod.float()
- self.alphas_cumprod_prev = alphas_cumprod_prev.float()
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
- self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
- self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
- self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
- self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
-
- betas = torch.from_numpy(betas).float()
- alphas = torch.from_numpy(alphas).float()
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.posterior_variance = posterior_variance
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
- self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
- self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
-
- @staticmethod
- def _extract(a, t, x_shape):
- """
- Extract some coefficients at specified timesteps,
- then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
- """
- bs, = t.shape
- assert x_shape[0] == bs
- out = torch.gather(a, 0, t)
- assert out.shape == torch.Size([bs])
- return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
-
-
-
- def q_mean_variance(self, x_start, t):
- mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
- variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
- log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
- return mean, variance, log_variance
-
- def q_sample(self, x_start, t, noise=None):
- """
- Diffuse the data (t == 0 means diffused for 1 step)
- """
- if noise is None:
- noise = torch.randn(x_start.shape, device=x_start.device)
- assert noise.shape == x_start.shape
- return (
- self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
- self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
- )
-
-
- def q_posterior_mean_variance(self, x_start, x_t, t):
- """
- Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
- """
- assert x_start.shape == x_t.shape
- posterior_mean = (
- self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
- self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
- )
- posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
- posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
- assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
- x_start.shape[0])
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
-
- def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
-
- model_output = denoise_fn(data, t)[:,:,self.sv_points:]
-
-
- if self.model_var_type in ['fixedsmall', 'fixedlarge']:
- # below: only log_variance is used in the KL computations
- model_variance, model_log_variance = {
- # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
- 'fixedlarge': (self.betas.to(data.device),
- torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
- 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
- }[self.model_var_type]
- model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
- model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
- else:
- raise NotImplementedError(self.model_var_type)
-
- if self.model_mean_type == 'eps':
- x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
-
-
- model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
- else:
- raise NotImplementedError(self.loss_type)
-
-
- assert model_mean.shape == x_recon.shape
- assert model_variance.shape == model_log_variance.shape
- if return_pred_xstart:
- return model_mean, model_variance, model_log_variance, x_recon
- else:
- return model_mean, model_variance, model_log_variance
-
- def _predict_xstart_from_eps(self, x_t, t, eps):
- assert x_t.shape == eps.shape
- return (
- self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
- )
-
- ''' samples '''
-
- def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
- """
- Sample from the model
- """
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
- return_pred_xstart=True)
- noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
-
- # no noise when t == 0
- nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
-
- sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
- sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
- return (sample, pred_xstart) if return_pred_xstart else sample
-
-
- def p_sample_loop(self, partial_x, denoise_fn, shape, device,
- noise_fn=torch.randn, clip_denoised=True, keep_running=False):
- """
- Generate samples
- keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
-
- """
-
- assert isinstance(shape, (tuple, list))
- img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
- for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised, return_pred_xstart=False)
-
- assert img_t[:,:,self.sv_points:].shape == shape
- return img_t
-
-
- def p_sample_loop_trajectory2(self, partial_x, denoise_fn, shape, device, num_save,
- noise_fn=torch.randn,clip_denoised=True, keep_running=False):
- """
- Generate samples, returning intermediate images
- Useful for visualizing how denoised images evolve over time
- Args:
- repeat_noise_steps (int): Number of denoising timesteps in which the same noise
- is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
- """
- assert isinstance(shape, (tuple, list))
-
- total_steps = self.num_timesteps if not keep_running else len(self.betas)
-
- scale = np.exp(np.log(1/total_steps)/num_save)
- save_step = total_steps
-
- img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
- imgs = [img_t.detach().cpu()]
- for t in reversed(range(0,total_steps)):
- if (t+1) == save_step and t > 0 and len(imgs) 0.999, log_one_minus_cdf_min,
- torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
- assert log_probs.shape == x.shape
- return log_probs
-
-
-
-class GaussianDiffusion:
- def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
- self.loss_type = loss_type
- self.model_mean_type = model_mean_type
- self.model_var_type = model_var_type
- assert isinstance(betas, np.ndarray)
- self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
- assert (betas > 0).all() and (betas <= 1).all()
- timesteps, = betas.shape
- self.num_timesteps = int(timesteps)
- self.sv_points = sv_points
- # initialize twice the actual length so we can keep running for eval
- # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
-
- alphas = 1. - betas
- alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
- alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
-
- self.betas = torch.from_numpy(betas).float()
- self.alphas_cumprod = alphas_cumprod.float()
- self.alphas_cumprod_prev = alphas_cumprod_prev.float()
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
- self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
- self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
- self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
- self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
-
- betas = torch.from_numpy(betas).float()
- alphas = torch.from_numpy(alphas).float()
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.posterior_variance = posterior_variance
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
- self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
- self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
-
- @staticmethod
- def _extract(a, t, x_shape):
- """
- Extract some coefficients at specified timesteps,
- then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
- """
- bs, = t.shape
- assert x_shape[0] == bs
- out = torch.gather(a, 0, t)
- assert out.shape == torch.Size([bs])
- return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
-
-
-
- def q_mean_variance(self, x_start, t):
- mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
- variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
- log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
- return mean, variance, log_variance
-
- def q_sample(self, x_start, t, noise=None):
- """
- Diffuse the data (t == 0 means diffused for 1 step)
- """
- if noise is None:
- noise = torch.randn(x_start.shape, device=x_start.device)
- assert noise.shape == x_start.shape
- return (
- self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
- self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
- )
-
-
- def q_posterior_mean_variance(self, x_start, x_t, t):
- """
- Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
- """
- assert x_start.shape == x_t.shape
- posterior_mean = (
- self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
- self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
- )
- posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
- posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
- assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
- x_start.shape[0])
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
-
- def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
-
- model_output = denoise_fn(data, t)[:,:,self.sv_points:]
-
-
- if self.model_var_type in ['fixedsmall', 'fixedlarge']:
- # below: only log_variance is used in the KL computations
- model_variance, model_log_variance = {
- # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
- 'fixedlarge': (self.betas.to(data.device),
- torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
- 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
- }[self.model_var_type]
- model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
- model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
- else:
- raise NotImplementedError(self.model_var_type)
-
- if self.model_mean_type == 'eps':
- x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
-
-
- model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
- else:
- raise NotImplementedError(self.loss_type)
-
-
- assert model_mean.shape == x_recon.shape
- assert model_variance.shape == model_log_variance.shape
- if return_pred_xstart:
- return model_mean, model_variance, model_log_variance, x_recon
- else:
- return model_mean, model_variance, model_log_variance
-
- def _predict_xstart_from_eps(self, x_t, t, eps):
- assert x_t.shape == eps.shape
- return (
- self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
- )
-
- ''' samples '''
-
- def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
- """
- Sample from the model
- """
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
- return_pred_xstart=True)
- noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
-
- # no noise when t == 0
- nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
-
- sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
- sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
- return (sample, pred_xstart) if return_pred_xstart else sample
-
-
- def p_sample_loop(self, partial_x, denoise_fn, shape, device,
- noise_fn=torch.randn, clip_denoised=True, keep_running=False):
- """
- Generate samples
- keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
-
- """
-
- assert isinstance(shape, (tuple, list))
- img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
- for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised, return_pred_xstart=False)
-
- assert img_t[:,:,self.sv_points:].shape == shape
- return img_t
-
- def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
- noise_fn=torch.randn,clip_denoised=True, keep_running=False):
- """
- Generate samples, returning intermediate images
- Useful for visualizing how denoised images evolve over time
- Args:
- repeat_noise_steps (int): Number of denoising timesteps in which the same noise
- is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
- """
- assert isinstance(shape, (tuple, list))
-
- total_steps = self.num_timesteps if not keep_running else len(self.betas)
-
- img_t = noise_fn(size=shape, dtype=torch.float, device=device)
- imgs = [img_t]
- for t in reversed(range(0,total_steps)):
-
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- return_pred_xstart=False)
- if t % freq == 0 or t == total_steps-1:
- imgs.append(img_t)
-
- assert imgs[-1].shape == shape
- return imgs
-
- '''losses'''
-
- def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
- true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
- denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
-
- kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
- kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
-
- return (kl, pred_xstart) if return_pred_xstart else kl
-
- def p_losses(self, denoise_fn, data_start, t, noise=None):
- """
- Training loss calculation
- """
- B, D, N = data_start.shape
- assert t.shape == torch.Size([B])
-
- if noise is None:
- noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
-
- data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
-
- if self.loss_type == 'mse':
- # predict the noise instead of x_start. seems to be weighted naturally like SNR
- eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
-
- losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
- elif self.loss_type == 'kl':
- losses = self._vb_terms_bpd(
- denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
- return_pred_xstart=False)
- else:
- raise NotImplementedError(self.loss_type)
-
- assert losses.shape == torch.Size([B])
- return losses
-
- '''debug'''
-
- def _prior_bpd(self, x_start):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
- t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
- mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
- assert kl_prior.shape == x_start.shape
- return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
-
- def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
-
- vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
- for t in reversed(range(T)):
-
- t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
- # Calculate VLB term at the current timestep
- data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
- new_vals_b, pred_xstart = self._vb_terms_bpd(
- denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
- clip_denoised=clip_denoised, return_pred_xstart=True)
- # MSE for progressive prediction loss
- assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
- new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
- assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
- # Insert the calculated term into the tensor of all terms
- mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
- vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
- mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
- assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
-
- prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
- total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
- assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
- total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
- return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
-
-
-class PVCNN2(PVCNN2Base):
- sa_blocks = [
- ((32, 2, 32), (1024, 0.1, 32, (32, 64))),
- ((64, 3, 16), (256, 0.2, 32, (64, 128))),
- ((128, 3, 8), (64, 0.4, 32, (128, 256))),
- (None, (16, 0.8, 32, (256, 256, 512))),
- ]
- fp_blocks = [
- ((256, 256), (256, 3, 8)),
- ((256, 256), (256, 3, 8)),
- ((256, 128), (128, 2, 16)),
- ((128, 128, 64), (64, 2, 32)),
- ]
-
- def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
- voxel_resolution_multiplier=1):
- super().__init__(
- num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
- dropout=dropout, extra_feature_channels=extra_feature_channels,
- width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
- )
-
-
-class Model(nn.Module):
- def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
- super(Model, self).__init__()
- self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
-
- self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
- dropout=args.dropout, extra_feature_channels=0)
-
- def prior_kl(self, x0):
- return self.diffusion._prior_bpd(x0)
-
- def all_kl(self, x0, clip_denoised=True):
- total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
-
- return {
- 'total_bpd_b': total_bpd_b,
- 'terms_bpd': vals_bt,
- 'prior_bpd_b': prior_bpd_b,
- 'mse_bt':mse_bt
- }
-
-
- def _denoise(self, data, t):
- B, D,N= data.shape
- assert data.dtype == torch.float
- assert t.shape == torch.Size([B]) and t.dtype == torch.int64
-
- out = self.model(data, t)
-
- return out
-
- def get_loss_iter(self, data, noises=None):
- B, D, N = data.shape
- t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
-
- if noises is not None:
- noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
-
- losses = self.diffusion.p_losses(
- denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
- assert losses.shape == t.shape == torch.Size([B])
- return losses
-
- def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
- clip_denoised=True,
- keep_running=False):
- return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- keep_running=keep_running)
-
-
- def train(self):
- self.model.train()
-
- def eval(self):
- self.model.eval()
-
- def multi_gpu_wrapper(self, f):
- self.model = f(self.model)
-
-def get_betas(schedule_type, b_start, b_end, time_num):
- if schedule_type == 'linear':
- betas = np.linspace(b_start, b_end, time_num)
- elif schedule_type == 'warm0.1':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.1)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.2':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.2)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.5':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.5)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- else:
- raise NotImplementedError(schedule_type)
- return betas
-
-
-
-#############################################################################
-def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
- cache=os.path.join(mesh_root, '../cache'), split='val',
- categories=category,
- radius=3, elev=0, azim=0, img_size=512, focal_length=1000,
- npoints=npoints, sv_samples=200,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- )
- return te_dataset
-
-def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
- cache=os.path.join(pc_dataroot, '../cache'), split='val',
- categories=category,
- npoints=npoints, sv_samples=200,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- )
- return te_dataset
-
-
-def evaluate_recon_mvr(opt, netE, save_dir, logger):
- test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
- opt.npoints, opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
- ref = []
- samples = []
- images = []
- masked = []
- k = 0
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
-
- gt_all = data['test_points']
- x_all = data['sv_points']
- img_all = data['image']
-
-
- B,V,N,C = x_all.shape
- gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
-
- x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
- img = img_all.reshape(B * V, *img_all.shape[2:])
-
- m, s = data['mean'].float(), data['std'].float()
-
- recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
- clip_denoised=False).detach().cpu()
-
- recon = recon.transpose(1, 2).contiguous()
- x = x.transpose(1, 2).contiguous()
-
-
- x_adj = x.reshape(B,V,N,C)* s + m
- recon_adj = recon.reshape(B,V,N,C)* s + m
- img = img.reshape(B,V,*img.shape[1:])
-
- ref.append( gt_all * s + m)
- masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
- samples.append(recon_adj)
- images.append(img)
-
- ref_pcs = torch.cat(ref, dim=0)
- sample_pcs = torch.cat(samples, dim=0)
- images = torch.cat(images, dim=0)
- masked = torch.cat(masked, dim=0)
-
- B, V, N, C = ref_pcs.shape
-
-
- torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
- # torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
- torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
- # Compute metrics
- results = EMD_CD(sample_pcs.reshape(B*V, N, C),
- ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
-
- results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
-
- pprint({key: val.mean().item() for key, val in results.items()})
- logger.info({key: val.mean().item() for key, val in results.items()})
-
- results['pc'] = sample_pcs
- torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
-
-
-
-def generate_multimodal(opt, netE, save_dir, logger):
- test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv,
- opt.npoints, opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
-
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
-
- gt_all = data['test_points']
- x_all = data['sv_points']
- img_all = data['image']
-
- Path(os.path.join(save_dir, 'x_%03d'%k)).mkdir(exist_ok=True, parents=True)
- Path(os.path.join(save_dir, 'x_ply_%03d' % k)).mkdir(exist_ok=True, parents=True)
-
- for v in range(5):
- x = x_all.transpose(1, 2).contiguous()
- img = img_all
-
- m, s = data['mean'].float(), data['std'].float()
-
- recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
- clip_denoised=False).detach().cpu()
-
- recon = recon.transpose(1, 2).contiguous()
- x = x.transpose(1, 2).contiguous()
-
- for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))):
-
- im = np.fliplr(np.flipud(d[-1]))
- plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray')
- write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
-
- export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
-
-
-def main(opt):
- exp_id = os.path.splitext(os.path.basename(__file__))[0]
- dir_id = os.path.dirname(__file__)
- output_dir = get_output_dir(dir_id, exp_id)
- copy_source(__file__, output_dir)
- logger = setup_logging(output_dir)
-
- outf_syn, = setup_output_subdirs(output_dir, 'syn')
-
- betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
- netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
-
- if opt.cuda:
- netE.cuda()
-
- def _transform_(m):
- return nn.parallel.DataParallel(m)
-
- netE = netE.cuda()
- netE.multi_gpu_wrapper(_transform_)
-
- netE.eval()
-
- ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
-
- with torch.no_grad():
- for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
-
- opt.netE = ckpt
- logger.info("Resume Path:%s" % opt.netE)
-
- resumed_param = torch.load(opt.netE)
- netE.load_state_dict(resumed_param['model_state'])
-
-
- if opt.eval_recon_mvr:
- # Evaluate generation
- evaluate_recon_mvr(opt, netE, outf_syn, logger)
-
- if opt.generate_multimodal:
-
- generate_multimodal(opt, netE, outf_syn, logger)
-
-
-def parse_args():
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
- parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
- help='input batch size')
- parser.add_argument('--classes', default=['car'])
-
- parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
- parser.add_argument('--workers', type=int, default=16, help='workers')
- parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
-
- parser.add_argument('--eval_recon_mvr', default=True)
- parser.add_argument('--generate_multimodal', default=False)
-
- parser.add_argument('--nc', default=3)
- parser.add_argument('--npoints', default=2048)
- parser.add_argument('--svpoints', default=200)
- '''model'''
- parser.add_argument('--beta_start', default=0.0001)
- parser.add_argument('--beta_end', default=0.02)
- parser.add_argument('--schedule_type', default='linear')
- parser.add_argument('--time_num', default=1000)
-
- #params
- parser.add_argument('--attention', default=True)
- parser.add_argument('--dropout', default=0.1)
- parser.add_argument('--embed_dim', type=int, default=64)
- parser.add_argument('--loss_type', default='mse')
- parser.add_argument('--model_mean_type', default='eps')
- parser.add_argument('--model_var_type', default='fixedsmall')
-
-
- # constrain function
- parser.add_argument('--constrain_eps', default=0.2)
- parser.add_argument('--constrain_steps', type=int, default=1)
-
- parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/3_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-11-01-03-08-40', help="path to netE (to continue training)")
-
- '''eval'''
-
- parser.add_argument('--eval_path',
- default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
-
- parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
-
- parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
-
- opt = parser.parse_args()
-
- if torch.cuda.is_available():
- opt.cuda = True
- else:
- opt.cuda = False
-
- return opt
-if __name__ == '__main__':
- opt = parse_args()
- set_seed(opt)
-
- main(opt)
diff --git a/shape_completion/test_chair.py b/shape_completion/test_chair.py
deleted file mode 100644
index dda8e3c..0000000
--- a/shape_completion/test_chair.py
+++ /dev/null
@@ -1,753 +0,0 @@
-
-from pprint import pprint
-from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
-
-import torch.nn as nn
-import torch.utils.data
-
-import argparse
-from torch.distributions import Normal
-from utils.visualize import *
-from utils.file_utils import *
-from utils.mitsuba_renderer import write_to_xml_batch
-from model.pvcnn_completion import PVCNN2Base
-
-from datasets.shapenet_data_pc import ShapeNet15kPointClouds
-from datasets.shapenet_data_sv import *
-'''
-models
-'''
-def normal_kl(mean1, logvar1, mean2, logvar2):
- """
- KL divergence between normal distributions parameterized by mean and log-variance.
- """
- return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
- + (mean1 - mean2)**2 * torch.exp(-logvar2))
-
-def discretized_gaussian_log_likelihood(x, *, means, log_scales):
- # Assumes data is integers [0, 1]
- assert x.shape == means.shape == log_scales.shape
- px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
-
- centered_x = x - means
- inv_stdv = torch.exp(-log_scales)
- plus_in = inv_stdv * (centered_x + 0.5)
- cdf_plus = px0.cdf(plus_in)
- min_in = inv_stdv * (centered_x - .5)
- cdf_min = px0.cdf(min_in)
- log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
- log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
- cdf_delta = cdf_plus - cdf_min
-
- log_probs = torch.where(
- x < 0.001, log_cdf_plus,
- torch.where(x > 0.999, log_one_minus_cdf_min,
- torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
- assert log_probs.shape == x.shape
- return log_probs
-
-
-
-class GaussianDiffusion:
- def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
- self.loss_type = loss_type
- self.model_mean_type = model_mean_type
- self.model_var_type = model_var_type
- assert isinstance(betas, np.ndarray)
- self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
- assert (betas > 0).all() and (betas <= 1).all()
- timesteps, = betas.shape
- self.num_timesteps = int(timesteps)
- self.sv_points = sv_points
- # initialize twice the actual length so we can keep running for eval
- # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
-
- alphas = 1. - betas
- alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
- alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
-
- self.betas = torch.from_numpy(betas).float()
- self.alphas_cumprod = alphas_cumprod.float()
- self.alphas_cumprod_prev = alphas_cumprod_prev.float()
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
- self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
- self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
- self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
- self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
-
- betas = torch.from_numpy(betas).float()
- alphas = torch.from_numpy(alphas).float()
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.posterior_variance = posterior_variance
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
- self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
- self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
-
- @staticmethod
- def _extract(a, t, x_shape):
- """
- Extract some coefficients at specified timesteps,
- then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
- """
- bs, = t.shape
- assert x_shape[0] == bs
- out = torch.gather(a, 0, t)
- assert out.shape == torch.Size([bs])
- return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
-
-
-
- def q_mean_variance(self, x_start, t):
- mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
- variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
- log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
- return mean, variance, log_variance
-
- def q_sample(self, x_start, t, noise=None):
- """
- Diffuse the data (t == 0 means diffused for 1 step)
- """
- if noise is None:
- noise = torch.randn(x_start.shape, device=x_start.device)
- assert noise.shape == x_start.shape
- return (
- self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
- self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
- )
-
-
- def q_posterior_mean_variance(self, x_start, x_t, t):
- """
- Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
- """
- assert x_start.shape == x_t.shape
- posterior_mean = (
- self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
- self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
- )
- posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
- posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
- assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
- x_start.shape[0])
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
-
- def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
-
- model_output = denoise_fn(data, t)[:,:,self.sv_points:]
-
-
- if self.model_var_type in ['fixedsmall', 'fixedlarge']:
- # below: only log_variance is used in the KL computations
- model_variance, model_log_variance = {
- # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
- 'fixedlarge': (self.betas.to(data.device),
- torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
- 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
- }[self.model_var_type]
- model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
- model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
- else:
- raise NotImplementedError(self.model_var_type)
-
- if self.model_mean_type == 'eps':
- x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
-
-
- model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
- else:
- raise NotImplementedError(self.loss_type)
-
-
- assert model_mean.shape == x_recon.shape
- assert model_variance.shape == model_log_variance.shape
- if return_pred_xstart:
- return model_mean, model_variance, model_log_variance, x_recon
- else:
- return model_mean, model_variance, model_log_variance
-
- def _predict_xstart_from_eps(self, x_t, t, eps):
- assert x_t.shape == eps.shape
- return (
- self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
- )
-
- ''' samples '''
-
- def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
- """
- Sample from the model
- """
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
- return_pred_xstart=True)
- noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
-
- # no noise when t == 0
- nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
-
- sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
- sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
- return (sample, pred_xstart) if return_pred_xstart else sample
-
-
- def p_sample_loop(self, partial_x, denoise_fn, shape, device,
- noise_fn=torch.randn, clip_denoised=True, keep_running=False):
- """
- Generate samples
- keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
-
- """
-
- assert isinstance(shape, (tuple, list))
- img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
- for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised, return_pred_xstart=False)
-
- assert img_t[:,:,self.sv_points:].shape == shape
- return img_t
-
- def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
- noise_fn=torch.randn,clip_denoised=True, keep_running=False):
- """
- Generate samples, returning intermediate images
- Useful for visualizing how denoised images evolve over time
- Args:
- repeat_noise_steps (int): Number of denoising timesteps in which the same noise
- is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
- """
- assert isinstance(shape, (tuple, list))
-
- total_steps = self.num_timesteps if not keep_running else len(self.betas)
-
- img_t = noise_fn(size=shape, dtype=torch.float, device=device)
- imgs = [img_t]
- for t in reversed(range(0,total_steps)):
-
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- return_pred_xstart=False)
- if t % freq == 0 or t == total_steps-1:
- imgs.append(img_t)
-
- assert imgs[-1].shape == shape
- return imgs
-
- '''losses'''
-
- def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
- true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
- denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
-
- kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
- kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
-
- return (kl, pred_xstart) if return_pred_xstart else kl
-
- def p_losses(self, denoise_fn, data_start, t, noise=None):
- """
- Training loss calculation
- """
- B, D, N = data_start.shape
- assert t.shape == torch.Size([B])
-
- if noise is None:
- noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
-
- data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
-
- if self.loss_type == 'mse':
- # predict the noise instead of x_start. seems to be weighted naturally like SNR
- eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
-
- losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
- elif self.loss_type == 'kl':
- losses = self._vb_terms_bpd(
- denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
- return_pred_xstart=False)
- else:
- raise NotImplementedError(self.loss_type)
-
- assert losses.shape == torch.Size([B])
- return losses
-
- '''debug'''
-
- def _prior_bpd(self, x_start):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
- t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
- mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
- assert kl_prior.shape == x_start.shape
- return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
-
- def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
-
- vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
- for t in reversed(range(T)):
-
- t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
- # Calculate VLB term at the current timestep
- data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
- new_vals_b, pred_xstart = self._vb_terms_bpd(
- denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
- clip_denoised=clip_denoised, return_pred_xstart=True)
- # MSE for progressive prediction loss
- assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
- new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
- assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
- # Insert the calculated term into the tensor of all terms
- mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
- vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
- mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
- assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
-
- prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
- total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
- assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
- total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
- return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
-
-
-class PVCNN2(PVCNN2Base):
- sa_blocks = [
- ((32, 2, 32), (1024, 0.1, 32, (32, 64))),
- ((64, 3, 16), (256, 0.2, 32, (64, 128))),
- ((128, 3, 8), (64, 0.4, 32, (128, 256))),
- (None, (16, 0.8, 32, (256, 256, 512))),
- ]
- fp_blocks = [
- ((256, 256), (256, 3, 8)),
- ((256, 256), (256, 3, 8)),
- ((256, 128), (128, 2, 16)),
- ((128, 128, 64), (64, 2, 32)),
- ]
-
- def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
- voxel_resolution_multiplier=1):
- super().__init__(
- num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
- dropout=dropout, extra_feature_channels=extra_feature_channels,
- width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
- )
-
-
-class Model(nn.Module):
- def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
- super(Model, self).__init__()
- self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
-
- self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
- dropout=args.dropout, extra_feature_channels=0)
-
- def prior_kl(self, x0):
- return self.diffusion._prior_bpd(x0)
-
- def all_kl(self, x0, clip_denoised=True):
- total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
-
- return {
- 'total_bpd_b': total_bpd_b,
- 'terms_bpd': vals_bt,
- 'prior_bpd_b': prior_bpd_b,
- 'mse_bt':mse_bt
- }
-
-
- def _denoise(self, data, t):
- B, D,N= data.shape
- assert data.dtype == torch.float
- assert t.shape == torch.Size([B]) and t.dtype == torch.int64
-
- out = self.model(data, t)
-
- return out
-
- def get_loss_iter(self, data, noises=None):
- B, D, N = data.shape
- t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
-
- if noises is not None:
- noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
-
- losses = self.diffusion.p_losses(
- denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
- assert losses.shape == t.shape == torch.Size([B])
- return losses
-
- def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
- clip_denoised=True,
- keep_running=False):
- return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- keep_running=keep_running)
-
-
- def train(self):
- self.model.train()
-
- def eval(self):
- self.model.eval()
-
- def multi_gpu_wrapper(self, f):
- self.model = f(self.model)
-
-def get_betas(schedule_type, b_start, b_end, time_num):
- if schedule_type == 'linear':
- betas = np.linspace(b_start, b_end, time_num)
- elif schedule_type == 'warm0.1':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.1)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.2':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.2)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.5':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.5)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- else:
- raise NotImplementedError(schedule_type)
- return betas
-
-
-
-#############################################################################
-def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
- cache=os.path.join(mesh_root, '../cache'), split='val',
- categories=category,
- radius=3, elev=-89, azim=180, img_size=512, focal_length=1000,
- npoints=npoints, sv_samples=200,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- )
- return te_dataset
-
-def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
- cache=os.path.join(pc_dataroot, '../cache'), split='val',
- categories=category,
- npoints=npoints, sv_samples=200,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- )
- return te_dataset
-
-
-def evaluate_recon_mvr(opt, netE, save_dir, logger):
- test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
- opt.npoints, opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
- ref = []
- samples = []
- images = []
- masked = []
- k = 0
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
-
- gt_all = data['test_points']
- x_all = data['sv_points']
- # img_all = data['image']
-
-
- B,V,N,C = x_all.shape
- gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
-
- x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
- # img = img_all.reshape(B * V, *img_all.shape[2:])
-
- m, s = data['mean'].float(), data['std'].float()
-
- recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
- clip_denoised=False).detach().cpu()
-
- recon = recon.transpose(1, 2).contiguous()
- x = x.transpose(1, 2).contiguous()
-
-
- x_adj = x.reshape(B,V,N,C)* s + m
- recon_adj = recon.reshape(B,V,N,C)* s + m
- # img = img.reshape(B,V,*img.shape[1:])
-
- ref.append( gt_all * s + m)
- masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
- samples.append(recon_adj)
- # images.append(img)
-
- ref_pcs = torch.cat(ref, dim=0)
- sample_pcs = torch.cat(samples, dim=0)
- # images = torch.cat(images, dim=0)
- masked = torch.cat(masked, dim=0)
-
- B, V, N, C = ref_pcs.shape
-
-
- torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
- # torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
- torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
- # Compute metrics
- results = EMD_CD(sample_pcs.reshape(B*V, N, C),
- ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
-
- results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
-
- pprint({key: val.mean().item() for key, val in results.items()})
- logger.info({key: val.mean().item() for key, val in results.items()})
-
- results['pc'] = sample_pcs
- torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
-
- del ref_pcs, masked, results
-
-def evaluate_saved(opt, netE, save_dir, logger):
- ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn'
-
- gt_pth = ours_base + '/recon_gt.pth'
- ours_pth = ours_base + '/ours_results.pth'
- gt = torch.load(gt_pth).permute(1,0,2,3)
- ours = torch.load(ours_pth)['pc'].permute(1,0,2,3)
-
- all_res = {}
- for i, (gt_, ours_) in enumerate(zip(gt, ours)):
- results = compute_all_metrics(gt_, ours_, opt.batch_size)
-
- for key, val in results.items():
- if i == 0:
- all_res[key] = val
- else:
- all_res[key] += val
- pprint(results)
- for key, val in all_res.items():
- all_res[key] = val / gt.shape[0]
-
- pprint({key: val.mean().item() for key, val in all_res.items()})
-
-
-def generate_multimodal(opt, netE, save_dir, logger):
- test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv,
- opt.npoints, opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
-
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
-
- gt_all = data['test_points']
- x_all = data['sv_points']
- img_all = data['image']
-
- for v in range(6):
- x = x_all.transpose(1, 2).contiguous()
- img = img_all
-
- m, s = data['mean'].float(), data['std'].float()
-
- recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
- clip_denoised=False).detach().cpu()
-
- recon = recon.transpose(1, 2).contiguous()
- x = x.transpose(1, 2).contiguous()
-
- for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))):
-
- im = np.fliplr(np.flipud(d[-1]))
- plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray')
- write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy(), cat='chair')
-
- export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
-
-
-def redwood_demo(opt, netE, save_dir, logger):
- import open3d as o3d
- pth = "/viscam/u/alexzhou907/01DATA/redwood/09515_pc_partial.ply"
- pth_gt = "/viscam/u/alexzhou907/01DATA/redwood/09515_pc.ply"
-
- points = np.asarray(o3d.io.read_point_cloud(pth).points)
-
- gt_points = np.asarray(o3d.io.read_point_cloud(pth_gt).points)
-
- np.save('gt.npy', gt_points)
-
- test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
- opt.npoints, opt.classes)
-
- m, s = torch.from_numpy(test_dataset[0]['mean']).float(), torch.from_numpy(test_dataset[0]['std']).float()
-
- x = torch.from_numpy(points[np.random.choice(points.shape[0], size=opt.npoints, replace=False)]).float()
- x = (x-m)/s
-
- x = x.transpose(1,2).cuda()
-
- res = []
- for k in range(20):
- recon = netE.gen_samples(x[:,:,:opt.svpoints], x[:,:,opt.svpoints:].shape, 'cuda',
- clip_denoised=False).detach().cpu()
- recon = recon.transpose(1, 2).contiguous()
- recon = recon * s+ m
- res.append(recon)
- res = torch.cat(res, dim=0)
-
- write_to_xml_batch(os.path.join(save_dir, 'xml'),
- (res).numpy(), cat='chair')
-
- export_to_pc_batch(os.path.join(save_dir, 'ply'),
- (res).numpy())
-
- torch.save(res, os.path.join(save_dir, 'redwood_demo.pth'))
-
- pcwrite(os.path.join(save_dir, 'ply', 'gt.ply'),
- gt_points[np.random.choice(gt_points.shape[0], size=opt.npoints, replace=False)])
- write_to_xml_batch(os.path.join(save_dir, 'xml_gt'),
- gt_points[None], cat='chair')
-
- exit()
-
-def main(opt):
- exp_id = os.path.splitext(os.path.basename(__file__))[0]
- dir_id = os.path.dirname(__file__)
- output_dir = get_output_dir(dir_id, exp_id)
- copy_source(__file__, output_dir)
- logger = setup_logging(output_dir)
-
- outf_syn, = setup_output_subdirs(output_dir, 'syn')
-
- betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
- netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
-
- if opt.cuda:
- netE.cuda()
-
- def _transform_(m):
- return nn.parallel.DataParallel(m)
-
- netE = netE.cuda()
- netE.multi_gpu_wrapper(_transform_)
-
- netE.eval()
-
- ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
-
- with torch.no_grad():
- for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
-
- opt.netE = ckpt
- logger.info("Resume Path:%s" % opt.netE)
-
- resumed_param = torch.load(opt.netE)
- netE.load_state_dict(resumed_param['model_state'])
-
-
- if opt.eval_recon_mvr:
- # Evaluate generation
- evaluate_recon_mvr(opt, netE, outf_syn, logger)
-
- if opt.generate_multimodal:
-
- generate_multimodal(opt, netE, outf_syn, logger)
-
- if opt.eval_saved:
- evaluate_saved(opt, netE, outf_syn, logger)
-
- if opt.eval_redwood:
- redwood_demo(opt, netE, outf_syn, logger)
-
-def parse_args():
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
- parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
- help='input batch size')
- parser.add_argument('--classes', default=['chair'])
-
- parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
- parser.add_argument('--workers', type=int, default=16, help='workers')
- parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
-
- parser.add_argument('--eval_recon_mvr', default=False)
- parser.add_argument('--generate_multimodal', default=False)
- parser.add_argument('--eval_saved', default=False)
- parser.add_argument('--eval_redwood', default=True)
-
- parser.add_argument('--nc', default=3)
- parser.add_argument('--npoints', default=2048)
- parser.add_argument('--svpoints', default=200)
- '''model'''
- parser.add_argument('--beta_start', default=0.0001)
- parser.add_argument('--beta_end', default=0.02)
- parser.add_argument('--schedule_type', default='linear')
- parser.add_argument('--time_num', default=1000)
-
- #params
- parser.add_argument('--attention', default=True)
- parser.add_argument('--dropout', default=0.1)
- parser.add_argument('--embed_dim', type=int, default=64)
- parser.add_argument('--loss_type', default='mse')
- parser.add_argument('--model_mean_type', default='eps')
- parser.add_argument('--model_var_type', default='fixedsmall')
-
-
- # constrain function
- parser.add_argument('--constrain_eps', default=0.2)
- parser.add_argument('--constrain_steps', type=int, default=1)
-
- parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/6_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-11-01-19-21-18', help="path to netE (to continue training)")
-
- '''eval'''
-
- parser.add_argument('--eval_path',
- default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
-
- parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
-
- parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
-
- opt = parser.parse_args()
-
- if torch.cuda.is_available():
- opt.cuda = True
- else:
- opt.cuda = False
-
- return opt
-if __name__ == '__main__':
- opt = parse_args()
-
- main(opt)
diff --git a/shape_completion/test_partnet_chair.py b/shape_completion/test_partnet_chair.py
deleted file mode 100644
index babb8a0..0000000
--- a/shape_completion/test_partnet_chair.py
+++ /dev/null
@@ -1,599 +0,0 @@
-
-from pprint import pprint
-from tqdm import tqdm
-import torch.nn as nn
-import torch.utils.data
-
-import argparse
-from torch.distributions import Normal
-from utils.visualize import *
-from utils.file_utils import *
-from utils.mitsuba_renderer import write_to_xml_batch
-from model.pvcnn_completion import PVCNN2Base
-
-from datasets.shapenet_data_pc import ShapeNet15kPointClouds
-from datasets.partnet import GANdatasetPartNet
-import trimesh
-import csv
-import numpy as np
-import random
-from plyfile import PlyData, PlyElement
-
-
-def write_ply(points, filename, text=False):
- """ input: Nx3, write points to filename as PLY format. """
- points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])]
- vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')])
- el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
- with open(filename, mode='wb') as f:
- PlyData([el], text=text).write(f)
-
-'''
-models
-'''
-def normal_kl(mean1, logvar1, mean2, logvar2):
- """
- KL divergence between normal distributions parameterized by mean and log-variance.
- """
- return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
- + (mean1 - mean2)**2 * torch.exp(-logvar2))
-
-def discretized_gaussian_log_likelihood(x, *, means, log_scales):
- # Assumes data is integers [0, 1]
- assert x.shape == means.shape == log_scales.shape
- px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
-
- centered_x = x - means
- inv_stdv = torch.exp(-log_scales)
- plus_in = inv_stdv * (centered_x + 0.5)
- cdf_plus = px0.cdf(plus_in)
- min_in = inv_stdv * (centered_x - .5)
- cdf_min = px0.cdf(min_in)
- log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
- log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
- cdf_delta = cdf_plus - cdf_min
-
- log_probs = torch.where(
- x < 0.001, log_cdf_plus,
- torch.where(x > 0.999, log_one_minus_cdf_min,
- torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
- assert log_probs.shape == x.shape
- return log_probs
-
-
-
-class GaussianDiffusion:
- def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
- self.loss_type = loss_type
- self.model_mean_type = model_mean_type
- self.model_var_type = model_var_type
- assert isinstance(betas, np.ndarray)
- self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
- assert (betas > 0).all() and (betas <= 1).all()
- timesteps, = betas.shape
- self.num_timesteps = int(timesteps)
- self.sv_points = sv_points
- # initialize twice the actual length so we can keep running for eval
- # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
-
- alphas = 1. - betas
- alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
- alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
-
- self.betas = torch.from_numpy(betas).float()
- self.alphas_cumprod = alphas_cumprod.float()
- self.alphas_cumprod_prev = alphas_cumprod_prev.float()
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
- self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
- self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
- self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
- self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
-
- betas = torch.from_numpy(betas).float()
- alphas = torch.from_numpy(alphas).float()
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.posterior_variance = posterior_variance
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
- self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
- self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
-
- @staticmethod
- def _extract(a, t, x_shape):
- """
- Extract some coefficients at specified timesteps,
- then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
- """
- bs, = t.shape
- assert x_shape[0] == bs
- out = torch.gather(a, 0, t)
- assert out.shape == torch.Size([bs])
- return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
-
-
-
- def q_mean_variance(self, x_start, t):
- mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
- variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
- log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
- return mean, variance, log_variance
-
- def q_sample(self, x_start, t, noise=None):
- """
- Diffuse the data (t == 0 means diffused for 1 step)
- """
- if noise is None:
- noise = torch.randn(x_start.shape, device=x_start.device)
- assert noise.shape == x_start.shape
- return (
- self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
- self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
- )
-
-
- def q_posterior_mean_variance(self, x_start, x_t, t):
- """
- Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
- """
- assert x_start.shape == x_t.shape
- posterior_mean = (
- self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
- self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
- )
- posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
- posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
- assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
- x_start.shape[0])
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
-
- def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
-
- model_output = denoise_fn(data, t)[:,:,self.sv_points:]
-
-
- if self.model_var_type in ['fixedsmall', 'fixedlarge']:
- # below: only log_variance is used in the KL computations
- model_variance, model_log_variance = {
- # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
- 'fixedlarge': (self.betas.to(data.device),
- torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
- 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
- }[self.model_var_type]
- model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
- model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
- else:
- raise NotImplementedError(self.model_var_type)
-
- if self.model_mean_type == 'eps':
- x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
-
-
- model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
- else:
- raise NotImplementedError(self.loss_type)
-
-
- assert model_mean.shape == x_recon.shape
- assert model_variance.shape == model_log_variance.shape
- if return_pred_xstart:
- return model_mean, model_variance, model_log_variance, x_recon
- else:
- return model_mean, model_variance, model_log_variance
-
- def _predict_xstart_from_eps(self, x_t, t, eps):
- assert x_t.shape == eps.shape
- return (
- self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
- )
-
- ''' samples '''
-
- def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
- """
- Sample from the model
- """
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
- return_pred_xstart=True)
- noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
-
- # no noise when t == 0
- nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
-
- sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
- sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
- return (sample, pred_xstart) if return_pred_xstart else sample
-
-
- def p_sample_loop(self, partial_x, denoise_fn, shape, device,
- noise_fn=torch.randn, clip_denoised=True, keep_running=False):
- """
- Generate samples
- keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
-
- """
-
- assert isinstance(shape, (tuple, list))
- img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
- for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised, return_pred_xstart=False)
-
- assert img_t[:,:,self.sv_points:].shape == shape
- return img_t
-
- def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
- noise_fn=torch.randn,clip_denoised=True, keep_running=False):
- """
- Generate samples, returning intermediate images
- Useful for visualizing how denoised images evolve over time
- Args:
- repeat_noise_steps (int): Number of denoising timesteps in which the same noise
- is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
- """
- assert isinstance(shape, (tuple, list))
-
- total_steps = self.num_timesteps if not keep_running else len(self.betas)
-
- img_t = noise_fn(size=shape, dtype=torch.float, device=device)
- imgs = [img_t]
- for t in reversed(range(0,total_steps)):
-
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- return_pred_xstart=False)
- if t % freq == 0 or t == total_steps-1:
- imgs.append(img_t)
-
- assert imgs[-1].shape == shape
- return imgs
-
- '''losses'''
-
- def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
- true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
- denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
-
- kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
- kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
-
- return (kl, pred_xstart) if return_pred_xstart else kl
-
- def p_losses(self, denoise_fn, data_start, t, noise=None):
- """
- Training loss calculation
- """
- B, D, N = data_start.shape
- assert t.shape == torch.Size([B])
-
- if noise is None:
- noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
-
- data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
-
- if self.loss_type == 'mse':
- # predict the noise instead of x_start. seems to be weighted naturally like SNR
- eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
-
- losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
- elif self.loss_type == 'kl':
- losses = self._vb_terms_bpd(
- denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
- return_pred_xstart=False)
- else:
- raise NotImplementedError(self.loss_type)
-
- assert losses.shape == torch.Size([B])
- return losses
-
- '''debug'''
-
- def _prior_bpd(self, x_start):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
- t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
- mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
- assert kl_prior.shape == x_start.shape
- return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
-
- def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
-
- vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
- for t in reversed(range(T)):
-
- t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
- # Calculate VLB term at the current timestep
- data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
- new_vals_b, pred_xstart = self._vb_terms_bpd(
- denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
- clip_denoised=clip_denoised, return_pred_xstart=True)
- # MSE for progressive prediction loss
- assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
- new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
- assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
- # Insert the calculated term into the tensor of all terms
- mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
- vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
- mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
- assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
-
- prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
- total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
- assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
- total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
- return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
-
-
-class PVCNN2(PVCNN2Base):
- sa_blocks = [
- ((32, 2, 32), (1024, 0.1, 32, (32, 64))),
- ((64, 3, 16), (256, 0.2, 32, (64, 128))),
- ((128, 3, 8), (64, 0.4, 32, (128, 256))),
- (None, (16, 0.8, 32, (256, 256, 512))),
- ]
- fp_blocks = [
- ((256, 256), (256, 3, 8)),
- ((256, 256), (256, 3, 8)),
- ((256, 128), (128, 2, 16)),
- ((128, 128, 64), (64, 2, 32)),
- ]
-
- def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
- voxel_resolution_multiplier=1):
- super().__init__(
- num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
- dropout=dropout, extra_feature_channels=extra_feature_channels,
- width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
- )
-
-
-class Model(nn.Module):
- def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
- super(Model, self).__init__()
- self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
-
- self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
- dropout=args.dropout, extra_feature_channels=0)
-
- def prior_kl(self, x0):
- return self.diffusion._prior_bpd(x0)
-
- def all_kl(self, x0, clip_denoised=True):
- total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
-
- return {
- 'total_bpd_b': total_bpd_b,
- 'terms_bpd': vals_bt,
- 'prior_bpd_b': prior_bpd_b,
- 'mse_bt':mse_bt
- }
-
-
- def _denoise(self, data, t):
- B, D,N= data.shape
- assert data.dtype == torch.float
- assert t.shape == torch.Size([B]) and t.dtype == torch.int64
-
- out = self.model(data, t)
-
- return out
-
- def get_loss_iter(self, data, noises=None):
- B, D, N = data.shape
- t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
-
- if noises is not None:
- noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
-
- losses = self.diffusion.p_losses(
- denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
- assert losses.shape == t.shape == torch.Size([B])
- return losses
-
- def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
- clip_denoised=True,
- keep_running=False):
- return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- keep_running=keep_running)
-
-
- def train(self):
- self.model.train()
-
- def eval(self):
- self.model.eval()
-
- def multi_gpu_wrapper(self, f):
- self.model = f(self.model)
-
-def get_betas(schedule_type, b_start, b_end, time_num):
- if schedule_type == 'linear':
- betas = np.linspace(b_start, b_end, time_num)
- elif schedule_type == 'warm0.1':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.1)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.2':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.2)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.5':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.5)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- else:
- raise NotImplementedError(schedule_type)
- return betas
-
-
-
-#############################################################################
-def get_dataset(data_root, npoints, category):
-
- train_ds = GANdatasetPartNet('test', data_root, category, npoints)
- return train_ds
-
-
-def generate_multimodal(opt, netE, save_dir, logger):
- test_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
-
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
-
- gt_all = data['real']
- x_all = data['raw']
-
- for j in range(5):
- x = torch.cat([x_all, gt_all[:, :, opt.svpoints:]], dim=-1)
-
- recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
- clip_denoised=False).detach().cpu()
-
-
-
-
- for p, d in enumerate(zip(list(x_all), list(recon), list(data['raw_id']))):
-
- partial = torch.cat([d[0], d[0][:, 0:1].expand(-1, opt.svpoints)], dim=-1)
- rec = d[1]
- rid = d[2]
- write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d' % (i, p), 'mode_%03d' % j),
- (torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy(), cat='chair')
-
- export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, p), 'mode_%03d' % j),
- (torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy())
-
- raw_id = rid.split('.')[0]
- save_sample_dir = os.path.join(save_dir, "{}".format(raw_id))
- Path(save_sample_dir).mkdir(parents=True, exist_ok=True)
- # save input partial shape
- if j == 0:
- save_path = os.path.join(save_sample_dir, "raw.ply")
- write_ply((partial.detach().cpu()).transpose(0, 1).numpy(), save_path)
- # save completed shape
- save_path = os.path.join(save_sample_dir, "fake-z{}.ply".format(j))
- write_ply((rec.detach().cpu()).transpose(0, 1).numpy(), save_path)
-
-
-def main(opt):
- exp_id = os.path.splitext(os.path.basename(__file__))[0]
- dir_id = os.path.dirname(__file__)
- output_dir = get_output_dir(dir_id, exp_id)
- copy_source(__file__, output_dir)
- logger = setup_logging(output_dir)
-
- outf_syn, = setup_output_subdirs(output_dir, 'syn')
-
- betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
- netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
-
- if opt.cuda:
- netE.cuda()
-
- def _transform_(m):
- return nn.parallel.DataParallel(m)
-
- netE = netE.cuda()
- netE.multi_gpu_wrapper(_transform_)
-
- netE.eval()
-
- ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
-
- with torch.no_grad():
- for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
-
- opt.netE = ckpt
- logger.info("Resume Path:%s" % opt.netE)
-
- resumed_param = torch.load(opt.netE)
- netE.load_state_dict(resumed_param['model_state'])
-
-
-
- if opt.generate_multimodal:
-
- generate_multimodal(opt, netE, outf_syn, logger)
-
-
-def parse_args():
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--data_root', default='/viscam/u/alexzhou907/01DATA/partnet',
- help='input batch size')
-
- parser.add_argument('--classes', default='Chair')
-
- parser.add_argument('--batch_size', type=int, default=64, help='input batch size')
- parser.add_argument('--workers', type=int, default=16, help='workers')
- parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
-
- parser.add_argument('--eval_recon_mvr', default=False)
- parser.add_argument('--generate_multimodal', default=True)
- parser.add_argument('--eval_saved', default=False)
-
- parser.add_argument('--nc', default=3)
- parser.add_argument('--npoints', default=2048)
- parser.add_argument('--svpoints', default=1024)
- '''model'''
- parser.add_argument('--beta_start', default=0.0001)
- parser.add_argument('--beta_end', default=0.02)
- parser.add_argument('--schedule_type', default='linear')
- parser.add_argument('--time_num', default=1000)
-
- #params
- parser.add_argument('--attention', default=True)
- parser.add_argument('--dropout', default=0.1)
- parser.add_argument('--embed_dim', type=int, default=64)
- parser.add_argument('--loss_type', default='mse')
- parser.add_argument('--model_mean_type', default='eps')
- parser.add_argument('--model_var_type', default='fixedsmall')
-
-
- # constrain function
- parser.add_argument('--constrain_eps', default=0.2)
- parser.add_argument('--constrain_steps', type=int, default=1)
-
- parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/epn3d_chair', help="path to netE (to continue training)")
-
- '''eval'''
-
- parser.add_argument('--eval_path',
- default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
-
- parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
-
- parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
-
- opt = parser.parse_args()
-
- if torch.cuda.is_available():
- opt.cuda = True
- else:
- opt.cuda = False
-
- return opt
-if __name__ == '__main__':
- opt = parse_args()
-
- main(opt)
diff --git a/shape_completion/test_partnet_table.py b/shape_completion/test_partnet_table.py
deleted file mode 100644
index 8efbde5..0000000
--- a/shape_completion/test_partnet_table.py
+++ /dev/null
@@ -1,599 +0,0 @@
-
-from pprint import pprint
-from tqdm import tqdm
-import torch.nn as nn
-import torch.utils.data
-
-import argparse
-from torch.distributions import Normal
-from utils.visualize import *
-from utils.file_utils import *
-from utils.mitsuba_renderer import write_to_xml_batch
-from model.pvcnn_completion import PVCNN2Base
-
-from datasets.shapenet_data_pc import ShapeNet15kPointClouds
-from datasets.partnet import GANdatasetPartNet
-import trimesh
-import csv
-import numpy as np
-import random
-from plyfile import PlyData, PlyElement
-
-
-def write_ply(points, filename, text=False):
- """ input: Nx3, write points to filename as PLY format. """
- points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])]
- vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')])
- el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
- with open(filename, mode='wb') as f:
- PlyData([el], text=text).write(f)
-
-'''
-models
-'''
-def normal_kl(mean1, logvar1, mean2, logvar2):
- """
- KL divergence between normal distributions parameterized by mean and log-variance.
- """
- return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
- + (mean1 - mean2)**2 * torch.exp(-logvar2))
-
-def discretized_gaussian_log_likelihood(x, *, means, log_scales):
- # Assumes data is integers [0, 1]
- assert x.shape == means.shape == log_scales.shape
- px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
-
- centered_x = x - means
- inv_stdv = torch.exp(-log_scales)
- plus_in = inv_stdv * (centered_x + 0.5)
- cdf_plus = px0.cdf(plus_in)
- min_in = inv_stdv * (centered_x - .5)
- cdf_min = px0.cdf(min_in)
- log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
- log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
- cdf_delta = cdf_plus - cdf_min
-
- log_probs = torch.where(
- x < 0.001, log_cdf_plus,
- torch.where(x > 0.999, log_one_minus_cdf_min,
- torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
- assert log_probs.shape == x.shape
- return log_probs
-
-
-
-class GaussianDiffusion:
- def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
- self.loss_type = loss_type
- self.model_mean_type = model_mean_type
- self.model_var_type = model_var_type
- assert isinstance(betas, np.ndarray)
- self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
- assert (betas > 0).all() and (betas <= 1).all()
- timesteps, = betas.shape
- self.num_timesteps = int(timesteps)
- self.sv_points = sv_points
- # initialize twice the actual length so we can keep running for eval
- # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
-
- alphas = 1. - betas
- alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
- alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
-
- self.betas = torch.from_numpy(betas).float()
- self.alphas_cumprod = alphas_cumprod.float()
- self.alphas_cumprod_prev = alphas_cumprod_prev.float()
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
- self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
- self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
- self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
- self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
-
- betas = torch.from_numpy(betas).float()
- alphas = torch.from_numpy(alphas).float()
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.posterior_variance = posterior_variance
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
- self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
- self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
-
- @staticmethod
- def _extract(a, t, x_shape):
- """
- Extract some coefficients at specified timesteps,
- then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
- """
- bs, = t.shape
- assert x_shape[0] == bs
- out = torch.gather(a, 0, t)
- assert out.shape == torch.Size([bs])
- return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
-
-
-
- def q_mean_variance(self, x_start, t):
- mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
- variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
- log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
- return mean, variance, log_variance
-
- def q_sample(self, x_start, t, noise=None):
- """
- Diffuse the data (t == 0 means diffused for 1 step)
- """
- if noise is None:
- noise = torch.randn(x_start.shape, device=x_start.device)
- assert noise.shape == x_start.shape
- return (
- self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
- self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
- )
-
-
- def q_posterior_mean_variance(self, x_start, x_t, t):
- """
- Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
- """
- assert x_start.shape == x_t.shape
- posterior_mean = (
- self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
- self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
- )
- posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
- posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
- assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
- x_start.shape[0])
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
-
- def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
-
- model_output = denoise_fn(data, t)[:,:,self.sv_points:]
-
-
- if self.model_var_type in ['fixedsmall', 'fixedlarge']:
- # below: only log_variance is used in the KL computations
- model_variance, model_log_variance = {
- # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
- 'fixedlarge': (self.betas.to(data.device),
- torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
- 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
- }[self.model_var_type]
- model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
- model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
- else:
- raise NotImplementedError(self.model_var_type)
-
- if self.model_mean_type == 'eps':
- x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
-
-
- model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
- else:
- raise NotImplementedError(self.loss_type)
-
-
- assert model_mean.shape == x_recon.shape
- assert model_variance.shape == model_log_variance.shape
- if return_pred_xstart:
- return model_mean, model_variance, model_log_variance, x_recon
- else:
- return model_mean, model_variance, model_log_variance
-
- def _predict_xstart_from_eps(self, x_t, t, eps):
- assert x_t.shape == eps.shape
- return (
- self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
- )
-
- ''' samples '''
-
- def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
- """
- Sample from the model
- """
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
- return_pred_xstart=True)
- noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
-
- # no noise when t == 0
- nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
-
- sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
- sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
- return (sample, pred_xstart) if return_pred_xstart else sample
-
-
- def p_sample_loop(self, partial_x, denoise_fn, shape, device,
- noise_fn=torch.randn, clip_denoised=True, keep_running=False):
- """
- Generate samples
- keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
-
- """
-
- assert isinstance(shape, (tuple, list))
- img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
- for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised, return_pred_xstart=False)
-
- assert img_t[:,:,self.sv_points:].shape == shape
- return img_t
-
- def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
- noise_fn=torch.randn,clip_denoised=True, keep_running=False):
- """
- Generate samples, returning intermediate images
- Useful for visualizing how denoised images evolve over time
- Args:
- repeat_noise_steps (int): Number of denoising timesteps in which the same noise
- is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
- """
- assert isinstance(shape, (tuple, list))
-
- total_steps = self.num_timesteps if not keep_running else len(self.betas)
-
- img_t = noise_fn(size=shape, dtype=torch.float, device=device)
- imgs = [img_t]
- for t in reversed(range(0,total_steps)):
-
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- return_pred_xstart=False)
- if t % freq == 0 or t == total_steps-1:
- imgs.append(img_t)
-
- assert imgs[-1].shape == shape
- return imgs
-
- '''losses'''
-
- def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
- true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
- denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
-
- kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
- kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
-
- return (kl, pred_xstart) if return_pred_xstart else kl
-
- def p_losses(self, denoise_fn, data_start, t, noise=None):
- """
- Training loss calculation
- """
- B, D, N = data_start.shape
- assert t.shape == torch.Size([B])
-
- if noise is None:
- noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
-
- data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
-
- if self.loss_type == 'mse':
- # predict the noise instead of x_start. seems to be weighted naturally like SNR
- eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
-
- losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
- elif self.loss_type == 'kl':
- losses = self._vb_terms_bpd(
- denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
- return_pred_xstart=False)
- else:
- raise NotImplementedError(self.loss_type)
-
- assert losses.shape == torch.Size([B])
- return losses
-
- '''debug'''
-
- def _prior_bpd(self, x_start):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
- t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
- mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
- assert kl_prior.shape == x_start.shape
- return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
-
- def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
-
- vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
- for t in reversed(range(T)):
-
- t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
- # Calculate VLB term at the current timestep
- data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
- new_vals_b, pred_xstart = self._vb_terms_bpd(
- denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
- clip_denoised=clip_denoised, return_pred_xstart=True)
- # MSE for progressive prediction loss
- assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
- new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
- assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
- # Insert the calculated term into the tensor of all terms
- mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
- vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
- mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
- assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
-
- prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
- total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
- assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
- total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
- return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
-
-
-class PVCNN2(PVCNN2Base):
- sa_blocks = [
- ((32, 2, 32), (1024, 0.1, 32, (32, 64))),
- ((64, 3, 16), (256, 0.2, 32, (64, 128))),
- ((128, 3, 8), (64, 0.4, 32, (128, 256))),
- (None, (16, 0.8, 32, (256, 256, 512))),
- ]
- fp_blocks = [
- ((256, 256), (256, 3, 8)),
- ((256, 256), (256, 3, 8)),
- ((256, 128), (128, 2, 16)),
- ((128, 128, 64), (64, 2, 32)),
- ]
-
- def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
- voxel_resolution_multiplier=1):
- super().__init__(
- num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
- dropout=dropout, extra_feature_channels=extra_feature_channels,
- width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
- )
-
-
-class Model(nn.Module):
- def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
- super(Model, self).__init__()
- self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
-
- self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
- dropout=args.dropout, extra_feature_channels=0)
-
- def prior_kl(self, x0):
- return self.diffusion._prior_bpd(x0)
-
- def all_kl(self, x0, clip_denoised=True):
- total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
-
- return {
- 'total_bpd_b': total_bpd_b,
- 'terms_bpd': vals_bt,
- 'prior_bpd_b': prior_bpd_b,
- 'mse_bt':mse_bt
- }
-
-
- def _denoise(self, data, t):
- B, D,N= data.shape
- assert data.dtype == torch.float
- assert t.shape == torch.Size([B]) and t.dtype == torch.int64
-
- out = self.model(data, t)
-
- return out
-
- def get_loss_iter(self, data, noises=None):
- B, D, N = data.shape
- t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
-
- if noises is not None:
- noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
-
- losses = self.diffusion.p_losses(
- denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
- assert losses.shape == t.shape == torch.Size([B])
- return losses
-
- def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
- clip_denoised=True,
- keep_running=False):
- return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- keep_running=keep_running)
-
-
- def train(self):
- self.model.train()
-
- def eval(self):
- self.model.eval()
-
- def multi_gpu_wrapper(self, f):
- self.model = f(self.model)
-
-def get_betas(schedule_type, b_start, b_end, time_num):
- if schedule_type == 'linear':
- betas = np.linspace(b_start, b_end, time_num)
- elif schedule_type == 'warm0.1':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.1)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.2':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.2)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.5':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.5)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- else:
- raise NotImplementedError(schedule_type)
- return betas
-
-
-
-#############################################################################
-def get_dataset(data_root, npoints, category):
-
- train_ds = GANdatasetPartNet('test', data_root, category, npoints)
- return train_ds
-
-
-def generate_multimodal(opt, netE, save_dir, logger):
- test_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
-
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
-
- gt_all = data['real']
- x_all = data['raw']
-
- for j in range(5):
- x = torch.cat([x_all, gt_all[:, :, opt.svpoints:]], dim=-1)
-
- recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
- clip_denoised=False).detach().cpu()
-
-
-
-
- for p, d in enumerate(zip(list(x_all), list(recon), list(data['raw_id']))):
-
- partial = torch.cat([d[0], d[0][:, 0:1].expand(-1, opt.svpoints)], dim=-1)
- rec = d[1]
- rid = d[2]
- write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d' % (i, p), 'mode_%03d' % j),
- (torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy(), cat='chair')
-
- export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, p), 'mode_%03d' % j),
- (torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy())
-
- raw_id = rid.split('.')[0]
- save_sample_dir = os.path.join(save_dir, "{}".format(raw_id))
- Path(save_sample_dir).mkdir(parents=True, exist_ok=True)
- # save input partial shape
- if j == 0:
- save_path = os.path.join(save_sample_dir, "raw.ply")
- write_ply((partial.detach().cpu()).transpose(0, 1).numpy(), save_path)
- # save completed shape
- save_path = os.path.join(save_sample_dir, "fake-z{}.ply".format(j))
- write_ply((rec.detach().cpu()).transpose(0, 1).numpy(), save_path)
-
-
-def main(opt):
- exp_id = os.path.splitext(os.path.basename(__file__))[0]
- dir_id = os.path.dirname(__file__)
- output_dir = get_output_dir(dir_id, exp_id)
- copy_source(__file__, output_dir)
- logger = setup_logging(output_dir)
-
- outf_syn, = setup_output_subdirs(output_dir, 'syn')
-
- betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
- netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
-
- if opt.cuda:
- netE.cuda()
-
- def _transform_(m):
- return nn.parallel.DataParallel(m)
-
- netE = netE.cuda()
- netE.multi_gpu_wrapper(_transform_)
-
- netE.eval()
-
- ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
-
- with torch.no_grad():
- for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
-
- opt.netE = ckpt
- logger.info("Resume Path:%s" % opt.netE)
-
- resumed_param = torch.load(opt.netE)
- netE.load_state_dict(resumed_param['model_state'])
-
-
-
- if opt.generate_multimodal:
-
- generate_multimodal(opt, netE, outf_syn, logger)
-
-
-def parse_args():
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--data_root', default='/viscam/u/alexzhou907/01DATA/partnet',
- help='input batch size')
-
- parser.add_argument('--classes', default='Table')
-
- parser.add_argument('--batch_size', type=int, default=64, help='input batch size')
- parser.add_argument('--workers', type=int, default=16, help='workers')
- parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
-
- parser.add_argument('--eval_recon_mvr', default=False)
- parser.add_argument('--generate_multimodal', default=True)
- parser.add_argument('--eval_saved', default=False)
-
- parser.add_argument('--nc', default=3)
- parser.add_argument('--npoints', default=2048)
- parser.add_argument('--svpoints', default=1024)
- '''model'''
- parser.add_argument('--beta_start', default=0.0001)
- parser.add_argument('--beta_end', default=0.02)
- parser.add_argument('--schedule_type', default='linear')
- parser.add_argument('--time_num', default=1000)
-
- #params
- parser.add_argument('--attention', default=True)
- parser.add_argument('--dropout', default=0.1)
- parser.add_argument('--embed_dim', type=int, default=64)
- parser.add_argument('--loss_type', default='mse')
- parser.add_argument('--model_mean_type', default='eps')
- parser.add_argument('--model_var_type', default='fixedsmall')
-
-
- # constrain function
- parser.add_argument('--constrain_eps', default=0.2)
- parser.add_argument('--constrain_steps', type=int, default=1)
-
- parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/epn3d_chair', help="path to netE (to continue training)")
-
- '''eval'''
-
- parser.add_argument('--eval_path',
- default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
-
- parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
-
- parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
-
- opt = parser.parse_args()
-
- if torch.cuda.is_available():
- opt.cuda = True
- else:
- opt.cuda = False
-
- return opt
-if __name__ == '__main__':
- opt = parse_args()
-
- main(opt)
diff --git a/shape_completion/test_plane.py b/shape_completion/test_plane.py
deleted file mode 100644
index 5981c3e..0000000
--- a/shape_completion/test_plane.py
+++ /dev/null
@@ -1,681 +0,0 @@
-
-from pprint import pprint
-from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
-
-import torch.nn as nn
-import torch.utils.data
-
-import argparse
-from torch.distributions import Normal
-
-from utils.file_utils import *
-from utils.mitsuba_renderer import write_to_xml_batch
-from model.pvcnn_completion import PVCNN2Base
-
-from datasets.shapenet_data_pc import ShapeNet15kPointClouds
-from datasets.shapenet_data_sv import *
-'''
-models
-'''
-def normal_kl(mean1, logvar1, mean2, logvar2):
- """
- KL divergence between normal distributions parameterized by mean and log-variance.
- """
- return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
- + (mean1 - mean2)**2 * torch.exp(-logvar2))
-
-def discretized_gaussian_log_likelihood(x, *, means, log_scales):
- # Assumes data is integers [0, 1]
- assert x.shape == means.shape == log_scales.shape
- px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
-
- centered_x = x - means
- inv_stdv = torch.exp(-log_scales)
- plus_in = inv_stdv * (centered_x + 0.5)
- cdf_plus = px0.cdf(plus_in)
- min_in = inv_stdv * (centered_x - .5)
- cdf_min = px0.cdf(min_in)
- log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
- log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
- cdf_delta = cdf_plus - cdf_min
-
- log_probs = torch.where(
- x < 0.001, log_cdf_plus,
- torch.where(x > 0.999, log_one_minus_cdf_min,
- torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
- assert log_probs.shape == x.shape
- return log_probs
-
-
-
-class GaussianDiffusion:
- def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
- self.loss_type = loss_type
- self.model_mean_type = model_mean_type
- self.model_var_type = model_var_type
- assert isinstance(betas, np.ndarray)
- self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
- assert (betas > 0).all() and (betas <= 1).all()
- timesteps, = betas.shape
- self.num_timesteps = int(timesteps)
- self.sv_points = sv_points
- # initialize twice the actual length so we can keep running for eval
- # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
-
- alphas = 1. - betas
- alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
- alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
-
- self.betas = torch.from_numpy(betas).float()
- self.alphas_cumprod = alphas_cumprod.float()
- self.alphas_cumprod_prev = alphas_cumprod_prev.float()
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
- self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
- self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
- self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
- self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
-
- betas = torch.from_numpy(betas).float()
- alphas = torch.from_numpy(alphas).float()
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.posterior_variance = posterior_variance
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
- self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
- self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
-
- @staticmethod
- def _extract(a, t, x_shape):
- """
- Extract some coefficients at specified timesteps,
- then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
- """
- bs, = t.shape
- assert x_shape[0] == bs
- out = torch.gather(a, 0, t)
- assert out.shape == torch.Size([bs])
- return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
-
-
-
- def q_mean_variance(self, x_start, t):
- mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
- variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
- log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
- return mean, variance, log_variance
-
- def q_sample(self, x_start, t, noise=None):
- """
- Diffuse the data (t == 0 means diffused for 1 step)
- """
- if noise is None:
- noise = torch.randn(x_start.shape, device=x_start.device)
- assert noise.shape == x_start.shape
- return (
- self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
- self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
- )
-
-
- def q_posterior_mean_variance(self, x_start, x_t, t):
- """
- Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
- """
- assert x_start.shape == x_t.shape
- posterior_mean = (
- self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
- self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
- )
- posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
- posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
- assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
- x_start.shape[0])
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
-
- def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
-
- model_output = denoise_fn(data, t)[:,:,self.sv_points:]
-
-
- if self.model_var_type in ['fixedsmall', 'fixedlarge']:
- # below: only log_variance is used in the KL computations
- model_variance, model_log_variance = {
- # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
- 'fixedlarge': (self.betas.to(data.device),
- torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
- 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
- }[self.model_var_type]
- model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
- model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
- else:
- raise NotImplementedError(self.model_var_type)
-
- if self.model_mean_type == 'eps':
- x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
-
-
- model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
- else:
- raise NotImplementedError(self.loss_type)
-
-
- assert model_mean.shape == x_recon.shape
- assert model_variance.shape == model_log_variance.shape
- if return_pred_xstart:
- return model_mean, model_variance, model_log_variance, x_recon
- else:
- return model_mean, model_variance, model_log_variance
-
- def _predict_xstart_from_eps(self, x_t, t, eps):
- assert x_t.shape == eps.shape
- return (
- self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
- )
-
- ''' samples '''
-
- def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
- """
- Sample from the model
- """
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
- return_pred_xstart=True)
- noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
-
- # no noise when t == 0
- nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
-
- sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
- sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
- return (sample, pred_xstart) if return_pred_xstart else sample
-
-
- def p_sample_loop(self, partial_x, denoise_fn, shape, device,
- noise_fn=torch.randn, clip_denoised=True, keep_running=False):
- """
- Generate samples
- keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
-
- """
-
- assert isinstance(shape, (tuple, list))
- img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
- for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised, return_pred_xstart=False)
-
- assert img_t[:,:,self.sv_points:].shape == shape
- return img_t
-
- def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
- noise_fn=torch.randn,clip_denoised=True, keep_running=False):
- """
- Generate samples, returning intermediate images
- Useful for visualizing how denoised images evolve over time
- Args:
- repeat_noise_steps (int): Number of denoising timesteps in which the same noise
- is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
- """
- assert isinstance(shape, (tuple, list))
-
- total_steps = self.num_timesteps if not keep_running else len(self.betas)
-
- img_t = noise_fn(size=shape, dtype=torch.float, device=device)
- imgs = [img_t]
- for t in reversed(range(0,total_steps)):
-
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- return_pred_xstart=False)
- if t % freq == 0 or t == total_steps-1:
- imgs.append(img_t)
-
- assert imgs[-1].shape == shape
- return imgs
-
- '''losses'''
-
- def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
- true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
- denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
-
- kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
- kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
-
- return (kl, pred_xstart) if return_pred_xstart else kl
-
- def p_losses(self, denoise_fn, data_start, t, noise=None):
- """
- Training loss calculation
- """
- B, D, N = data_start.shape
- assert t.shape == torch.Size([B])
-
- if noise is None:
- noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
-
- data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
-
- if self.loss_type == 'mse':
- # predict the noise instead of x_start. seems to be weighted naturally like SNR
- eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
-
- losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
- elif self.loss_type == 'kl':
- losses = self._vb_terms_bpd(
- denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
- return_pred_xstart=False)
- else:
- raise NotImplementedError(self.loss_type)
-
- assert losses.shape == torch.Size([B])
- return losses
-
- '''debug'''
-
- def _prior_bpd(self, x_start):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
- t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
- mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
- assert kl_prior.shape == x_start.shape
- return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
-
- def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
-
- vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
- for t in reversed(range(T)):
-
- t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
- # Calculate VLB term at the current timestep
- data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
- new_vals_b, pred_xstart = self._vb_terms_bpd(
- denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
- clip_denoised=clip_denoised, return_pred_xstart=True)
- # MSE for progressive prediction loss
- assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
- new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
- assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
- # Insert the calculated term into the tensor of all terms
- mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
- vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
- mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
- assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
-
- prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
- total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
- assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
- total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
- return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
-
-
-class PVCNN2(PVCNN2Base):
- sa_blocks = [
- ((32, 2, 32), (1024, 0.1, 32, (32, 64))),
- ((64, 3, 16), (256, 0.2, 32, (64, 128))),
- ((128, 3, 8), (64, 0.4, 32, (128, 256))),
- (None, (16, 0.8, 32, (256, 256, 512))),
- ]
- fp_blocks = [
- ((256, 256), (256, 3, 8)),
- ((256, 256), (256, 3, 8)),
- ((256, 128), (128, 2, 16)),
- ((128, 128, 64), (64, 2, 32)),
- ]
-
- def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
- voxel_resolution_multiplier=1):
- super().__init__(
- num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
- dropout=dropout, extra_feature_channels=extra_feature_channels,
- width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
- )
-
-
-class Model(nn.Module):
- def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
- super(Model, self).__init__()
- self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
-
- self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
- dropout=args.dropout, extra_feature_channels=0)
-
- def prior_kl(self, x0):
- return self.diffusion._prior_bpd(x0)
-
- def all_kl(self, x0, clip_denoised=True):
- total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
-
- return {
- 'total_bpd_b': total_bpd_b,
- 'terms_bpd': vals_bt,
- 'prior_bpd_b': prior_bpd_b,
- 'mse_bt':mse_bt
- }
-
-
- def _denoise(self, data, t):
- B, D,N= data.shape
- assert data.dtype == torch.float
- assert t.shape == torch.Size([B]) and t.dtype == torch.int64
-
- out = self.model(data, t)
-
- return out
-
- def get_loss_iter(self, data, noises=None):
- B, D, N = data.shape
- t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
-
- if noises is not None:
- noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
-
- losses = self.diffusion.p_losses(
- denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
- assert losses.shape == t.shape == torch.Size([B])
- return losses
-
- def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
- clip_denoised=True,
- keep_running=False):
- return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- keep_running=keep_running)
-
-
- def train(self):
- self.model.train()
-
- def eval(self):
- self.model.eval()
-
- def multi_gpu_wrapper(self, f):
- self.model = f(self.model)
-
-def get_betas(schedule_type, b_start, b_end, time_num):
- if schedule_type == 'linear':
- betas = np.linspace(b_start, b_end, time_num)
- elif schedule_type == 'warm0.1':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.1)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.2':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.2)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.5':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.5)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- else:
- raise NotImplementedError(schedule_type)
- return betas
-
-
-
-#############################################################################
-def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
- cache=os.path.join(mesh_root, '../cache'), split='val',
- categories=category,
- radius=3, elev=0, azim=0, img_size=512, focal_length=1000,
- npoints=npoints, sv_samples=200,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- )
- return te_dataset
-
-def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
- cache=os.path.join(pc_dataroot, '../cache'), split='val',
- categories=category,
- npoints=npoints, sv_samples=200,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- )
- return te_dataset
-
-
-def evaluate_recon_mvr(opt, netE, save_dir, logger):
- test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
- opt.npoints, opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
- ref = []
- samples = []
- images = []
- masked = []
- k = 0
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
-
- gt_all = data['test_points']
- x_all = data['sv_points']
- img_all = data['image']
-
-
- B,V,N,C = x_all.shape
- gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
-
- x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
- img = img_all.reshape(B * V, *img_all.shape[2:])
-
- m, s = data['mean'].float(), data['std'].float()
-
- recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
- clip_denoised=False).detach().cpu()
-
- recon = recon.transpose(1, 2).contiguous()
- x = x.transpose(1, 2).contiguous()
-
-
- x_adj = x.reshape(B,V,N,C)* s + m
- recon_adj = recon.reshape(B,V,N,C)* s + m
- img = img.reshape(B,V,*img.shape[1:])
-
- ref.append( gt_all * s + m)
- masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
- samples.append(recon_adj)
- images.append(img)
-
- ref_pcs = torch.cat(ref, dim=0)
- sample_pcs = torch.cat(samples, dim=0)
- images = torch.cat(images, dim=0)
- masked = torch.cat(masked, dim=0)
-
- B, V, N, C = ref_pcs.shape
-
-
- torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
- # torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
- torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
- # Compute metrics
- results = EMD_CD(sample_pcs.reshape(B*V, N, C),
- ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
-
- results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
-
- pprint({key: val.mean().item() for key, val in results.items()})
- logger.info({key: val.mean().item() for key, val in results.items()})
-
- results['pc'] = sample_pcs
- torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
-
-
-
-def generate_multimodal(opt, netE, save_dir, logger):
- test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv,
- opt.npoints, opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
-
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
-
- gt_all = data['test_points']
- x_all = data['sv_points']
- img_all = data['image']
-
- Path(os.path.join(save_dir, 'x_%03d'%k)).mkdir(exist_ok=True, parents=True)
- Path(os.path.join(save_dir, 'x_ply_%03d' % k)).mkdir(exist_ok=True, parents=True)
-
- for v in range(5):
- x = x_all.transpose(1, 2).contiguous()
- img = img_all
-
- m, s = data['mean'].float(), data['std'].float()
-
- recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
- clip_denoised=False).detach().cpu()
-
- recon = recon.transpose(1, 2).contiguous()
- x = x.transpose(1, 2).contiguous()
-
- for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))):
-
- im = np.fliplr(np.flipud(d[-1]))
- plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray')
- write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
-
- export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
-
-
-def main(opt):
- exp_id = os.path.splitext(os.path.basename(__file__))[0]
- dir_id = os.path.dirname(__file__)
- output_dir = get_output_dir(dir_id, exp_id)
- copy_source(__file__, output_dir)
- logger = setup_logging(output_dir)
-
- outf_syn, = setup_output_subdirs(output_dir, 'syn')
-
- betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
- netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
-
- if opt.cuda:
- netE.cuda()
-
- def _transform_(m):
- return nn.parallel.DataParallel(m)
-
- netE = netE.cuda()
- netE.multi_gpu_wrapper(_transform_)
-
- netE.eval()
-
- ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
-
- with torch.no_grad():
- for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
-
- opt.netE = ckpt
- logger.info("Resume Path:%s" % opt.netE)
-
- resumed_param = torch.load(opt.netE)
- netE.load_state_dict(resumed_param['model_state'])
-
-
- if opt.eval_recon_mvr:
- # Evaluate generation
- evaluate_recon_mvr(opt, netE, outf_syn, logger)
-
- if opt.generate_multimodal:
-
- generate_multimodal(opt, netE, outf_syn, logger)
-
-
-def parse_args():
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
- parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
- help='input batch size')
- parser.add_argument('--classes', default=['airplane'])
-
- parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
- parser.add_argument('--workers', type=int, default=16, help='workers')
- parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
-
- parser.add_argument('--eval_recon_mvr', default=True)
- parser.add_argument('--generate_multimodal', default=False)
-
- parser.add_argument('--nc', default=3)
- parser.add_argument('--npoints', default=2048)
- parser.add_argument('--svpoints', default=200)
- '''model'''
- parser.add_argument('--beta_start', default=0.0001)
- parser.add_argument('--beta_end', default=0.02)
- parser.add_argument('--schedule_type', default='linear')
- parser.add_argument('--time_num', default=1000)
-
- #params
- parser.add_argument('--attention', default=True)
- parser.add_argument('--dropout', default=0.1)
- parser.add_argument('--embed_dim', type=int, default=64)
- parser.add_argument('--loss_type', default='mse')
- parser.add_argument('--model_mean_type', default='eps')
- parser.add_argument('--model_var_type', default='fixedsmall')
-
-
- # constrain function
- parser.add_argument('--constrain_eps', default=0.2)
- parser.add_argument('--constrain_steps', type=int, default=1)
-
- parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/airplane_ckpt/', help="path to netE (to continue training)")
-
- '''eval'''
-
- parser.add_argument('--eval_path',
- default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
-
- parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
-
- parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
-
- opt = parser.parse_args()
-
- if torch.cuda.is_available():
- opt.cuda = True
- else:
- opt.cuda = False
-
- return opt
-if __name__ == '__main__':
- opt = parse_args()
- set_seed(opt)
-
- main(opt)
diff --git a/shape_completion/test_table.py b/shape_completion/test_table.py
deleted file mode 100644
index 8fb92be..0000000
--- a/shape_completion/test_table.py
+++ /dev/null
@@ -1,764 +0,0 @@
-
-from pprint import pprint
-from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
-
-import torch.nn as nn
-import torch.utils.data
-
-import argparse
-from torch.distributions import Normal
-from utils.visualize import *
-from utils.file_utils import *
-from utils.mitsuba_renderer import write_to_xml_batch
-from model.pvcnn_completion import PVCNN2Base
-
-from datasets.shapenet_data_pc import ShapeNet15kPointClouds
-from datasets.shapenet_data_sv import *
-'''
-models
-'''
-def normal_kl(mean1, logvar1, mean2, logvar2):
- """
- KL divergence between normal distributions parameterized by mean and log-variance.
- """
- return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
- + (mean1 - mean2)**2 * torch.exp(-logvar2))
-
-def discretized_gaussian_log_likelihood(x, *, means, log_scales):
- # Assumes data is integers [0, 1]
- assert x.shape == means.shape == log_scales.shape
- px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
-
- centered_x = x - means
- inv_stdv = torch.exp(-log_scales)
- plus_in = inv_stdv * (centered_x + 0.5)
- cdf_plus = px0.cdf(plus_in)
- min_in = inv_stdv * (centered_x - .5)
- cdf_min = px0.cdf(min_in)
- log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
- log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
- cdf_delta = cdf_plus - cdf_min
-
- log_probs = torch.where(
- x < 0.001, log_cdf_plus,
- torch.where(x > 0.999, log_one_minus_cdf_min,
- torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
- assert log_probs.shape == x.shape
- return log_probs
-
-
-
-class GaussianDiffusion:
- def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
- self.loss_type = loss_type
- self.model_mean_type = model_mean_type
- self.model_var_type = model_var_type
- assert isinstance(betas, np.ndarray)
- self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
- assert (betas > 0).all() and (betas <= 1).all()
- timesteps, = betas.shape
- self.num_timesteps = int(timesteps)
- self.sv_points = sv_points
- # initialize twice the actual length so we can keep running for eval
- # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
-
- alphas = 1. - betas
- alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
- alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
-
- self.betas = torch.from_numpy(betas).float()
- self.alphas_cumprod = alphas_cumprod.float()
- self.alphas_cumprod_prev = alphas_cumprod_prev.float()
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
- self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
- self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
- self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
- self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
-
- betas = torch.from_numpy(betas).float()
- alphas = torch.from_numpy(alphas).float()
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.posterior_variance = posterior_variance
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
- self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
- self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
-
- @staticmethod
- def _extract(a, t, x_shape):
- """
- Extract some coefficients at specified timesteps,
- then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
- """
- bs, = t.shape
- assert x_shape[0] == bs
- out = torch.gather(a, 0, t)
- assert out.shape == torch.Size([bs])
- return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
-
-
-
- def q_mean_variance(self, x_start, t):
- mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
- variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
- log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
- return mean, variance, log_variance
-
- def q_sample(self, x_start, t, noise=None):
- """
- Diffuse the data (t == 0 means diffused for 1 step)
- """
- if noise is None:
- noise = torch.randn(x_start.shape, device=x_start.device)
- assert noise.shape == x_start.shape
- return (
- self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
- self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
- )
-
-
- def q_posterior_mean_variance(self, x_start, x_t, t):
- """
- Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
- """
- assert x_start.shape == x_t.shape
- posterior_mean = (
- self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
- self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
- )
- posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
- posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
- assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
- x_start.shape[0])
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
-
- def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
-
- model_output = denoise_fn(data, t)[:,:,self.sv_points:]
-
-
- if self.model_var_type in ['fixedsmall', 'fixedlarge']:
- # below: only log_variance is used in the KL computations
- model_variance, model_log_variance = {
- # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
- 'fixedlarge': (self.betas.to(data.device),
- torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
- 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
- }[self.model_var_type]
- model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
- model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
- else:
- raise NotImplementedError(self.model_var_type)
-
- if self.model_mean_type == 'eps':
- x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
-
-
- model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
- else:
- raise NotImplementedError(self.loss_type)
-
-
- assert model_mean.shape == x_recon.shape
- assert model_variance.shape == model_log_variance.shape
- if return_pred_xstart:
- return model_mean, model_variance, model_log_variance, x_recon
- else:
- return model_mean, model_variance, model_log_variance
-
- def _predict_xstart_from_eps(self, x_t, t, eps):
- assert x_t.shape == eps.shape
- return (
- self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
- )
-
- ''' samples '''
-
- def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
- """
- Sample from the model
- """
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
- return_pred_xstart=True)
- noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
-
- # no noise when t == 0
- nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
-
- sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
- sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
- return (sample, pred_xstart) if return_pred_xstart else sample
-
-
- def p_sample_loop(self, partial_x, denoise_fn, shape, device,
- noise_fn=torch.randn, clip_denoised=True, keep_running=False):
- """
- Generate samples
- keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
-
- """
-
- assert isinstance(shape, (tuple, list))
- img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
- for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised, return_pred_xstart=False)
-
- assert img_t[:,:,self.sv_points:].shape == shape
- return img_t
-
- def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
- noise_fn=torch.randn,clip_denoised=True, keep_running=False):
- """
- Generate samples, returning intermediate images
- Useful for visualizing how denoised images evolve over time
- Args:
- repeat_noise_steps (int): Number of denoising timesteps in which the same noise
- is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
- """
- assert isinstance(shape, (tuple, list))
-
- total_steps = self.num_timesteps if not keep_running else len(self.betas)
-
- img_t = noise_fn(size=shape, dtype=torch.float, device=device)
- imgs = [img_t]
- for t in reversed(range(0,total_steps)):
-
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- return_pred_xstart=False)
- if t % freq == 0 or t == total_steps-1:
- imgs.append(img_t)
-
- assert imgs[-1].shape == shape
- return imgs
-
- '''losses'''
-
- def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
- true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
- denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
-
- kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
- kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
-
- return (kl, pred_xstart) if return_pred_xstart else kl
-
- def p_losses(self, denoise_fn, data_start, t, noise=None):
- """
- Training loss calculation
- """
- B, D, N = data_start.shape
- assert t.shape == torch.Size([B])
-
- if noise is None:
- noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
-
- data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
-
- if self.loss_type == 'mse':
- # predict the noise instead of x_start. seems to be weighted naturally like SNR
- eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
-
- losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
- elif self.loss_type == 'kl':
- losses = self._vb_terms_bpd(
- denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
- return_pred_xstart=False)
- else:
- raise NotImplementedError(self.loss_type)
-
- assert losses.shape == torch.Size([B])
- return losses
-
- '''debug'''
-
- def _prior_bpd(self, x_start):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
- t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
- mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
- assert kl_prior.shape == x_start.shape
- return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
-
- def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
-
- with torch.no_grad():
- B, T = x_start.shape[0], self.num_timesteps
-
- vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
- for t in reversed(range(T)):
-
- t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
- # Calculate VLB term at the current timestep
- data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
- new_vals_b, pred_xstart = self._vb_terms_bpd(
- denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
- clip_denoised=clip_denoised, return_pred_xstart=True)
- # MSE for progressive prediction loss
- assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
- new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
- assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
- # Insert the calculated term into the tensor of all terms
- mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
- vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
- mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
- assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
-
- prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
- total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
- assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
- total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
- return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
-
-
-class PVCNN2(PVCNN2Base):
- sa_blocks = [
- ((32, 2, 32), (1024, 0.1, 32, (32, 64))),
- ((64, 3, 16), (256, 0.2, 32, (64, 128))),
- ((128, 3, 8), (64, 0.4, 32, (128, 256))),
- (None, (16, 0.8, 32, (256, 256, 512))),
- ]
- fp_blocks = [
- ((256, 256), (256, 3, 8)),
- ((256, 256), (256, 3, 8)),
- ((256, 128), (128, 2, 16)),
- ((128, 128, 64), (64, 2, 32)),
- ]
-
- def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
- voxel_resolution_multiplier=1):
- super().__init__(
- num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
- dropout=dropout, extra_feature_channels=extra_feature_channels,
- width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
- )
-
-
-class Model(nn.Module):
- def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
- super(Model, self).__init__()
- self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
-
- self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
- dropout=args.dropout, extra_feature_channels=0)
-
- def prior_kl(self, x0):
- return self.diffusion._prior_bpd(x0)
-
- def all_kl(self, x0, clip_denoised=True):
- total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
-
- return {
- 'total_bpd_b': total_bpd_b,
- 'terms_bpd': vals_bt,
- 'prior_bpd_b': prior_bpd_b,
- 'mse_bt':mse_bt
- }
-
-
- def _denoise(self, data, t):
- B, D,N= data.shape
- assert data.dtype == torch.float
- assert t.shape == torch.Size([B]) and t.dtype == torch.int64
-
- out = self.model(data, t)
-
- return out
-
- def get_loss_iter(self, data, noises=None):
- B, D, N = data.shape
- t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
-
- if noises is not None:
- noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
-
- losses = self.diffusion.p_losses(
- denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
- assert losses.shape == t.shape == torch.Size([B])
- return losses
-
- def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
- clip_denoised=True,
- keep_running=False):
- return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
- clip_denoised=clip_denoised,
- keep_running=keep_running)
-
-
- def train(self):
- self.model.train()
-
- def eval(self):
- self.model.eval()
-
- def multi_gpu_wrapper(self, f):
- self.model = f(self.model)
-
-def get_betas(schedule_type, b_start, b_end, time_num):
- if schedule_type == 'linear':
- betas = np.linspace(b_start, b_end, time_num)
- elif schedule_type == 'warm0.1':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.1)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.2':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.2)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.5':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.5)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- else:
- raise NotImplementedError(schedule_type)
- return betas
-
-
-
-#############################################################################
-def get_pc_dataset(pc_dataroot, mesh_root, npoints,category):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- return tr_dataset
-
-def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
- cache=os.path.join(mesh_root, '../cache'), split='val',
- categories=category,
- radius=3, elev=-89, azim=180, img_size=512, focal_length=1000,
- npoints=npoints, sv_samples=200,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- )
- return te_dataset
-
-def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
- cache=os.path.join(pc_dataroot, '../cache'), split='val',
- categories=category,
- npoints=npoints, sv_samples=200,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- )
- return te_dataset
-
-
-def evaluate_recon_mvr(opt, netE, save_dir, logger):
- test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
- opt.npoints, opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
- ref = []
- samples = []
- images = []
- masked = []
- k = 0
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
-
- gt_all = data['test_points']
- x_all = data['sv_points']
- # img_all = data['image']
-
-
- B,V,N,C = x_all.shape
- gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
-
- x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
- # img = img_all.reshape(B * V, *img_all.shape[2:])
-
- m, s = data['mean'].float(), data['std'].float()
-
- recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
- clip_denoised=False).detach().cpu()
-
- recon = recon.transpose(1, 2).contiguous()
- x = x.transpose(1, 2).contiguous()
-
-
- x_adj = x.reshape(B,V,N,C)* s + m
- recon_adj = recon.reshape(B,V,N,C)* s + m
- # img = img.reshape(B,V,*img.shape[1:])
-
- ref.append( gt_all * s + m)
- masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
- samples.append(recon_adj)
- # images.append(img)
-
- ref_pcs = torch.cat(ref, dim=0)
- sample_pcs = torch.cat(samples, dim=0)
- # images = torch.cat(images, dim=0)
- masked = torch.cat(masked, dim=0)
-
- B, V, N, C = ref_pcs.shape
-
-
- torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
- # torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
- torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
- # Compute metrics
- results = EMD_CD(sample_pcs.reshape(B*V, N, C),
- ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
-
- results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
-
- pprint({key: val.mean().item() for key, val in results.items()})
- logger.info({key: val.mean().item() for key, val in results.items()})
-
- results['pc'] = sample_pcs
- torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
-
- del ref_pcs, masked, results
-
-def evaluate_saved(opt, netE, save_dir, logger):
- ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn'
-
- gt_pth = ours_base + '/recon_gt.pth'
- ours_pth = ours_base + '/ours_results.pth'
- gt = torch.load(gt_pth).permute(1,0,2,3)
- ours = torch.load(ours_pth)['pc'].permute(1,0,2,3)
-
- all_res = {}
- for i, (gt_, ours_) in enumerate(zip(gt, ours)):
- results = compute_all_metrics(gt_, ours_, opt.batch_size)
-
- for key, val in results.items():
- if i == 0:
- all_res[key] = val
- else:
- all_res[key] += val
- pprint(results)
- for key, val in all_res.items():
- all_res[key] = val / gt.shape[0]
-
- pprint({key: val.mean().item() for key, val in all_res.items()})
-
-
-def generate_multimodal(opt, netE, save_dir, logger):
- test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv,
- opt.npoints, opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
-
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
-
- gt_all = data['test_points']
- x_all = data['sv_points']
- img_all = data['image']
-
- for v in range(6):
- x = x_all.transpose(1, 2).contiguous()
- img = img_all
-
- m, s = data['mean'].float(), data['std'].float()
-
- recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
- clip_denoised=False).detach().cpu()
-
- recon = recon.transpose(1, 2).contiguous()
- x = x.transpose(1, 2).contiguous()
-
- for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))):
-
- im = np.fliplr(np.flipud(d[-1]))
- plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray')
- write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy(), cat='chair')
-
- export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
-
-
-def redwood_demo(opt, netE, save_dir, logger):
- import open3d as o3d
- pth = "/viscam/u/alexzhou907/01DATA/redwood/01605_sample_1.ply"
- pth_gt = "/viscam/u/alexzhou907/01DATA/redwood/01605_pc_gt.ply"
-
- points = np.asarray(o3d.io.read_point_cloud(pth).points)
-
- gt_points = np.asarray(o3d.io.read_point_cloud(pth_gt).points)
-
- np.save('gt.npy', gt_points[np.random.choice(gt_points.shape[0], size=opt.npoints, replace=False)])
-
- write_to_xml_batch(os.path.join(save_dir, 'xml_gt'),
- gt_points[np.random.choice(gt_points.shape[0], size=opt.npoints, replace=False)][None], cat='table')
-
- test_dataset = get_pc_dataset(opt.dataroot_pc, opt.dataroot_sv,
- opt.npoints, opt.classes)
-
- m, s = torch.from_numpy(test_dataset[0]['mean']).float(), torch.from_numpy(test_dataset[0]['std']).float()
-
- x = torch.from_numpy(points[np.random.choice(points.shape[0], size=opt.npoints, replace=False)]).float()
-
- x = (x-m)/s
-
-
- x = x[None].transpose(1,2).cuda()
-
- res = []
- for k in range(20):
- recon = netE.gen_samples(x[:,:,:opt.svpoints], x[:,:,opt.svpoints:].shape, 'cuda',
- clip_denoised=False).detach().cpu()
- recon = recon.transpose(1, 2).contiguous()
- recon = recon * s+ m
- res.append(recon)
- res = torch.cat(res, dim=0)
-
- write_to_xml_batch(os.path.join(save_dir, 'xml'),
- (res).numpy(), cat='table')
-
- export_to_pc_batch(os.path.join(save_dir, 'ply'),
- (res).numpy())
-
- torch.save(res, os.path.join(save_dir, 'redwood_demo.pth'))
-
- exit()
-
-def main(opt):
- exp_id = os.path.splitext(os.path.basename(__file__))[0]
- dir_id = os.path.dirname(__file__)
- output_dir = get_output_dir(dir_id, exp_id)
- copy_source(__file__, output_dir)
- logger = setup_logging(output_dir)
-
- outf_syn, = setup_output_subdirs(output_dir, 'syn')
-
- betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
- netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
-
- if opt.cuda:
- netE.cuda()
-
- def _transform_(m):
- return nn.parallel.DataParallel(m)
-
- netE = netE.cuda()
- netE.multi_gpu_wrapper(_transform_)
-
- netE.eval()
-
- ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
-
- with torch.no_grad():
- for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
-
- opt.netE = ckpt
- logger.info("Resume Path:%s" % opt.netE)
-
- resumed_param = torch.load(opt.netE)
- netE.load_state_dict(resumed_param['model_state'])
-
-
- if opt.eval_recon_mvr:
- # Evaluate generation
- evaluate_recon_mvr(opt, netE, outf_syn, logger)
-
- if opt.generate_multimodal:
-
- generate_multimodal(opt, netE, outf_syn, logger)
-
- if opt.eval_saved:
- evaluate_saved(opt, netE, outf_syn, logger)
-
- if opt.eval_redwood:
- redwood_demo(opt, netE, outf_syn, logger)
-
-def parse_args():
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
- parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
- help='input batch size')
- parser.add_argument('--classes', default=['table'])
-
- parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
- parser.add_argument('--workers', type=int, default=16, help='workers')
- parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
-
- parser.add_argument('--eval_recon_mvr', default=False)
- parser.add_argument('--generate_multimodal', default=False)
- parser.add_argument('--eval_saved', default=False)
- parser.add_argument('--eval_redwood', default=True)
-
- parser.add_argument('--nc', default=3)
- parser.add_argument('--npoints', default=2048)
- parser.add_argument('--svpoints', default=200)
- '''model'''
- parser.add_argument('--beta_start', default=0.0001)
- parser.add_argument('--beta_end', default=0.02)
- parser.add_argument('--schedule_type', default='linear')
- parser.add_argument('--time_num', default=1000)
-
- #params
- parser.add_argument('--attention', default=True)
- parser.add_argument('--dropout', default=0.1)
- parser.add_argument('--embed_dim', type=int, default=64)
- parser.add_argument('--loss_type', default='mse')
- parser.add_argument('--model_mean_type', default='eps')
- parser.add_argument('--model_var_type', default='fixedsmall')
-
-
- # constrain function
- parser.add_argument('--constrain_eps', default=0.2)
- parser.add_argument('--constrain_steps', type=int, default=1)
-
- parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/9_res32_pc_table_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-12-16-14-09-50', help="path to netE (to continue training)")
-
- '''eval'''
-
- parser.add_argument('--eval_path',
- default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
-
- parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
-
- parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
-
- opt = parser.parse_args()
-
- if torch.cuda.is_available():
- opt.cuda = True
- else:
- opt.cuda = False
-
- return opt
-if __name__ == '__main__':
- opt = parse_args()
-
- main(opt)
diff --git a/shapenet/__init__.py b/shapenet/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/shapenet/test_car.py b/shapenet/test_car.py
deleted file mode 100644
index 7997128..0000000
--- a/shapenet/test_car.py
+++ /dev/null
@@ -1,905 +0,0 @@
-import torch
-from pprint import pprint
-from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
-from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
-
-import torch.nn as nn
-import torch.optim as optim
-import torch.utils.data
-
-import argparse
-from model.unet import get_model
-from torch.distributions import Normal
-
-from utils.file_utils import *
-from utils.visualize import *
-from utils.mitsuba_renderer import write_to_xml_batch
-from model.pvcnn_generation import PVCNN2Base
-
-from tqdm import tqdm
-
-from datasets.shapenet_data_pc import ShapeNet15kPointClouds
-from datasets.shapenet_data_sv import *
-'''
-models
-'''
-def normal_kl(mean1, logvar1, mean2, logvar2):
- """
- KL divergence between normal distributions parameterized by mean and log-variance.
- """
- return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
- + (mean1 - mean2)**2 * torch.exp(-logvar2))
-
-def discretized_gaussian_log_likelihood(x, *, means, log_scales):
- # Assumes data is integers [0, 1]
- assert x.shape == means.shape == log_scales.shape
- px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
-
- centered_x = x - means
- inv_stdv = torch.exp(-log_scales)
- plus_in = inv_stdv * (centered_x + 0.5)
- cdf_plus = px0.cdf(plus_in)
- min_in = inv_stdv * (centered_x - .5)
- cdf_min = px0.cdf(min_in)
- log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
- log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
- cdf_delta = cdf_plus - cdf_min
-
- log_probs = torch.where(
- x < 0.001, log_cdf_plus,
- torch.where(x > 0.999, log_one_minus_cdf_min,
- torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
- assert log_probs.shape == x.shape
- return log_probs
-
-
-class GaussianDiffusion:
- def __init__(self,betas, loss_type, model_mean_type, model_var_type):
- self.loss_type = loss_type
- self.model_mean_type = model_mean_type
- self.model_var_type = model_var_type
- assert isinstance(betas, np.ndarray)
- self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
- assert (betas > 0).all() and (betas <= 1).all()
- timesteps, = betas.shape
- self.num_timesteps = int(timesteps)
-
- # initialize twice the actual length so we can keep running for eval
- # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
-
- alphas = 1. - betas
- alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
- alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
-
- self.betas = torch.from_numpy(betas).float()
- self.alphas_cumprod = alphas_cumprod.float()
- self.alphas_cumprod_prev = alphas_cumprod_prev.float()
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
- self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
- self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
- self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
- self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
-
- betas = torch.from_numpy(betas).float()
- alphas = torch.from_numpy(alphas).float()
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.posterior_variance = posterior_variance
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
- self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
- self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
-
- @staticmethod
- def _extract(a, t, x_shape):
- """
- Extract some coefficients at specified timesteps,
- then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
- """
- bs, = t.shape
- assert x_shape[0] == bs
- out = torch.gather(a, 0, t)
- assert out.shape == torch.Size([bs])
- return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
-
-
-
- def q_mean_variance(self, x_start, t):
- mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
- variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
- log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
- return mean, variance, log_variance
-
- def q_sample(self, x_start, t, noise=None):
- """
- Diffuse the data (t == 0 means diffused for 1 step)
- """
- if noise is None:
- noise = torch.randn(x_start.shape, device=x_start.device)
- assert noise.shape == x_start.shape
- return (
- self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
- self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
- )
-
-
- def q_posterior_mean_variance(self, x_start, x_t, t):
- """
- Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
- """
- assert x_start.shape == x_t.shape
- posterior_mean = (
- self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
- self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
- )
- posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
- posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
- assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
- x_start.shape[0])
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
-
- def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
-
- model_output = denoise_fn(data, t)
-
-
- if self.model_var_type in ['fixedsmall', 'fixedlarge']:
- # below: only log_variance is used in the KL computations
- model_variance, model_log_variance = {
- # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
- 'fixedlarge': (self.betas.to(data.device),
- torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
- 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
- }[self.model_var_type]
- model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data)
- model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data)
- else:
- raise NotImplementedError(self.model_var_type)
-
- if self.model_mean_type == 'eps':
- x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output)
-
- if clip_denoised:
- x_recon = torch.clamp(x_recon, -.5, .5)
-
- model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t)
- else:
- raise NotImplementedError(self.loss_type)
-
-
- assert model_mean.shape == x_recon.shape == data.shape
- assert model_variance.shape == model_log_variance.shape == data.shape
- if return_pred_xstart:
- return model_mean, model_variance, model_log_variance, x_recon
- else:
- return model_mean, model_variance, model_log_variance
-
- def _predict_xstart_from_eps(self, x_t, t, eps):
- assert x_t.shape == eps.shape
- return (
- self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
- )
-
- ''' samples '''
-
- def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True):
- """
- Sample from the model
- """
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
- return_pred_xstart=True)
- noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device)
- assert noise.shape == data.shape
- # no noise when t == 0
- nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(data.shape) - 1))
-
- sample = model_mean
- if use_var:
- sample = sample + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
- assert sample.shape == pred_xstart.shape
- return (sample, pred_xstart) if return_pred_xstart else sample
-
-
- def p_sample_loop(self, denoise_fn, shape, device,
- noise_fn=torch.randn, constrain_fn=lambda x, t:x,
- clip_denoised=True, max_timestep=None, keep_running=False):
- """
- Generate samples
- keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
-
- """
- if max_timestep is None:
- final_time = self.num_timesteps
- else:
- final_time = max_timestep
-
- assert isinstance(shape, (tuple, list))
- img_t = noise_fn(size=shape, dtype=torch.float, device=device)
- for t in reversed(range(0, final_time if not keep_running else len(self.betas))):
- img_t = constrain_fn(img_t, t)
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised, return_pred_xstart=False).detach()
-
-
- assert img_t.shape == shape
- return img_t
-
- def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x):
-
- assert t >= 1
-
- t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1)
- encoding = self.q_sample(x0, t_vec)
-
- img_t = encoding
-
- for k in reversed(range(0,t)):
- img_t = constrain_fn(img_t, k)
- t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
- clip_denoised=False, return_pred_xstart=False, use_var=True).detach()
-
-
- return img_t
-
-
-class PVCNN2(PVCNN2Base):
- sa_blocks = [
- ((32, 2, 32), (1024, 0.1, 32, (32, 64))),
- ((64, 3, 16), (256, 0.2, 32, (64, 128))),
- ((128, 3, 8), (64, 0.4, 32, (128, 256))),
- (None, (16, 0.8, 32, (256, 256, 512))),
- ]
- fp_blocks = [
- ((256, 256), (256, 3, 8)),
- ((256, 256), (256, 3, 8)),
- ((256, 128), (128, 2, 16)),
- ((128, 128, 64), (64, 2, 32)),
- ]
-
- def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
- voxel_resolution_multiplier=1):
- super().__init__(
- num_classes=num_classes, embed_dim=embed_dim, use_att=use_att,
- dropout=dropout, extra_feature_channels=extra_feature_channels,
- width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
- )
-
-class Model(nn.Module):
- def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
- super(Model, self).__init__()
- self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type)
-
- self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention,
- dropout=args.dropout, extra_feature_channels=0)
-
- def prior_kl(self, x0):
- return self.diffusion._prior_bpd(x0)
-
- def all_kl(self, x0, clip_denoised=True):
- total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
-
- return {
- 'total_bpd_b': total_bpd_b,
- 'terms_bpd': vals_bt,
- 'prior_bpd_b': prior_bpd_b,
- 'mse_bt':mse_bt
- }
-
-
- def _denoise(self, data, t):
- B, D,N= data.shape
- assert data.dtype == torch.float
- assert t.shape == torch.Size([B]) and t.dtype == torch.int64
-
- out = self.model(data, t)
-
- assert out.shape == torch.Size([B, D, N])
- return out
-
- def get_loss_iter(self, data, noises=None):
- B, D, N = data.shape
- t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
-
- if noises is not None:
- noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
-
- losses = self.diffusion.p_losses(
- denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
- assert losses.shape == t.shape == torch.Size([B])
- return losses
-
- def gen_samples(self, shape, device, noise_fn=torch.randn, constrain_fn=lambda x, t:x,
- clip_denoised=False, max_timestep=None,
- keep_running=False):
- return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn,
- constrain_fn=constrain_fn,
- clip_denoised=clip_denoised, max_timestep=max_timestep,
- keep_running=keep_running)
-
- def reconstruct(self, x0, t, constrain_fn=lambda x, t:x):
-
- return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn)
-
- def train(self):
- self.model.train()
-
- def eval(self):
- self.model.eval()
-
- def multi_gpu_wrapper(self, f):
- self.model = f(self.model)
-
-
-def get_betas(schedule_type, b_start, b_end, time_num):
- if schedule_type == 'linear':
- betas = np.linspace(b_start, b_end, time_num)
- elif schedule_type == 'warm0.1':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.1)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.2':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.2)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.5':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.5)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- else:
- raise NotImplementedError(schedule_type)
- return betas
-
-def get_constrain_function(ground_truth, mask, eps, num_steps=1):
- '''
-
- :param target_shape_constraint: target voxels
- :return: constrained x
- '''
- # eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2))
- eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000)**2 ))
- def constrain_fn(x, t):
- eps_ = eps_all[t] if (t<1000) else 0
- for _ in range(num_steps):
- x = x - eps_ * ((x - ground_truth) * mask)
-
-
- return x
- return constrain_fn
-
-
-#############################################################################
-
-def get_dataset(dataroot, npoints,category,use_mask=False):
- tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True, use_mask = use_mask)
- te_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
- categories=category, split='val',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- use_mask=use_mask
- )
- return tr_dataset, te_dataset
-
-def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
- cache=os.path.join(mesh_root, '../cache'), split='val',
- categories=category,
- radius=3, elev=0, azim=0, img_size=512, focal_length=1000,
- npoints=npoints, sv_samples=200,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- )
- return te_dataset
-
-
-def get_mvr_dataset(pc_dataroot, mesh_root, npoints,category):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
- cache=os.path.join(mesh_root, '../cache'), split='val',
- categories=category,
- npoints=npoints, sv_samples=200,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- )
-
-
- return te_dataset
-def get_mvr_dataset_v2(pc_dataroot, views_root, npoints,category):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
- cache=os.path.join(pc_dataroot, '../cache'), split='val',
- categories=category,
- npoints=npoints, sv_samples=200,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- )
- return te_dataset
-
-
-def evaluate_gen(opt, ref_pcs, logger):
-
- if ref_pcs is None:
- _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=False)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
- ref = []
- for data in tqdm(test_dataloader, total=len(test_dataloader), desc='Generating Samples'):
- x = data['test_points']
- m, s = data['mean'].float(), data['std'].float()
-
- ref.append(x*s + m)
-
- ref_pcs = torch.cat(ref, dim=0).contiguous()
-
- logger.info("Loading sample path: %s"
- % (opt.eval_path))
- sample_pcs = torch.load(opt.eval_path).contiguous()
-
- logger.info("Generation sample size:%s reference size: %s"
- % (sample_pcs.size(), ref_pcs.size()))
-
-
- # Compute metrics
- results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
- results = {k: (v.cpu().detach().item()
- if not isinstance(v, float) else v) for k, v in results.items()}
-
- pprint(results)
- logger.info(results)
-
- jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy())
- pprint('JSD: {}'.format(jsd))
- logger.info('JSD: {}'.format(jsd))
-
-def evaluate_recon(opt, netE, save_dir, logger):
- test_dataset = get_mvr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2',
- opt.npoints, opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
- ref = []
- samples = []
- images = []
- masked = []
- k = 0
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
-
- randind = i%24
- gt_all = data['test_points'][:,randind:randind+1]
- x_all = data['sv_points'][:,randind:randind+1]
- mask_all= data['masks'][:,randind:randind+1]
- img_all = data['image'][:,randind:randind+1]
-
-
- B,V,N,C = x_all.shape
-
- x = x_all.reshape(B*V,N,C).transpose(1,2).contiguous()
- mask = mask_all.reshape(B*V,N,C).transpose(1,2).contiguous()
- img = img_all.reshape(B*V, *img_all.shape[2:])
-
- m, s = data['mean'].float(), data['std'].float()
-
- # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None)
- # for t in [10]:
- # recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
- # opt.constrain_steps)).detach().cpu()
-
- recon = netE.gen_samples(x.shape, 'cuda',
- constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
- opt.constrain_steps),
- clip_denoised=False).detach().cpu()
-
- recon = recon.transpose(1, 2).contiguous()
- x = x.transpose(1, 2).contiguous()
-
-
- # for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))):
- # write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
- # visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None)
- #
- # export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
- #
- # k+=1
-
- x_adj = x.reshape(B,V,N,C)* s + m
- recon_adj = recon.reshape(B,V,N,C)* s + m
- img = img.reshape(B,V,*img.shape[1:])
-
- ref.append( gt_all * s + m)
- masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
- samples.append(recon_adj)
- images.append(img)
-
-
- ref_pcs = torch.cat(ref, dim=0)
- sample_pcs = torch.cat(samples, dim=0)
- images = torch.cat(images, dim=0)
- masked = torch.cat(masked, dim=0)
-
- B, V, N, C = ref_pcs.shape
-
-
- torch.save(ref_pcs.reshape(B*V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
- torch.save(images.reshape(B*V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
- torch.save(masked.reshape(B*V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
- # Compute metrics
- results = EMD_CD(sample_pcs.reshape(B*V, N, C),
- ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
-
- results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
-
- pprint({key: val.mean() for key, val in results.items()})
- logger.info({key: val.mean() for key, val in results.items()})
-
- results['pc'] = sample_pcs
- torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
-
- #
- # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
- #
- # results = {k: (v.cpu().detach().item()
- # if not isinstance(v, float) else v) for k, v in results.items()}
- # pprint(results)
- # logger.info(results)
-def evaluate_recon_mvr(opt, netE, save_dir, logger):
- test_dataset = get_mvr_dataset_v2(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
- opt.npoints, opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
- ref = []
- samples = []
- images = []
- masked = []
- k = 0
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
-
- gt_all = data['test_points']
- x_all = data['sv_points']
- mask_all= data['masks']
- img_all = data['image']
-
-
- B,V,N,C = x_all.shape
- gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
-
- # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None)
- # for t in [10]:
- # recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
- # opt.constrain_steps)).detach().cpu()
-
- cd_res = []
- recon_res = []
- for p in range(5):
- x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
- mask = mask_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
- img = img_all.reshape(B * V, *img_all.shape[2:])
-
- m, s = data['mean'].float(), data['std'].float()
-
- recon = netE.gen_samples(x.shape, 'cuda',
- constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
- opt.constrain_steps),
- clip_denoised=False).detach().cpu()
-
- recon = recon.transpose(1, 2).contiguous()
- x = x.transpose(1, 2).contiguous()
-
- cd = (((recon - x)**2)*mask.transpose(1, 2).contiguous()).sum(dim=(1,2))
-
- cd_res.append(cd)
- recon_res.append(recon)
-
- cd_res = torch.stack(cd_res, dim=0)
- recon_res = torch.stack(recon_res, dim=0)
- _, argmin = torch.min(cd_res, 0)
- recon = recon_res[argmin,torch.arange(0,argmin.shape[0])]
-
- # for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))):
- # write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
- # visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None)
- #
- # export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
- #
- # k+=1
-
- x_adj = x.reshape(B,V,N,C)* s + m
- recon_adj = recon.reshape(B,V,N,C)* s + m
- img = img.reshape(B,V,*img.shape[1:])
-
- ref.append( gt_all * s + m)
- masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
- samples.append(recon_adj)
- images.append(img)
-
- ref_pcs = torch.cat(ref, dim=0)
- sample_pcs = torch.cat(samples, dim=0)
- images = torch.cat(images, dim=0)
- masked = torch.cat(masked, dim=0)
-
- B, V, N, C = ref_pcs.shape
-
-
- torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
- torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
- torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
- # Compute metrics
- results = EMD_CD(sample_pcs.reshape(B*V, N, C),
- ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
-
- results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
-
- pprint({key: val.mean().item() for key, val in results.items()})
- logger.info({key: val.mean().item() for key, val in results.items()})
-
- results['pc'] = sample_pcs
- torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
-
- #
- # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
- #
- # results = {k: (v.cpu().detach().item()
- # if not isinstance(v, float) else v) for k, v in results.items()}
- # pprint(results)
- # logger.info(results)
-
-
-
-def generate(netE, opt, logger):
-
- _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes)
-
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
-
- with torch.no_grad():
-
- samples = []
- ref = []
-
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Generating Samples'):
-
- x = data['test_points'].transpose(1,2)
- m, s = data['mean'].float(), data['std'].float()
-
- gen = netE.gen_samples(x.shape,
- 'cuda', clip_denoised=False).detach().cpu()
-
- gen = gen.transpose(1,2).contiguous()
- x = x.transpose(1,2).contiguous()
-
-
-
- gen = gen * s + m
- x = x * s + m
- samples.append(gen)
- ref.append(x)
-
- visualize_pointcloud_batch(os.path.join(str(Path(opt.eval_path).parent), 'x.png'), gen[:64], None,
- None, None)
-
- write_to_xml_batch(os.path.join(str(Path(opt.eval_path).parent), 'xml_samples_%03d'%i), gen[:min(gen.shape[0], 40)].numpy(), default_color='b')
-
- samples = torch.cat(samples, dim=0)
- ref = torch.cat(ref, dim=0)
-
- torch.save(samples, opt.eval_path)
-
-
-
- return ref
-
-
-def generate_multimodal(opt, netE, save_dir, logger):
- test_dataset = get_svr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2',
- opt.npoints, opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
-
- k = 0
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
-
- gt_all = data['test_points']
- x_all = data['sv_points']
- mask_all= data['masks']
- img_all = data['image']
-
-
- # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None)
- # for t in [10]:
- # recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
- # opt.constrain_steps)).detach().cpu()
-
- Path(os.path.join(save_dir, 'x_%03d'%k)).mkdir(exist_ok=True, parents=True)
- Path(os.path.join(save_dir, 'x_ply_%03d' % k)).mkdir(exist_ok=True, parents=True)
-
- for v in range(5):
- x = x_all.transpose(1, 2).contiguous()
- mask = mask_all.transpose(1, 2).contiguous()
- img = img_all
-
- m, s = data['mean'].float(), data['std'].float()
-
- recon = netE.gen_samples(x.shape, 'cuda',
- constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
- opt.constrain_steps),
- clip_denoised=False).detach().cpu()
-
- recon = recon.transpose(1, 2).contiguous()
- x = x.transpose(1, 2).contiguous()
-
-
- for d in zip(list(gt_all), list(recon), list(x), list(img)):
-
- im = np.fliplr(np.flipud(d[-1]))
- plt.imsave(os.path.join(save_dir, 'depth_%03d.png'%k), im, cmap='gray')
- write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k, 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
-
- export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k, 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
-
- k+=1
-
-def main(opt):
- exp_id = os.path.splitext(os.path.basename(__file__))[0]
- dir_id = os.path.dirname(__file__)
- output_dir = get_output_dir(dir_id, exp_id)
- copy_source(__file__, output_dir)
- logger = setup_logging(output_dir)
-
- outf_syn, = setup_output_subdirs(output_dir, 'syn')
-
- betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
- netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
-
- if opt.cuda:
- netE.cuda()
-
- def _transform_(m):
- return nn.parallel.DataParallel(m)
-
- netE = netE.cuda()
- netE.multi_gpu_wrapper(_transform_)
-
- netE.eval()
-
- ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
-
- with torch.no_grad():
- for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
- #'/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-02-45/epoch_1399.pth'
- #'/viscam/u/alexzhou907/research/diffusion/shapenet/output/67_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-03-38/epoch_2799.pth'
- opt.netE = '/viscam/u/alexzhou907/research/diffusion/shapenet/output/67_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-03-38/epoch_2799.pth'#ckpt
- logger.info("Resume Path:%s" % opt.netE)
-
- resumed_param = torch.load(opt.netE)
- netE.load_state_dict(resumed_param['model_state'])
-
-
- ref = None
- if opt.generate:
- epoch = int(os.path.basename(ckpt).split('.')[0].split('_')[-1])
- opt.eval_path = os.path.join(outf_syn, 'epoch_{}_samples.pth'.format(epoch))
- Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True)
- ref=generate(netE, opt, logger)
- if opt.eval_gen:
- # Evaluate generation
- evaluate_gen(opt, ref, logger)
-
- if opt.eval_recon:
- # Evaluate generation
- evaluate_recon(opt, netE, outf_syn, logger)
-
- if opt.eval_recon_mvr:
- # Evaluate generation
- evaluate_recon_mvr(opt, netE, outf_syn, logger)
-
- if opt.generate_multimodal:
-
- generate_multimodal(opt, netE, outf_syn, logger)
-
- exit()
-
-def parse_args():
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
- parser.add_argument('--classes', default=['car'])
-
- parser.add_argument('--batch_size', type=int, default=50, help='input batch size')
- parser.add_argument('--workers', type=int, default=16, help='workers')
- parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
-
- parser.add_argument('--generate',default=True)
- parser.add_argument('--eval_gen', default=True)
-
- parser.add_argument('--nc', default=3)
- parser.add_argument('--npoints', default=2048)
- '''model'''
- parser.add_argument('--beta_start', default=0.0001)
- parser.add_argument('--beta_end', default=0.02)
- parser.add_argument('--schedule_type', default='linear')
- parser.add_argument('--time_num', default=1000)
-
- #params
- parser.add_argument('--attention', default=True)
- parser.add_argument('--dropout', default=0.1)
- parser.add_argument('--embed_dim', type=int, default=64)
- parser.add_argument('--loss_type', default='mse')
- parser.add_argument('--model_mean_type', default='eps')
- parser.add_argument('--model_var_type', default='fixedsmall')
-
-
- # constrain function
- parser.add_argument('--constrain_eps', default=0.2)
- parser.add_argument('--constrain_steps', type=int, default=1)
-
- parser.add_argument('--ckpt_dir', default='',required=True, help="path to netE (to continue training)")
-
- '''eval'''
-
- parser.add_argument('--eval_path',
- default='')
-
- parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
-
- parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
-
- opt = parser.parse_args()
-
- if torch.cuda.is_available():
- opt.cuda = True
- else:
- opt.cuda = False
-
- return opt
-if __name__ == '__main__':
- opt = parse_args()
- set_seed(opt)
-
- main(opt)
-
- # results in /viscam/u/alexzhou907/research/diffusion/shapenet/output/test_chair
\ No newline at end of file
diff --git a/shapenet/test_chair.py b/shapenet/test_chair.py
deleted file mode 100644
index c88ec0b..0000000
--- a/shapenet/test_chair.py
+++ /dev/null
@@ -1,911 +0,0 @@
-import torch
-from pprint import pprint
-from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
-from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
-
-import torch.nn as nn
-import torch.optim as optim
-import torch.utils.data
-
-import argparse
-from model.unet import get_model
-from torch.distributions import Normal
-
-from utils.file_utils import *
-from utils.visualize import *
-from utils.mitsuba_renderer import write_to_xml_batch
-from model.pvcnn_generation import PVCNN2Base
-
-from tqdm import tqdm
-
-from datasets.shapenet_data_pc import ShapeNet15kPointClouds
-from datasets.shapenet_data_sv import *
-'''
-models
-'''
-def normal_kl(mean1, logvar1, mean2, logvar2):
- """
- KL divergence between normal distributions parameterized by mean and log-variance.
- """
- return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
- + (mean1 - mean2)**2 * torch.exp(-logvar2))
-
-def discretized_gaussian_log_likelihood(x, *, means, log_scales):
- # Assumes data is integers [0, 1]
- assert x.shape == means.shape == log_scales.shape
- px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
-
- centered_x = x - means
- inv_stdv = torch.exp(-log_scales)
- plus_in = inv_stdv * (centered_x + 0.5)
- cdf_plus = px0.cdf(plus_in)
- min_in = inv_stdv * (centered_x - .5)
- cdf_min = px0.cdf(min_in)
- log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
- log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
- cdf_delta = cdf_plus - cdf_min
-
- log_probs = torch.where(
- x < 0.001, log_cdf_plus,
- torch.where(x > 0.999, log_one_minus_cdf_min,
- torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
- assert log_probs.shape == x.shape
- return log_probs
-
-
-class GaussianDiffusion:
- def __init__(self,betas, loss_type, model_mean_type, model_var_type):
- self.loss_type = loss_type
- self.model_mean_type = model_mean_type
- self.model_var_type = model_var_type
- assert isinstance(betas, np.ndarray)
- self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
- assert (betas > 0).all() and (betas <= 1).all()
- timesteps, = betas.shape
- self.num_timesteps = int(timesteps)
-
- # initialize twice the actual length so we can keep running for eval
- # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
-
- alphas = 1. - betas
- alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
- alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
-
- self.betas = torch.from_numpy(betas).float()
- self.alphas_cumprod = alphas_cumprod.float()
- self.alphas_cumprod_prev = alphas_cumprod_prev.float()
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
- self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
- self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
- self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
- self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
-
- betas = torch.from_numpy(betas).float()
- alphas = torch.from_numpy(alphas).float()
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.posterior_variance = posterior_variance
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
- self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
- self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
-
- @staticmethod
- def _extract(a, t, x_shape):
- """
- Extract some coefficients at specified timesteps,
- then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
- """
- bs, = t.shape
- assert x_shape[0] == bs
- out = torch.gather(a, 0, t)
- assert out.shape == torch.Size([bs])
- return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
-
-
-
- def q_mean_variance(self, x_start, t):
- mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
- variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
- log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
- return mean, variance, log_variance
-
- def q_sample(self, x_start, t, noise=None):
- """
- Diffuse the data (t == 0 means diffused for 1 step)
- """
- if noise is None:
- noise = torch.randn(x_start.shape, device=x_start.device)
- assert noise.shape == x_start.shape
- return (
- self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
- self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
- )
-
-
- def q_posterior_mean_variance(self, x_start, x_t, t):
- """
- Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
- """
- assert x_start.shape == x_t.shape
- posterior_mean = (
- self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
- self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
- )
- posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
- posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
- assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
- x_start.shape[0])
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
-
- def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
-
- model_output = denoise_fn(data, t)
-
-
- if self.model_var_type in ['fixedsmall', 'fixedlarge']:
- # below: only log_variance is used in the KL computations
- model_variance, model_log_variance = {
- # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
- 'fixedlarge': (self.betas.to(data.device),
- torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
- 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
- }[self.model_var_type]
- model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data)
- model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data)
- else:
- raise NotImplementedError(self.model_var_type)
-
- if self.model_mean_type == 'eps':
- x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output)
-
- if clip_denoised:
- x_recon = torch.clamp(x_recon, -.5, .5)
-
- model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t)
- else:
- raise NotImplementedError(self.loss_type)
-
-
- assert model_mean.shape == x_recon.shape == data.shape
- assert model_variance.shape == model_log_variance.shape == data.shape
- if return_pred_xstart:
- return model_mean, model_variance, model_log_variance, x_recon
- else:
- return model_mean, model_variance, model_log_variance
-
- def _predict_xstart_from_eps(self, x_t, t, eps):
- assert x_t.shape == eps.shape
- return (
- self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
- )
-
- ''' samples '''
-
- def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True):
- """
- Sample from the model
- """
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
- return_pred_xstart=True)
- noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device)
- assert noise.shape == data.shape
- # no noise when t == 0
- nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(data.shape) - 1))
-
- sample = model_mean
- if use_var:
- sample = sample + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
- assert sample.shape == pred_xstart.shape
- return (sample, pred_xstart) if return_pred_xstart else sample
-
-
- def p_sample_loop(self, denoise_fn, shape, device,
- noise_fn=torch.randn, constrain_fn=lambda x, t:x,
- clip_denoised=True, max_timestep=None, keep_running=False):
- """
- Generate samples
- keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
-
- """
- if max_timestep is None:
- final_time = self.num_timesteps
- else:
- final_time = max_timestep
-
- assert isinstance(shape, (tuple, list))
- img_t = noise_fn(size=shape, dtype=torch.float, device=device)
- for t in reversed(range(0, final_time if not keep_running else len(self.betas))):
- img_t = constrain_fn(img_t, t)
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised, return_pred_xstart=False).detach()
-
-
- assert img_t.shape == shape
- return img_t
-
- def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x):
-
- assert t >= 1
-
- t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1)
- encoding = self.q_sample(x0, t_vec)
-
- img_t = encoding
-
- for k in reversed(range(0,t)):
- img_t = constrain_fn(img_t, k)
- t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
- clip_denoised=False, return_pred_xstart=False, use_var=True).detach()
-
-
- return img_t
-
- def interpolate(self, x0, x1, t, lamb, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x):
-
- assert t >= 1
-
- t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1)
- encoding0 = self.q_sample(x0, t_vec)
- encoding1 = self.q_sample(x1, t_vec)
-
- enc = encoding0 * lamb + (1-lamb) * encoding1
-
- img_t = enc
-
- for k in reversed(range(0,t)):
- img_t = constrain_fn(img_t, k)
- t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
- clip_denoised=False, return_pred_xstart=False, use_var=True).detach()
-
-
- return img_t
-
-class PVCNN2(PVCNN2Base):
- sa_blocks = [
- ((32, 2, 32), (1024, 0.1, 32, (32, 64))),
- ((64, 3, 16), (256, 0.2, 32, (64, 128))),
- ((128, 3, 8), (64, 0.4, 32, (128, 256))),
- (None, (16, 0.8, 32, (256, 256, 512))),
- ]
- fp_blocks = [
- ((256, 256), (256, 3, 8)),
- ((256, 256), (256, 3, 8)),
- ((256, 128), (128, 2, 16)),
- ((128, 128, 64), (64, 2, 32)),
- ]
-
- def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
- voxel_resolution_multiplier=1):
- super().__init__(
- num_classes=num_classes, embed_dim=embed_dim, use_att=use_att,
- dropout=dropout, extra_feature_channels=extra_feature_channels,
- width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
- )
-
-class Model(nn.Module):
- def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
- super(Model, self).__init__()
- self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type)
-
- self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention,
- dropout=args.dropout, extra_feature_channels=0)
-
- def prior_kl(self, x0):
- return self.diffusion._prior_bpd(x0)
-
- def all_kl(self, x0, clip_denoised=True):
- total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
-
- return {
- 'total_bpd_b': total_bpd_b,
- 'terms_bpd': vals_bt,
- 'prior_bpd_b': prior_bpd_b,
- 'mse_bt':mse_bt
- }
-
-
- def _denoise(self, data, t):
- B, D,N= data.shape
- assert data.dtype == torch.float
- assert t.shape == torch.Size([B]) and t.dtype == torch.int64
-
- out = self.model(data, t)
-
- assert out.shape == torch.Size([B, D, N])
- return out
-
- def get_loss_iter(self, data, noises=None):
- B, D, N = data.shape
- t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
-
- if noises is not None:
- noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
-
- losses = self.diffusion.p_losses(
- denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
- assert losses.shape == t.shape == torch.Size([B])
- return losses
-
- def gen_samples(self, shape, device, noise_fn=torch.randn, constrain_fn=lambda x, t:x,
- clip_denoised=False, max_timestep=None,
- keep_running=False):
- return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn,
- constrain_fn=constrain_fn,
- clip_denoised=clip_denoised, max_timestep=max_timestep,
- keep_running=keep_running)
-
- def reconstruct(self, x0, t, constrain_fn=lambda x, t:x):
-
- return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn)
-
- def interpolate(self, x0, x1, t, lamb, constrain_fn=lambda x, t:x):
-
- return self.diffusion.interpolate(x0, x1, t, lamb, self._denoise, constrain_fn=constrain_fn)
-
- def train(self):
- self.model.train()
-
- def eval(self):
- self.model.eval()
-
- def multi_gpu_wrapper(self, f):
- self.model = f(self.model)
-
-
-def get_betas(schedule_type, b_start, b_end, time_num):
- if schedule_type == 'linear':
- betas = np.linspace(b_start, b_end, time_num)
- elif schedule_type == 'warm0.1':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.1)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.2':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.2)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.5':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.5)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- else:
- raise NotImplementedError(schedule_type)
- return betas
-
-def get_constrain_function(ground_truth, mask, eps, num_steps=1):
- '''
-
- :param target_shape_constraint: target voxels
- :return: constrained x
- '''
- # eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2))
- eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000)**2 ))
- def constrain_fn(x, t):
- eps_ = eps_all[t] if (t<1000) else 0
- for _ in range(num_steps):
- x = x - eps_ * ((x - ground_truth) * mask)
-
-
- return x
- return constrain_fn
-
-
-#############################################################################
-
-def get_dataset(dataroot, npoints,category,use_mask=False):
- tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True, use_mask = use_mask)
- te_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
- categories=category, split='val',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- use_mask=use_mask
- )
- return tr_dataset, te_dataset
-
-def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
- cache=os.path.join(mesh_root, '../cache'), split='val',
- categories=category,
- radius=3, elev=-89, azim=180, img_size=512, focal_length=1000,
- npoints=npoints, sv_samples=200,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- )
- return te_dataset
-
-def get_mvr_dataset(pc_dataroot, mesh_root, npoints,category):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
- cache=os.path.join(mesh_root, '../cache'), split='val',
- categories=category,
- npoints=npoints, sv_samples=200,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- )
-
-
- return te_dataset
-
-def get_mvr_dataset_v2(pc_dataroot, views_root, npoints,category):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
- cache=os.path.join(pc_dataroot, '../cache'), split='val',
- categories=category,
- npoints=npoints, sv_samples=200,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- )
- return te_dataset
-
-
-def evaluate_gen(opt, ref_pcs, logger):
-
- if ref_pcs is None:
- _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=False)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
- ref = []
- for data in tqdm(test_dataloader, total=len(test_dataloader), desc='Generating Samples'):
- x = data['test_points']
- m, s = data['mean'].float(), data['std'].float()
-
- ref.append(x*s + m)
-
- ref_pcs = torch.cat(ref, dim=0).contiguous()
-
- logger.info("Loading sample path: %s"
- % (opt.eval_path))
- sample_pcs = torch.load(opt.eval_path).contiguous()
-
- logger.info("Generation sample size:%s reference size: %s"
- % (sample_pcs.size(), ref_pcs.size()))
-
-
- # Compute metrics
- # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
- # results = {k: (v.cpu().detach().item()
- # if not isinstance(v, float) else v) for k, v in results.items()}
- #
- # pprint(results)
- # logger.info(results)
-
- jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy())
- pprint('JSD: {}'.format(jsd))
- logger.info('JSD: {}'.format(jsd))
-
-def evaluate_recon(opt, netE, save_dir, logger):
- test_dataset = get_mvr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2',
- opt.npoints, opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
- ref = []
- samples = []
- images = []
- masked = []
- k = 0
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
-
- randind = i%24
- gt_all = data['test_points'][:,randind:randind+1]
- x_all = data['sv_points'][:,randind:randind+1]
- mask_all= data['masks'][:,randind:randind+1]
- img_all = data['image'][:,randind:randind+1]
-
-
- B,V,N,C = x_all.shape
-
- gt = gt_all.reshape(B*V,N,C).transpose(1,2).contiguous()
- x = x_all.reshape(B*V,N,C).transpose(1,2).contiguous()
- mask = mask_all.reshape(B*V,N,C).transpose(1,2).contiguous()
- img = img_all.reshape(B*V, *img_all.shape[2:])
-
- m, s = data['mean'].float(), data['std'].float()
-
- # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None)
- # for t in [10]:
- # t_vec = torch.empty(gt.shape[0], dtype=torch.int64, device='cuda').fill_(80)
- # recon = netE.diffusion.q_sample(gt.cuda(), t_vec).detach().cpu()
- recon = netE.gen_samples(x.shape, 'cuda',
- constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
- opt.constrain_steps),
- clip_denoised=False).detach().cpu()
-
- # recon = recon.transpose(1, 2).contiguous()
- # x = x.transpose(1, 2).contiguous()
- # gt = gt.transpose(1, 2).contiguous()
- # write_to_xml_batch(os.path.join(save_dir, 'intermediate_%03d' % i),
- # (recon.detach().cpu() * s[0].squeeze() + m[0].squeeze()).numpy())
- # write_to_xml_batch(os.path.join(save_dir, 'x_%03d' % i),
- # (gt.detach().cpu() * s[0].squeeze() + m[0].squeeze()).numpy())
- # write_to_xml_batch(os.path.join(save_dir, 'noise_%03d' % i),
- # (torch.randn_like(gt).detach().cpu() * s[0].squeeze() + m[0].squeeze()).numpy())
- # for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))):
- # write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
- # visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None)
- #
- # k+=1
-
- x_adj = x.reshape(B,V,N,C)* s + m
- recon_adj = recon.reshape(B,V,N,C)* s + m
- img = img.reshape(B,V,*img.shape[1:])
-
- ref.append( gt_all * s + m)
- masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
- samples.append(recon_adj)
- images.append(img)
-
-
- ref_pcs = torch.cat(ref, dim=0)
- sample_pcs = torch.cat(samples, dim=0)
- images = torch.cat(images, dim=0)
- masked = torch.cat(masked, dim=0)
-
- B, V, N, C = ref_pcs.shape
-
-
- torch.save(ref_pcs.reshape(B*V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
- torch.save(images.reshape(B*V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
- torch.save(masked.reshape(B*V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
- # Compute metrics
- results = EMD_CD(sample_pcs.reshape(B*V, N, C),
- ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
-
- results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
-
- pprint({key: val.mean().item() for key, val in results.items()})
- logger.info({key: val.mean().item() for key, val in results.items()})
-
- results['pc'] = sample_pcs
- torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
-
- #
- # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
- #
- # results = {k: (v.cpu().detach().item()
- # if not isinstance(v, float) else v) for k, v in results.items()}
- # pprint(results)
- # logger.info(results)
-
-def evaluate_recon_mvr(opt, netE, save_dir, logger):
- test_dataset = get_mvr_dataset_v2(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
- opt.npoints, opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
- ref = []
- samples = []
- images = []
- masked = []
- k = 0
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
-
- gt_all = data['test_points']
- x_all = data['sv_points']
- mask_all= data['masks']
- # img_all = data['image']
-
-
- B,V,N,C = x_all.shape
- gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
-
- x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
- mask = mask_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
- # img = img_all.reshape(B * V, *img_all.shape[2:])
-
- m, s = data['mean'].float(), data['std'].float()
-
- recon = netE.gen_samples(x.shape, 'cuda',
- constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
- opt.constrain_steps),
- clip_denoised=False).detach().cpu()
-
- recon = recon.transpose(1, 2).contiguous()
- x = x.transpose(1, 2).contiguous()
-
- x_adj = x.reshape(B,V,N,C)* s + m
- recon_adj = recon.reshape(B,V,N,C)* s + m
- # img = img.reshape(B,V,*img.shape[1:])
-
- ref.append( gt_all * s + m)
- masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
- samples.append(recon_adj)
- # images.append(img)
-
- ref_pcs = torch.cat(ref, dim=0)
- sample_pcs = torch.cat(samples, dim=0)
- # images = torch.cat(images, dim=0)
- masked = torch.cat(masked, dim=0)
-
- B, V, N, C = ref_pcs.shape
-
-
- torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
- # torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
- torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
- # Compute metrics
- results = EMD_CD(sample_pcs.reshape(B*V, N, C),
- ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
-
- results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
-
- pprint({key: val.mean().item() for key, val in results.items()})
- logger.info({key: val.mean().item() for key, val in results.items()})
-
- results['pc'] = sample_pcs
- torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
-
- #
- # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
- #
- # results = {k: (v.cpu().detach().item()
- # if not isinstance(v, float) else v) for k, v in results.items()}
- # pprint(results)
- # logger.info(results)
-
- del ref_pcs, masked
-
-
-def generate(netE, opt, logger):
-
- _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes)
-
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
-
- with torch.no_grad():
-
- samples = []
- ref = []
-
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Generating Samples'):
-
- x = data['test_points'].transpose(1,2)
- m, s = data['mean'].float(), data['std'].float()
-
- gen = netE.gen_samples(x.shape,
- 'cuda', clip_denoised=False).detach().cpu()
-
- gen = gen.transpose(1,2).contiguous()
- x = x.transpose(1,2).contiguous()
-
-
-
- gen = gen * s + m
- x = x * s + m
- samples.append(gen)
- ref.append(x)
-
- visualize_pointcloud_batch(os.path.join(str(Path(opt.eval_path).parent), 'x.png'), gen[:64], None,
- None, None)
-
- write_to_xml_batch(os.path.join(str(Path(opt.eval_path).parent), 'xml_samples_%03d'%i), gen[:min(gen.shape[0], 40)].numpy(), cat='chair')
-
- samples = torch.cat(samples, dim=0)
- ref = torch.cat(ref, dim=0)
-
- torch.save(samples, opt.eval_path)
-
-
-
- return ref
-
-
-
-def generate_multimodal(opt, netE, save_dir, logger):
- test_dataset = get_svr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2',
- opt.npoints, opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
-
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
-
- gt_all = data['test_points']
- x_all = data['sv_points']
- mask_all= data['masks']
- img_all = data['image']
-
-
- # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None)
- # for t in [10]:
- # recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
- # opt.constrain_steps)).detach().cpu()
-
- for v in range(10):
- x = x_all.transpose(1, 2).contiguous()
- mask = mask_all.transpose(1, 2).contiguous()
- img = img_all
-
- m, s = data['mean'].float(), data['std'].float()
-
- recon = netE.gen_samples(x.shape, 'cuda',
- constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
- opt.constrain_steps),
- clip_denoised=False).detach().cpu()
-
- recon = recon.transpose(1, 2).contiguous()
- x = x.transpose(1, 2).contiguous()
-
- for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))):
-
- im = np.fliplr(np.flipud(d[-1]))
- plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray')
- write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
-
- export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
-
-
-
-
-
-
-def main(opt):
- exp_id = os.path.splitext(os.path.basename(__file__))[0]
- dir_id = os.path.dirname(__file__)
- output_dir = get_output_dir(dir_id, exp_id)
- copy_source(__file__, output_dir)
- logger = setup_logging(output_dir)
-
- outf_syn, = setup_output_subdirs(output_dir, 'syn')
-
- betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
- netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
-
- if opt.cuda:
- netE.cuda()
-
- def _transform_(m):
- return nn.parallel.DataParallel(m)
-
- netE = netE.cuda()
- netE.multi_gpu_wrapper(_transform_)
-
- netE.eval()
-
- ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
-
- with torch.no_grad():
- for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
- #'/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-02-45/epoch_1399.pth'
- #'/viscam/u/alexzhou907/research/diffusion/shapenet/output/67_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-03-38/epoch_2799.pth'
- opt.netE = '/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do_best/2020-10-16-12-23-44/epoch_1799.pth'#ckpt
- logger.info("Resume Path:%s" % opt.netE)
-
- resumed_param = torch.load(opt.netE)
- netE.load_state_dict(resumed_param['model_state'])
-
-
- ref = None
- if opt.generate:
- epoch = int(os.path.basename(ckpt).split('.')[0].split('_')[-1])
- opt.eval_path = os.path.join(outf_syn, 'epoch_{}_samples.pth'.format(epoch))
- Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True)
- ref=generate(netE, opt, logger)
- if opt.eval_gen:
- # Evaluate generation
- evaluate_gen(opt, ref, logger)
-
- if opt.eval_recon:
- # Evaluate generation
- evaluate_recon(opt, netE, outf_syn, logger)
-
- if opt.eval_recon_mvr:
- # Evaluate generation
- evaluate_recon_mvr(opt, netE, outf_syn, logger)
-
- if opt.generate_multimodal:
-
- generate_multimodal(opt, netE, outf_syn, logger)
-
-
-
- exit()
-
-def parse_args():
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
- parser.add_argument('--classes', default=['chair'])
-
- parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
- parser.add_argument('--workers', type=int, default=16, help='workers')
- parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
-
- parser.add_argument('--generate',default=True)
- parser.add_argument('--eval_gen', default=False)
- parser.add_argument('--eval_recon', default=False)
- parser.add_argument('--eval_recon_mvr', default=False)
- parser.add_argument('--generate_multimodal', default=False)
-
- parser.add_argument('--nc', default=3)
- parser.add_argument('--npoints', default=2048)
- '''model'''
- parser.add_argument('--beta_start', default=0.0001)
- parser.add_argument('--beta_end', default=0.02)
- parser.add_argument('--schedule_type', default='linear')
- parser.add_argument('--time_num', default=1000)
-
- #params
- parser.add_argument('--attention', default=True)
- parser.add_argument('--dropout', default=0.1)
- parser.add_argument('--embed_dim', type=int, default=64)
- parser.add_argument('--loss_type', default='mse')
- parser.add_argument('--model_mean_type', default='eps')
- parser.add_argument('--model_var_type', default='fixedsmall')
-
- # constrain function
- parser.add_argument('--constrain_eps', default=.051)
- parser.add_argument('--constrain_steps', type=int, default=1)
-
- parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do_best/2020-10-16-12-23-44/', help="path to netE (to continue training)")
-
- '''eval'''
-
- parser.add_argument('--eval_path',
- default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test_chair/2020-10-18-13-46-21/syn/epoch_1699_samples.pth')
-
- parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
-
- parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
-
- opt = parser.parse_args()
-
- if torch.cuda.is_available():
- opt.cuda = True
- else:
- opt.cuda = False
-
- return opt
-if __name__ == '__main__':
- opt = parse_args()
- set_seed(opt)
-
- main(opt)
-
- # results in /viscam/u/alexzhou907/research/diffusion/shapenet/output/test_chair/2020-10-18-13-46-21
\ No newline at end of file
diff --git a/shapenet/test_plane.py b/shapenet/test_plane.py
deleted file mode 100644
index 5fd3b1d..0000000
--- a/shapenet/test_plane.py
+++ /dev/null
@@ -1,925 +0,0 @@
-import torch
-import functools
-from pprint import pprint
-from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
-from metrics.evaluation_metrics import compute_all_metrics, EMD_CD, distChamfer
-
-import torch.nn as nn
-import torch.optim as optim
-import torch.utils.data
-
-import argparse
-from model.unet import get_model
-from torch.distributions import Normal
-
-from utils.file_utils import *
-from utils.visualize import *
-from utils.mitsuba_renderer import write_to_xml_batch
-from model.pvcnn_generation import PVCNN2Base
-
-from tqdm import tqdm
-
-from datasets.shapenet_data_pc import ShapeNet15kPointClouds
-from datasets.shapenet_data_sv import *
-'''
-models
-'''
-def normal_kl(mean1, logvar1, mean2, logvar2):
- """
- KL divergence between normal distributions parameterized by mean and log-variance.
- """
- return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
- + (mean1 - mean2)**2 * torch.exp(-logvar2))
-
-def discretized_gaussian_log_likelihood(x, *, means, log_scales):
- # Assumes data is integers [0, 1]
- assert x.shape == means.shape == log_scales.shape
- px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
-
- centered_x = x - means
- inv_stdv = torch.exp(-log_scales)
- plus_in = inv_stdv * (centered_x + 0.5)
- cdf_plus = px0.cdf(plus_in)
- min_in = inv_stdv * (centered_x - .5)
- cdf_min = px0.cdf(min_in)
- log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
- log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
- cdf_delta = cdf_plus - cdf_min
-
- log_probs = torch.where(
- x < 0.001, log_cdf_plus,
- torch.where(x > 0.999, log_one_minus_cdf_min,
- torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
- assert log_probs.shape == x.shape
- return log_probs
-
-
-class GaussianDiffusion:
- def __init__(self,betas, loss_type, model_mean_type, model_var_type):
- self.loss_type = loss_type
- self.model_mean_type = model_mean_type
- self.model_var_type = model_var_type
- assert isinstance(betas, np.ndarray)
- self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
- assert (betas > 0).all() and (betas <= 1).all()
- timesteps, = betas.shape
- self.num_timesteps = int(timesteps)
-
- # initialize twice the actual length so we can keep running for eval
- # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
-
- alphas = 1. - betas
- alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
- alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
-
- self.betas = torch.from_numpy(betas).float()
- self.alphas_cumprod = alphas_cumprod.float()
- self.alphas_cumprod_prev = alphas_cumprod_prev.float()
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
- self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
- self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
- self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
- self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
-
- betas = torch.from_numpy(betas).float()
- alphas = torch.from_numpy(alphas).float()
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.posterior_variance = posterior_variance
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
- self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
- self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
-
- @staticmethod
- def _extract(a, t, x_shape):
- """
- Extract some coefficients at specified timesteps,
- then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
- """
- bs, = t.shape
- assert x_shape[0] == bs
- out = torch.gather(a, 0, t)
- assert out.shape == torch.Size([bs])
- return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
-
-
-
- def q_mean_variance(self, x_start, t):
- mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
- variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
- log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
- return mean, variance, log_variance
-
- def q_sample(self, x_start, t, noise=None):
- """
- Diffuse the data (t == 0 means diffused for 1 step)
- """
- if noise is None:
- noise = torch.randn(x_start.shape, device=x_start.device)
- assert noise.shape == x_start.shape
- return (
- self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
- self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
- )
-
-
- def q_posterior_mean_variance(self, x_start, x_t, t):
- """
- Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
- """
- assert x_start.shape == x_t.shape
- posterior_mean = (
- self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
- self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
- )
- posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
- posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
- assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
- x_start.shape[0])
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
-
- def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
-
- model_output = denoise_fn(data, t)
-
-
- if self.model_var_type in ['fixedsmall', 'fixedlarge']:
- # below: only log_variance is used in the KL computations
- model_variance, model_log_variance = {
- # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
- 'fixedlarge': (self.betas.to(data.device),
- torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
- 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
- }[self.model_var_type]
- model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data)
- model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data)
- else:
- raise NotImplementedError(self.model_var_type)
-
- if self.model_mean_type == 'eps':
- x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output)
-
- if clip_denoised:
- x_recon = torch.clamp(x_recon, -.5, .5)
-
- model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t)
- else:
- raise NotImplementedError(self.loss_type)
-
-
- assert model_mean.shape == x_recon.shape == data.shape
- assert model_variance.shape == model_log_variance.shape == data.shape
- if return_pred_xstart:
- return model_mean, model_variance, model_log_variance, x_recon
- else:
- return model_mean, model_variance, model_log_variance
-
- def _predict_xstart_from_eps(self, x_t, t, eps):
- assert x_t.shape == eps.shape
- return (
- self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
- )
-
- ''' samples '''
-
- def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True):
- """
- Sample from the model
- """
- model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
- return_pred_xstart=True)
- noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device)
- assert noise.shape == data.shape
- # no noise when t == 0
- nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(data.shape) - 1))
-
- sample = model_mean
- if use_var:
- sample = sample + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
- assert sample.shape == pred_xstart.shape
- return (sample, pred_xstart) if return_pred_xstart else sample
-
-
- def p_sample_loop(self, denoise_fn, shape, device,
- noise_fn=torch.randn, constrain_fn=lambda x, t:x,
- clip_denoised=True, max_timestep=None, keep_running=False):
- """
- Generate samples
- keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
-
- """
- if max_timestep is None:
- final_time = self.num_timesteps
- else:
- final_time = max_timestep
-
- assert isinstance(shape, (tuple, list))
- img_t = noise_fn(size=shape, dtype=torch.float, device=device)
- for t in reversed(range(0, final_time if not keep_running else len(self.betas))):
- img_t = constrain_fn(img_t, t)
- t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
- clip_denoised=clip_denoised, return_pred_xstart=False).detach()
-
-
- assert img_t.shape == shape
- return img_t
-
- def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x):
-
- assert t >= 1
-
- t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1)
- encoding = self.q_sample(x0, t_vec)
-
- img_t = encoding
-
- for k in reversed(range(0,t)):
- img_t = constrain_fn(img_t, k)
- t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k)
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
- clip_denoised=False, return_pred_xstart=False, use_var=True).detach()
-
-
- return img_t
-
- def reconstruct2(self, x0, mask, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda forward, x, t:x):
- z = noise_fn(size=x0.shape, dtype=torch.float, device=x0.device)
-
- for _ in range(10):
- img_t = z
- outputs =[None for _ in range(len(self.betas))]
- for t in reversed(range(0, len(self.betas))):
-
- t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t)
- outputs[t] = img_t.detach().cpu().clone()
-
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
- clip_denoised=False, return_pred_xstart=False, use_var=True).detach()
-
-
- img_t = torch.autograd.Variable(img_t.data, requires_grad=True)
-
- dist = ((img_t - x0) ** 2 * mask).sum(dim=0).mean()
- grad = torch.autograd.grad(dist, [img_t])[0].detach()
-
- print('Dist', dist.detach().cpu().item())
-
- for t in (range(0, len(outputs))):
-
- t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t)
-
- x = outputs[t].to(x0).requires_grad_()
-
- y = self.p_sample(denoise_fn=denoise_fn, data=x, t=t_, noise_fn=noise_fn,
- clip_denoised=False, return_pred_xstart=False, use_var=True)
-
- grad = torch.autograd.grad(y, [x], grad_outputs=grad)[0]
-
-
- z = x.detach().to(x0) - 0.1 * grad.detach()
-
- img_t = z
- for t in reversed(range(0, len(self.betas))):
-
- t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t)
-
- img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
- clip_denoised=False, return_pred_xstart=False, use_var=True).detach()
-
- return img_t
-
-
-# class PVCNN2(PVCNN2Base):
-# sa_blocks = [
-# ((32, 2, 32), (1024, 0.1, 32, (32, 64))),
-# ((64, 3, 16), (256, 0.2, 32, (64, 128))),
-# ((128, 3, 8), (64, 0.4, 32, (128, 256))),
-# (None, (16, 0.8, 32, (256, 256, 512))),
-# ]
-# fp_blocks = [
-# ((256, 256), (256, 3, 8)),
-# ((256, 256), (256, 3, 8)),
-# ((256, 128), (128, 2, 16)),
-# ((128, 128, 64), (64, 2, 32)),
-# ]
-#
-# def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
-# voxel_resolution_multiplier=1):
-# super().__init__(
-# num_classes=num_classes, embed_dim=embed_dim, use_att=use_att,
-# dropout=dropout, extra_feature_channels=extra_feature_channels,
-# width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
-# )
-
-class PVCNN2(PVCNN2Base):
- sa_blocks = [
- ((32, 2, 32), (1024, 0.1, 32, (32, 64))),
- ((64, 3, 16), (256, 0.2, 32, (64, 128))),
- ((128, 3, 8), (64, 0.4, 32, (128, 256))),
- (None, (16, 0.8, 32, (256, 256, 512))),
- ]
- fp_blocks = [
- ((256, 256), (256, 3, 8)),
- ((256, 256), (256, 3, 8)),
- ((256, 128), (128, 2, 16)),
- ((128, 128, 64), (64, 2, 32)),
- ]
-
- def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
- voxel_resolution_multiplier=1):
- super().__init__(
- num_classes=num_classes, embed_dim=embed_dim, use_att=use_att,
- dropout=dropout, extra_feature_channels=extra_feature_channels,
- width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
- )
-
-
-class Model(nn.Module):
- def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
- super(Model, self).__init__()
- self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type)
-
- self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention,
- dropout=args.dropout, extra_feature_channels=0)
-
- def prior_kl(self, x0):
- return self.diffusion._prior_bpd(x0)
-
- def all_kl(self, x0, clip_denoised=True):
- total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
-
- return {
- 'total_bpd_b': total_bpd_b,
- 'terms_bpd': vals_bt,
- 'prior_bpd_b': prior_bpd_b,
- 'mse_bt':mse_bt
- }
-
-
- def _denoise(self, data, t):
- B, D,N= data.shape
- assert data.dtype == torch.float
- assert t.shape == torch.Size([B]) and t.dtype == torch.int64
-
- out = self.model(data, t)
-
- assert out.shape == torch.Size([B, D, N])
- return out
-
- def get_loss_iter(self, data, noises=None):
- B, D, N = data.shape
- t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
-
- if noises is not None:
- noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
-
- losses = self.diffusion.p_losses(
- denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
- assert losses.shape == t.shape == torch.Size([B])
- return losses
-
- def gen_samples(self, shape, device, noise_fn=torch.randn, constrain_fn=lambda x, t:x,
- clip_denoised=False, max_timestep=None,
- keep_running=False):
- return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn,
- constrain_fn=constrain_fn,
- clip_denoised=clip_denoised, max_timestep=max_timestep,
- keep_running=keep_running)
-
- def reconstruct(self, x0, t, constrain_fn=lambda x, t:x):
-
- return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn)
-
- def reconstruct2(self, x0, mask, constrain_fn):
-
- return self.diffusion.reconstruct2(x0, mask, self._denoise, constrain_fn=constrain_fn)
-
- def train(self):
- self.model.train()
-
- def eval(self):
- self.model.eval()
-
- def multi_gpu_wrapper(self, f):
- self.model = f(self.model)
-
-
-def get_betas(schedule_type, b_start, b_end, time_num):
- if schedule_type == 'linear':
- betas = np.linspace(b_start, b_end, time_num)
- elif schedule_type == 'warm0.1':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.1)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.2':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.2)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- elif schedule_type == 'warm0.5':
-
- betas = b_end * np.ones(time_num, dtype=np.float64)
- warmup_time = int(time_num * 0.5)
- betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
- else:
- raise NotImplementedError(schedule_type)
- return betas
-
-def get_constrain_function(ground_truth, mask, eps, num_steps=1):
- '''
-
- :param target_shape_constraint: target voxels
- :return: constrained x
- '''
- # eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2))
- eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 500)**2 ))
- def constrain_fn(x, t):
- eps_ = eps_all[t] if (t<500) else 0
- for _ in range(num_steps):
- x = x - eps_ * ((x - ground_truth) * mask)
-
-
- return x
-
- # mask_single = mask[0, :, 0]
- # num = mask_single.sum().int().item()
- def constrain_fn2(forward, x, t):
-
- x = torch.autograd.Variable(x.data, requires_grad=True)
- y = forward(x)
-
-
-
- dist = ((y - ground_truth)**2 * mask).sum(dim=0).mean()
- grad = torch.autograd.grad(dist, [x], retain_graph=True)[0]
- x = x - eps * (grad)
-
- print('Dist', dist.detach().cpu().item())
-
- return x
- return constrain_fn
-
-
-
-#############################################################################
-
-def get_dataset(dataroot, npoints,category,use_mask=False):
- tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True, use_mask = use_mask)
- te_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
- categories=category, split='val',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- use_mask=use_mask
- )
- return tr_dataset, te_dataset
-
-def get_mvr_dataset(pc_dataroot, mesh_root, npoints,category):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
- cache=os.path.join(mesh_root, '../cache'), split='val',
- categories=category,
- npoints=npoints, sv_samples=200,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- )
- return te_dataset
-
-def get_mvr_dataset_v2(pc_dataroot, views_root, npoints,category):
- tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
- categories=category, split='train',
- tr_sample_size=npoints,
- te_sample_size=npoints,
- scale=1.,
- normalize_per_shape=False,
- normalize_std_per_axis=False,
- random_subsample=True)
- te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
- cache=os.path.join(pc_dataroot, '../cache'), split='val',
- categories=category,
- npoints=npoints, sv_samples=200,
- all_points_mean=tr_dataset.all_points_mean,
- all_points_std=tr_dataset.all_points_std,
- )
- return te_dataset
-
-
-def evaluate_gen(opt, ref_pcs, logger):
-
- if ref_pcs is None:
- _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=False)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
- ref = []
- for data in tqdm(test_dataloader, total=len(test_dataloader), desc='Generating Samples'):
- x = data['test_points']
- m, s = data['mean'].float(), data['std'].float()
-
- ref.append(x*s + m)
-
- ref_pcs = torch.cat(ref, dim=0).contiguous()
-
- logger.info("Loading sample path: %s"
- % (opt.eval_path))
- sample_pcs = torch.load(opt.eval_path).contiguous()
-
- logger.info("Generation sample size:%s reference size: %s"
- % (sample_pcs.size(), ref_pcs.size()))
-
-
- # Compute metrics
- # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
- # results = {k: (v.cpu().detach().item()
- # if not isinstance(v, float) else v) for k, v in results.items()}
- #
- # pprint(results)
- # logger.info(results)
-
- jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy())
- pprint('JSD: {}'.format(jsd))
- logger.info('JSD: {}'.format(jsd))
-
-
-def evaluate_recon(opt, netE, save_dir, logger):
- test_dataset = get_mvr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2',
- opt.npoints, opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
- ref = []
- samples = []
- images = []
- masked = []
- k = 0
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
-
- randind = i%24
- gt_all = data['test_points'][:,randind:randind+1]
- x_all = data['sv_points'][:,randind:randind+1]
- mask_all= data['masks'][:,randind:randind+1]
- img_all = data['image'][:,randind:randind+1]
-
-
- B,V,N,C = x_all.shape
-
- x = x_all.reshape(B*V,N,C).transpose(1,2).contiguous()
- mask = mask_all.reshape(B*V,N,C).transpose(1,2).contiguous()
- img = img_all.reshape(B*V, *img_all.shape[2:])
-
- m, s = data['mean'].float(), data['std'].float()
-
- # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None)
- # for t in [10]:
- # recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
- # opt.constrain_steps)).detach().cpu()
-
- recon = netE.gen_samples(x.shape, 'cuda',
- constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
- opt.constrain_steps),
- clip_denoised=False).detach().cpu()
-
- recon = recon.transpose(1, 2).contiguous()
- x = x.transpose(1, 2).contiguous()
-
-
- # for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))):
- # write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
- # visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None)
- #
- # export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
- #
- # k+=1
-
- x_adj = x.reshape(B,V,N,C)* s + m
- recon_adj = recon.reshape(B,V,N,C)* s + m
- img = img.reshape(B,V,*img.shape[1:])
-
- ref.append( gt_all * s + m)
- masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
- samples.append(recon_adj)
- images.append(img)
-
-
- ref_pcs = torch.cat(ref, dim=0)
- sample_pcs = torch.cat(samples, dim=0)
- images = torch.cat(images, dim=0)
- masked = torch.cat(masked, dim=0)
-
- B, V, N, C = ref_pcs.shape
-
-
- torch.save(ref_pcs.reshape(B*V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
- torch.save(images.reshape(B*V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
- torch.save(masked.reshape(B*V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
- # Compute metrics
- results = EMD_CD(sample_pcs.reshape(B*V, N, C),
- ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
-
- results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
-
- pprint({key: val.mean() for key, val in results.items()})
- logger.info({key: val.mean() for key, val in results.items()})
-
- results['pc'] = sample_pcs
- torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
-
- #
- # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
- #
- # results = {k: (v.cpu().detach().item()
- # if not isinstance(v, float) else v) for k, v in results.items()}
- # pprint(results)
- # logger.info(results)
-
-
-
-def evaluate_recon_mvr(opt, netE, save_dir, logger):
- test_dataset = get_mvr_dataset_v2(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
- opt.npoints, opt.classes)
- # _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
- ref = []
- samples = []
- images = []
- masked = []
- k = 0
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
-
- gt_all = data['test_points']
- x_all = data['sv_points']
- mask_all= data['masks']
- img_all = data['image']
-
-
- B,V,N,C = x_all.shape
- gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
-
- # visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None)
- # for t in [10]:
- # recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
- # opt.constrain_steps)).detach().cpu()
-
- cd_res = []
- recon_res = []
- for p in range(5):
- x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
- mask = mask_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
- img = img_all.reshape(B * V, *img_all.shape[2:])
-
- m, s = data['mean'].float(), data['std'].float()
-
- recon = netE.gen_samples(x.shape, 'cuda',
- constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
- opt.constrain_steps),
- clip_denoised=False).detach().cpu()
-
- recon = recon.transpose(1, 2).contiguous()
- x = x.transpose(1, 2).contiguous()
-
- cd = (((recon - x)**2)*mask.transpose(1, 2).contiguous()).sum(dim=(1,2))
-
- cd_res.append(cd)
- recon_res.append(recon)
-
- cd_res = torch.stack(cd_res, dim=0)
- recon_res = torch.stack(recon_res, dim=0)
- _, argmin = torch.min(cd_res, 0)
- recon = recon_res[argmin,torch.arange(0,argmin.shape[0])]
-
- # for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))):
- # write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
- # visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None)
- #
- # export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
- #
- # k+=1
-
- x_adj = x.reshape(B,V,N,C)* s + m
- recon_adj = recon.reshape(B,V,N,C)* s + m
- img = img.reshape(B,V,*img.shape[1:])
-
- ref.append( gt_all * s + m)
- masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
- samples.append(recon_adj)
- images.append(img)
-
- ref_pcs = torch.cat(ref, dim=0)
- sample_pcs = torch.cat(samples, dim=0)
- images = torch.cat(images, dim=0)
- masked = torch.cat(masked, dim=0)
-
- B, V, N, C = ref_pcs.shape
-
-
- torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
- torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
- torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
- # Compute metrics
- results = EMD_CD(sample_pcs.reshape(B*V, N, C),
- ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
-
- results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
-
- pprint({key: val.mean().item() for key, val in results.items()})
- logger.info({key: val.mean().item() for key, val in results.items()})
-
- results['pc'] = sample_pcs
- torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
-
- #
- # results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
- #
- # results = {k: (v.cpu().detach().item()
- # if not isinstance(v, float) else v) for k, v in results.items()}
- # pprint(results)
- # logger.info(results)
-
-
-
-def generate(netE, opt, logger):
-
- _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes)
-
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
- shuffle=False, num_workers=int(opt.workers), drop_last=False)
-
- with torch.no_grad():
-
- samples = []
- ref = []
-
- for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Generating Samples'):
-
- x = data['test_points'].transpose(1,2)
- m, s = data['mean'].float(), data['std'].float()
-
- gen = netE.gen_samples(x.shape,
- 'cuda', clip_denoised=False).detach().cpu()
-
- gen = gen.transpose(1,2).contiguous()
- x = x.transpose(1,2).contiguous()
-
-
-
- gen = gen * s + m
- x = x * s + m
- samples.append(gen)
- ref.append(x)
-
- # visualize_pointcloud_batch(os.path.join(str(Path(opt.eval_path).parent), 'x_%03d.png'%i), gen[:64], None,
- # None, None)
- # export_to_pc_batch(os.path.join(str(Path(opt.eval_path).parent), 'ply_%03d'%i),
- # gen[:64].numpy())
- write_to_xml_batch(os.path.join(str(Path(opt.eval_path).parent), 'xml_samples_%03d'%i), gen[:min(gen.shape[0], 40)].numpy(), cat='airplane')
-
- samples = torch.cat(samples, dim=0)
- ref = torch.cat(ref, dim=0)
-
- torch.save(samples, opt.eval_path)
-
-
-
- return ref
-
-def main(opt):
- exp_id = os.path.splitext(os.path.basename(__file__))[0]
- dir_id = os.path.dirname(__file__)
- output_dir = get_output_dir(dir_id, exp_id)
- copy_source(__file__, output_dir)
- logger = setup_logging(output_dir)
-
- outf_syn, = setup_output_subdirs(output_dir, 'syn')
-
- betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
- netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
-
- if opt.cuda:
- netE.cuda()
-
- def _transform_(m):
- return nn.parallel.DataParallel(m)
-
- netE = netE.cuda()
- netE.multi_gpu_wrapper(_transform_)
-
- # netE.eval()
-
- ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
-
- with torch.no_grad():
- for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
- #'/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-02-45/epoch_1399.pth'
- #'/viscam/u/alexzhou907/research/diffusion/shapenet/output/67_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-03-38/epoch_2799.pth'
- # opt.netE = '/viscam/u/alexzhou907/research/diffusion/shapenet/output/71_res32_pc_plane_mse_fs_dattng_8e3-1e5_w0.1beta_0wd_0.1do_best/2020-10-13-13-33-53/epoch_2899.pth'#ckpt
- opt.netE = '/viscam/u/alexzhou907/research/diffusion/shapenet/output/71_res32_pc_plane_mse_fs_dattng_8e3-1e5_w0.1beta_0wd_0.1do/2020-10-07-13-26-10/epoch_2299.pth'
- logger.info("Resume Path:%s" % opt.netE)
-
- resumed_param = torch.load(opt.netE)
- netE.load_state_dict(resumed_param['model_state'])
-
-
- ref = None
- if opt.generate:
- epoch = int(os.path.basename(ckpt).split('.')[0].split('_')[-1])
- opt.eval_path = os.path.join(outf_syn, 'epoch_{}_samples.pth'.format(epoch))
- Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True)
- ref=generate(netE, opt, logger)
- if opt.eval_gen:
- # Evaluate generation
- evaluate_gen(opt, ref, logger)
-
- if opt.eval_recon:
- # Evaluate generation
- evaluate_recon(opt, netE, outf_syn, logger)
-
- if opt.eval_recon_mvr:
- # Evaluate generation
- evaluate_recon_mvr(opt, netE, outf_syn, logger)
-
-
-def parse_args():
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
- parser.add_argument('--classes', default=['airplane'])
-
- parser.add_argument('--batch_size', type=int, default=20, help='input batch size')
- parser.add_argument('--workers', type=int, default=16, help='workers')
- parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
-
- parser.add_argument('--generate',default=False)
- parser.add_argument('--eval_gen', default=True)
- parser.add_argument('--eval_recon', default=False)
- parser.add_argument('--eval_recon_mvr', default=False)
-
- parser.add_argument('--nc', default=3)
- parser.add_argument('--npoints', default=2048)
- '''model'''
- parser.add_argument('--beta_start', default=0.00001)
- parser.add_argument('--beta_end', default=0.008)
- parser.add_argument('--schedule_type', default='warm0.1')
- parser.add_argument('--time_num', default=1000)
-
- #params
- parser.add_argument('--attention', default=True)
- parser.add_argument('--dropout', default=0.1)
- parser.add_argument('--embed_dim', type=int, default=64)
- parser.add_argument('--loss_type', default='mse')
- parser.add_argument('--model_mean_type', default='eps')
- parser.add_argument('--model_var_type', default='fixedsmall')
-
- # constrain function
- parser.add_argument('--constrain_eps', default=.1)
- parser.add_argument('--constrain_steps', type=int, default=1)
-
- parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/71_res32_pc_plane_mse_fs_dattng_8e3-1e5_w0.1beta_0wd_0.1do_best/2020-10-13-13-33-53', help="path to netE (to continue training)")
-
- '''eval'''
-
- parser.add_argument('--eval_path',
- default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test_plane/2020-10-18-13-49-20/syn/epoch_2499_samples.pth')
-
- parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
-
- parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
-
- opt = parser.parse_args()
-
- if torch.cuda.is_available():
- opt.cuda = True
- else:
- opt.cuda = False
-
- return opt
-if __name__ == '__main__':
- opt = parse_args()
- set_seed(opt)
-
- main(opt)
diff --git a/shape_completion/test_completion.py b/test_completion.py
similarity index 98%
rename from shape_completion/test_completion.py
rename to test_completion.py
index 9be1ed7..4bbd5a1 100644
--- a/shape_completion/test_completion.py
+++ b/test_completion.py
@@ -7,9 +7,7 @@ import torch.utils.data
import argparse
from torch.distributions import Normal
-from utils.visualize import *
from utils.file_utils import *
-from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
@@ -579,9 +577,8 @@ def main(opt):
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
- parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
- help='input batch size')
+ parser.add_argument('--dataroot_pc', default='ShapeNetCore.v2.PC15k/')
+ parser.add_argument('--dataroot_sv', default='GenReData/')
parser.add_argument('--category', default='chair')
parser.add_argument('--batch_size', type=int, default=50, help='input batch size')
diff --git a/shapenet/test_generation.py b/test_generation.py
similarity index 98%
rename from shapenet/test_generation.py
rename to test_generation.py
index 979e0d0..9659ba0 100644
--- a/shapenet/test_generation.py
+++ b/test_generation.py
@@ -4,16 +4,13 @@ from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch.nn as nn
-import torch.optim as optim
import torch.utils.data
import argparse
-from model.unet import get_model
from torch.distributions import Normal
from utils.file_utils import *
from utils.visualize import *
-from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_generation import PVCNN2Base
from tqdm import tqdm
@@ -271,6 +268,8 @@ class PVCNN2(PVCNN2Base):
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
+
+
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
@@ -534,8 +533,8 @@ def main(opt):
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
- parser.add_argument('--category', default='car')
+ parser.add_argument('--dataroot', default='ShapeNetCore.v2.PC15k/')
+ parser.add_argument('--category', default='chair')
parser.add_argument('--batch_size', type=int, default=50, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
@@ -585,5 +584,3 @@ if __name__ == '__main__':
set_seed(opt)
main(opt)
-
- # results in /viscam/u/alexzhou907/research/diffusion/shapenet/output/test_chair
\ No newline at end of file
diff --git a/shape_completion/train_completion.py b/train_completion.py
similarity index 98%
rename from shape_completion/train_completion.py
rename to train_completion.py
index 6c3aae5..8d81818 100644
--- a/shape_completion/train_completion.py
+++ b/train_completion.py
@@ -4,7 +4,6 @@ import torch.optim as optim
import torch.utils.data
import argparse
-from model.unet import get_model
from torch.distributions import Normal
from utils.file_utils import *
@@ -559,7 +558,7 @@ def train(gpu, opt, output_dir, noises_init):
''' data '''
- train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints,opt.classes)
+ train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints,opt.category)
dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None)
@@ -609,10 +608,6 @@ def train(gpu, opt, output_dir, noises_init):
else:
start_epoch = 0
- def new_x_chain(x, num_chain):
- return torch.randn(num_chain, *x.shape[1:], device=x.device)
-
-
for epoch in range(start_epoch, opt.niter):
@@ -754,7 +749,7 @@ def main():
''' workaround '''
- train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints,opt.classes)
+ train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints,opt.category)
noises_init = torch.randn(len(train_dataset), opt.npoints-opt.svpoints, opt.nc)
if opt.dist_url == "env://" and opt.world_size == -1:
@@ -772,9 +767,8 @@ def main():
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
- parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
- help='input batch size')
+ parser.add_argument('--dataroot_pc', default='ShapeNetCore.v2.PC15k/')
+ parser.add_argument('--dataroot_sv', default='GenReData/')
parser.add_argument('--category', default='chair')
parser.add_argument('--bs', type=int, default=48, help='input batch size')
diff --git a/shapenet/train_generation.py b/train_generation.py
similarity index 99%
rename from shapenet/train_generation.py
rename to train_generation.py
index 9141740..83d39ad 100644
--- a/shapenet/train_generation.py
+++ b/train_generation.py
@@ -4,7 +4,6 @@ import torch.optim as optim
import torch.utils.data
import argparse
-from model.unet import get_model
from torch.distributions import Normal
from utils.file_utils import *
@@ -787,10 +786,10 @@ def main():
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
+ parser.add_argument('--dataroot', default='ShapeNetCore.v2.PC15k/')
parser.add_argument('--category', default='chair')
- parser.add_argument('--bs', type=int, default=48, help='input batch size')
+ parser.add_argument('--bs', type=int, default=16, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
diff --git a/utils/binvox_rw.py b/utils/binvox_rw.py
deleted file mode 100644
index 73190d2..0000000
--- a/utils/binvox_rw.py
+++ /dev/null
@@ -1,266 +0,0 @@
-# Copyright (C) 2012 Daniel Maturana
-# This file is part of binvox-rw-py.
-#
-# binvox-rw-py is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# binvox-rw-py is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with binvox-rw-py. If not, see .
-#
-
-"""
-Binvox to Numpy and back.
->>> import numpy as np
->>> import binvox_rw
->>> with open('chair.binvox', 'rb') as f:
-... m1 = binvox_rw.read_as_3d_array(f)
-...
->>> m1.dims
-[32, 32, 32]
->>> m1.scale
-41.133000000000003
->>> m1.translate
-[0.0, 0.0, 0.0]
->>> with open('chair_out.binvox', 'wb') as f:
-... m1.write(f)
-...
->>> with open('chair_out.binvox', 'rb') as f:
-... m2 = binvox_rw.read_as_3d_array(f)
-...
->>> m1.dims==m2.dims
-True
->>> m1.scale==m2.scale
-True
->>> m1.translate==m2.translate
-True
->>> np.all(m1.data==m2.data)
-True
->>> with open('chair.binvox', 'rb') as f:
-... md = binvox_rw.read_as_3d_array(f)
-...
->>> with open('chair.binvox', 'rb') as f:
-... ms = binvox_rw.read_as_coord_array(f)
-...
->>> data_ds = binvox_rw.dense_to_sparse(md.data)
->>> data_sd = binvox_rw.sparse_to_dense(ms.data, 32)
->>> np.all(data_sd==md.data)
-True
->>> # the ordering of elements returned by numpy.nonzero changes with axis
->>> # ordering, so to compare for equality we first lexically sort the voxels.
->>> np.all(ms.data[:, np.lexsort(ms.data)] == data_ds[:, np.lexsort(data_ds)])
-True
-"""
-
-import numpy as np
-
-class Voxels(object):
- """ Holds a binvox model.
- data is either a three-dimensional numpy boolean array (dense representation)
- or a two-dimensional numpy float array (coordinate representation).
- dims, translate and scale are the model metadata.
- dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model.
- scale and translate relate the voxels to the original model coordinates.
- To translate voxel coordinates i, j, k to original coordinates x, y, z:
- x_n = (i+.5)/dims[0]
- y_n = (j+.5)/dims[1]
- z_n = (k+.5)/dims[2]
- x = scale*x_n + translate[0]
- y = scale*y_n + translate[1]
- z = scale*z_n + translate[2]
- """
-
- def __init__(self, data, dims, translate, scale, axis_order):
- self.data = data
- self.dims = dims
- self.translate = translate
- self.scale = scale
- assert (axis_order in ('xzy', 'xyz'))
- self.axis_order = axis_order
-
- def clone(self):
- data = self.data.copy()
- dims = self.dims[:]
- translate = self.translate[:]
- return Voxels(data, dims, translate, self.scale, self.axis_order)
-
- def write(self, fp):
- write(self, fp)
-
-def read_header(fp):
- """ Read binvox header. Mostly meant for internal use.
- """
- line = fp.readline().strip()
- if not line.startswith(b'#binvox'):
- raise IOError('Not a binvox file')
- dims = list(map(int, fp.readline().strip().split(b' ')[1:]))
- translate = list(map(float, fp.readline().strip().split(b' ')[1:]))
- scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0]
- line = fp.readline()
- return dims, translate, scale
-
-def read_as_3d_array(fp, fix_coords=True):
- """ Read binary binvox format as array.
- Returns the model with accompanying metadata.
- Voxels are stored in a three-dimensional numpy array, which is simple and
- direct, but may use a lot of memory for large models. (Storage requirements
- are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy
- boolean arrays use a byte per element).
- Doesn't do any checks on input except for the '#binvox' line.
- """
- dims, translate, scale = read_header(fp)
- raw_data = np.frombuffer(fp.read(), dtype=np.uint8)
- # if just using reshape() on the raw data:
- # indexing the array as array[i,j,k], the indices map into the
- # coords as:
- # i -> x
- # j -> z
- # k -> y
- # if fix_coords is true, then data is rearranged so that
- # mapping is
- # i -> x
- # j -> y
- # k -> z
- values, counts = raw_data[::2], raw_data[1::2]
- data = np.repeat(values, counts).astype(np.bool)
- data = data.reshape(dims)
- if fix_coords:
- # xzy to xyz TODO the right thing
- data = np.transpose(data, (0, 2, 1))
- axis_order = 'xyz'
- else:
- axis_order = 'xzy'
- return Voxels(data, dims, translate, scale, axis_order)
-
-def read_as_coord_array(fp, fix_coords=True):
- """ Read binary binvox format as coordinates.
- Returns binvox model with voxels in a "coordinate" representation, i.e. an
- 3 x N array where N is the number of nonzero voxels. Each column
- corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates
- of the voxel. (The odd ordering is due to the way binvox format lays out
- data). Note that coordinates refer to the binvox voxels, without any
- scaling or translation.
- Use this to save memory if your model is very sparse (mostly empty).
- Doesn't do any checks on input except for the '#binvox' line.
- """
- dims, translate, scale = read_header(fp)
- raw_data = np.frombuffer(fp.read(), dtype=np.uint8)
-
- values, counts = raw_data[::2], raw_data[1::2]
-
- sz = np.prod(dims)
- index, end_index = 0, 0
- end_indices = np.cumsum(counts)
- indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype)
-
- values = values.astype(np.bool)
- indices = indices[values]
- end_indices = end_indices[values]
-
- nz_voxels = []
- for index, end_index in zip(indices, end_indices):
- nz_voxels.extend(range(index, end_index))
- nz_voxels = np.array(nz_voxels)
- # TODO are these dims correct?
- # according to docs,
- # index = x * wxh + z * width + y; // wxh = width * height = d * d
-
- x = nz_voxels / (dims[0]*dims[1])
- zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y
- z = zwpy / dims[0]
- y = zwpy % dims[0]
- if fix_coords:
- data = np.vstack((x, y, z))
- axis_order = 'xyz'
- else:
- data = np.vstack((x, z, y))
- axis_order = 'xzy'
-
- #return Voxels(data, dims, translate, scale, axis_order)
- return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order)
-
-def dense_to_sparse(voxel_data, dtype=np.int):
- """ From dense representation to sparse (coordinate) representation.
- No coordinate reordering.
- """
- if voxel_data.ndim!=3:
- raise ValueError('voxel_data is wrong shape; should be 3D array.')
- return np.asarray(np.nonzero(voxel_data), dtype)
-
-def sparse_to_dense(voxel_data, dims, dtype=np.bool):
- if voxel_data.ndim!=2 or voxel_data.shape[0]!=3:
- raise ValueError('voxel_data is wrong shape; should be 3xN array.')
- if np.isscalar(dims):
- dims = [dims]*3
- dims = np.atleast_2d(dims).T
- # truncate to integers
- xyz = voxel_data.astype(np.int)
- # discard voxels that fall outside dims
- valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0)
- xyz = xyz[:,valid_ix]
- out = np.zeros(dims.flatten(), dtype=dtype)
- out[tuple(xyz)] = True
- return out
-
-#def get_linear_index(x, y, z, dims):
- #""" Assuming xzy order. (y increasing fastest.
- #TODO ensure this is right when dims are not all same
- #"""
- #return x*(dims[1]*dims[2]) + z*dims[1] + y
-
-def write(voxel_model, fp):
- """ Write binary binvox format.
- Note that when saving a model in sparse (coordinate) format, it is first
- converted to dense format.
- Doesn't check if the model is 'sane'.
- """
- if voxel_model.data.ndim==2:
- # TODO avoid conversion to dense
- dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims)
- else:
- dense_voxel_data = voxel_model.data
-
- fp.write('#binvox 1\n')
- fp.write('dim '+' '.join(map(str, voxel_model.dims))+'\n')
- fp.write('translate '+' '.join(map(str, voxel_model.translate))+'\n')
- fp.write('scale '+str(voxel_model.scale)+'\n')
- fp.write('data\n')
- if not voxel_model.axis_order in ('xzy', 'xyz'):
- raise ValueError('Unsupported voxel model axis order')
-
- if voxel_model.axis_order=='xzy':
- voxels_flat = dense_voxel_data.flatten()
- elif voxel_model.axis_order=='xyz':
- voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten()
-
- # keep a sort of state machine for writing run length encoding
- state = voxels_flat[0]
- ctr = 0
- for c in voxels_flat:
- if c==state:
- ctr += 1
- # if ctr hits max, dump
- if ctr==255:
- fp.write(chr(state))
- fp.write(chr(ctr))
- ctr = 0
- else:
- # if switch state, dump
- fp.write(chr(state))
- fp.write(chr(ctr))
- state = c
- ctr = 1
- # flush out remainders
- if ctr > 0:
- fp.write(chr(state))
- fp.write(chr(ctr))
-
-if __name__ == '__main__':
- import doctest
- doctest.testmod()
\ No newline at end of file
diff --git a/utils/conversion.py b/utils/conversion.py
deleted file mode 100644
index 1d90335..0000000
--- a/utils/conversion.py
+++ /dev/null
@@ -1,46 +0,0 @@
-
-from skimage import measure
-import numpy as np
-
-
-def get_mesh(tsdf_vol, color_vol, threshold=0, vol_max=.5, vol_min=-.5):
- """Compute a mesh from the voxel volume using marching cubes.
- """
- vol_origin = vol_min
- voxel_size = (vol_max - vol_min) / tsdf_vol.shape[-1]
-
- # Marching cubes
- verts, faces, norms, vals = measure.marching_cubes_lewiner(tsdf_vol, level=threshold)
- verts_ind = np.round(verts).astype(int)
- verts = verts * voxel_size + vol_origin # voxel grid coordinates to world coordinates
-
- # Get vertex colors
- if color_vol is None:
- return verts, faces, norms
- colors = color_vol[:, verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]].T
-
- return verts, faces, norms, colors
-
-
-def get_point_cloud(tsdf_vol, color_vol, vol_max=0.5, vol_min=-0.5):
- vol_origin = vol_min
- voxel_size = (vol_max - vol_min) / tsdf_vol.shape[-1]
- # Marching cubes
- verts = measure.marching_cubes_lewiner(tsdf_vol, level=0)[0]
- verts_ind = np.round(verts).astype(int)
- verts = verts * voxel_size + vol_origin
-
- # Get vertex colors
- colors = color_vol[:, verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]].T
-
- pc = np.hstack([verts, colors])
-
- return pc
-
-def sparse_to_dense_voxel(coords, feats, res):
- coords = coords.astype('int64', copy=False)
- a = np.zeros((res, res, res), dtype=feats.dtype)
-
- a[coords[:,0],coords[:,1],coords[:,2] ] = feats[:,0].astype(a.dtype, copy=False)
-
- return a
\ No newline at end of file
diff --git a/utils/mitsuba_renderer.py b/utils/mitsuba_renderer.py
deleted file mode 100644
index 26cc3bf..0000000
--- a/utils/mitsuba_renderer.py
+++ /dev/null
@@ -1,146 +0,0 @@
-import numpy as np
-from pathlib import Path
-import os
-
-
-def standardize_bbox(pcl, points_per_object, scale=None):
- pt_indices = np.random.choice(pcl.shape[0], points_per_object, replace=False)
- np.random.shuffle(pt_indices)
- pcl = pcl[pt_indices] # n by 3
- mins = np.amin(pcl, axis=0)
- maxs = np.amax(pcl, axis=0)
- center = (mins + maxs) / 2.
- if scale is None:
- scale = np.amax(maxs - mins)
- result = ((pcl - center) / scale).astype(np.float32) # [-0.5, 0.5]
- return result
-
-
-xml_head = \
- """
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- """
-
-xml_ball_segment = \
- """
-
-
-
-
-
-
-
-
-
- """
-
-xml_tail = \
- """
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- """
-
-
-def colormap_fn(x, y, z):
- vec = np.array([x, y, z])
- vec = np.clip(vec, 0.001, 1.0)
- norm = np.sqrt(np.sum(vec ** 2))
- vec /= norm
- return [vec[0], vec[1], vec[2]]
-
-
-color_dict = {'r': [163, 102, 96], 'g': [20, 130, 3],
- 'o': [145, 128, 47], 'b': [91, 102, 112], 'p':[133,111,139], 'br':[111,92,81]}
-
-color_map = {'airplane': 'r', 'chair': 'o', 'car': 'b', 'table': 'p', 'lamp':'br'}
-fov_map = {'airplane': 12, 'chair': 16, 'car':15, 'table': 13, 'lamp':13}
-radius_map = {'airplane': 0.02, 'chair': 0.035, 'car': 0.01, 'table':0.035, 'lamp':0.035}
-
-def write_to_xml_batch(dir, pcl_batch, filenames=None, color_batch=None, cat='airplane'):
- default_color = color_map[cat]
- Path(dir).mkdir(parents=True, exist_ok=True)
- if filenames is not None:
- assert len(filenames) == pcl_batch.shape[0]
- # mins = np.amin(pcl_batch, axis=(0,1))
- # maxs = np.amax(pcl_batch, axis=(0,1))
- # scale = 1; print(np.amax(maxs - mins))
-
- for k, pcl in enumerate(pcl_batch):
- xml_segments = [xml_head.format(fov_map[cat])]
- pcl = standardize_bbox(pcl, pcl.shape[0])
- pcl = pcl[:, [2, 0, 1]]
- pcl[:, 0] *= -1
- pcl[:, 2] += 0.0125
- for i in range(pcl.shape[0]):
- if color_batch is not None:
- color = color_batch[k, i]
- else:
- color = np.array(color_dict[default_color]) / 255
- # color = colormap_fn(pcl[i,0]+0.5,pcl[i,1]+0.5,pcl[i,2]+0.5-0.0125)
- xml_segments.append(xml_ball_segment.format(radius_map[cat], pcl[i, 0], pcl[i, 1], pcl[i, 2], *color))
- xml_segments.append(
- xml_tail.format(pcl[:, 2].min()))
-
- xml_content = str.join('', xml_segments)
-
- if filenames is None:
- fn = 'sample_{}.xml'.format(k)
- else:
- fn = filenames[k]
- with open(os.path.join(dir, fn), 'w') as f:
- f.write(xml_content)
diff --git a/utils/mitsuba_renderer2.py b/utils/mitsuba_renderer2.py
deleted file mode 100644
index cb725fd..0000000
--- a/utils/mitsuba_renderer2.py
+++ /dev/null
@@ -1,170 +0,0 @@
-import numpy as np
-from pathlib import Path
-import os
-
-
-def standardize_bbox(pcl, points_per_object):
- pt_indices = np.random.choice(pcl.shape[0], points_per_object, replace=False)
- np.random.shuffle(pt_indices)
- pcl = pcl[pt_indices] # n by 3
- mins = np.amin(pcl, axis=0)
- maxs = np.amax(pcl, axis=0)
- center = (mins + maxs) / 2.
- scale = np.amax(maxs - mins)
- result = ((pcl - center) / scale).astype(np.float32) # [-0.5, 0.5]
- return result
-
-
-xml_head = \
- """
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- """
-
-xml_ball_segment = \
- """
-
-
-
-
-
-
-
-
-
- """
-
-xml_tail = \
- """
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- """
-
-
-def colormap_fn(x, y, z):
- vec = np.array([x, y, z])
- vec = np.clip(vec, 0.001, 1.0)
- norm = np.sqrt(np.sum(vec ** 2))
- vec /= norm
- return [vec[0], vec[1], vec[2]]
-
-
-color_dict = {'r': [163, 102, 96], 'p': [133,111,139], 'g': [20, 130, 3],
- 'o': [145, 128, 47], 'b': [91, 102, 112]}
-
-color_map = {'airplane': 'r', 'chair': 'o', 'car': 'b', 'table': 'p'}
-fov_map = {'airplane': 12, 'chair': 15, 'car':12, 'table':12}
-radius_map = {'airplane': 0.0175, 'chair': 0.035, 'car': 0.025, 'table': 0.02}
-
-def write_to_xml_batch(dir, pcl_batch, color_batch=None, cat='airplane', elev=15, azim=45, radius=np.sqrt(18)):
- elev_rad = elev * np.pi / 180
- azim_rad = azim * np.pi / 180
-
- x = radius * np.cos(elev_rad)*np.cos(azim_rad)
- y = radius * np.cos(elev_rad)*np.sin(azim_rad)
- z = radius * np.sin(elev_rad)
-
- default_color = color_map[cat]
- Path(dir).mkdir(parents=True, exist_ok=True)
- for k, pcl in enumerate(pcl_batch):
- xml_segments = [xml_head.format(x,y,z)]
- pcl = standardize_bbox(pcl, pcl.shape[0])
- pcl = pcl[:, [2, 0, 1]]
- pcl[:, 0] *= -1
- pcl[:, 2] += 0.0125
- for i in range(pcl.shape[0]):
- if color_batch is not None:
- color = color_batch[k, i]
- else:
- color = np.array(color_dict[default_color]) / 255
- # color = colormap_fn(pcl[i,0]+0.5,pcl[i,1]+0.5,pcl[i,2]+0.5-0.0125)
- xml_segments.append(xml_ball_segment.format(0.0175, pcl[i, 0], pcl[i, 1], pcl[i, 2], *color))
- xml_segments.append(
- xml_tail.format(pcl[:, 2].min()))
-
- xml_content = str.join('', xml_segments)
-
- with open(os.path.join(dir, 'sample_{}.xml'.format(k)), 'w') as f:
- f.write(xml_content)
-
-def write_to_xml(file, pcl, cat='airplane', elev=15, azim=45, radius=np.sqrt(18)):
- assert pcl.ndim == 2
- elev_rad = elev * np.pi / 180
- azim_rad = azim * np.pi / 180
-
- x = radius * np.cos(elev_rad)*np.cos(azim_rad)
- y = radius * np.cos(elev_rad)*np.sin(azim_rad)
- z = radius * np.sin(elev_rad)
-
- default_color = color_map[cat]
-
- xml_segments = [xml_head.format(x,y,z)]
- pcl = standardize_bbox(pcl, pcl.shape[0])
- pcl = pcl[:, [2, 0, 1]]
- pcl[:, 0] *= -1
- pcl[:, 2] += 0.0125
- for i in range(pcl.shape[0]):
- color = np.array(color_dict[default_color]) / 255
- # color = colormap_fn(pcl[i,0]+0.5,pcl[i,1]+0.5,pcl[i,2]+0.5-0.0125)
- xml_segments.append(xml_ball_segment.format(radius_map[cat], pcl[i, 0], pcl[i, 1], pcl[i, 2], *color))
- xml_segments.append(
- xml_tail.format(pcl[:, 2].min()))
-
- xml_content = str.join('', xml_segments)
-
- with open(file, 'w') as f:
- f.write(xml_content)
diff --git a/utils/xml_from_mesh.py b/utils/xml_from_mesh.py
deleted file mode 100644
index b833148..0000000
--- a/utils/xml_from_mesh.py
+++ /dev/null
@@ -1,86 +0,0 @@
-import sys
-
-sys.path.append('..')
-import argparse
-import os
-import numpy as np
-import trimesh
-import glob
-from joblib import Parallel, delayed
-import re
-from utils.mitsuba_renderer import write_to_xml_batch
-def rotation_matrix(axis, theta):
- """
- Return the rotation matrix associated with counterclockwise rotation about
- the given axis by theta radians.
- """
- axis = np.asarray(axis)
- axis = axis / np.sqrt(np.dot(axis, axis))
- a = np.cos(theta / 2.0)
- b, c, d = -axis * np.sin(theta / 2.0)
- aa, bb, cc, dd = a * a, b * b, c * c, d * d
- bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
- return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
- [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
- [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
-
-def rotate(vertices, faces):
- '''
- vertices: [numpoints, 3]
- '''
- N = rotation_matrix([1, 0, 0], 3* np.pi / 4).transpose()
- # M = rotation_matrix([0, 1, 0], -np.pi / 2).transpose()
-
-
- v, f = vertices.dot(N), faces
- return v, f
-
-def as_mesh(scene_or_mesh):
- if isinstance(scene_or_mesh, trimesh.Scene):
- mesh = trimesh.util.concatenate([
- trimesh.Trimesh(vertices=m.vertices, faces=m.faces)
- for m in scene_or_mesh.geometry.values()])
- else:
- mesh = scene_or_mesh
- return mesh
-def process_one(shape_dir, cat):
- pc_paths = glob.glob(os.path.join(shape_dir, "*.obj"))
- pc_paths = sorted(pc_paths)
-
- xml_paths = [] #[re.sub('.ply', '.xml', os.path.basename(pth)) for pth in pc_paths]
-
- gen_pcs = []
- for path in pc_paths:
- sample_mesh = trimesh.load(path, force='mesh')
- v, f = rotate(sample_mesh.vertices,sample_mesh.faces)
- mesh = trimesh.Trimesh(v, f)
- sample_pts = trimesh.sample.sample_surface(mesh, 2048)[0]
- gen_pcs.append(sample_pts)
- xml_paths.append(re.sub('.obj', '.xml', os.path.basename(path)))
-
-
-
- gen_pcs = np.stack(gen_pcs, axis=0)
- write_to_xml_batch(os.path.dirname(pc_paths[0]), gen_pcs, xml_paths, cat=cat)
-
-
-def process(args):
- shape_names = [n for n in sorted(os.listdir(args.src)) if
- os.path.isdir(os.path.join(args.src, n)) and not n.startswith('x')]
-
- all_shape_dir = [os.path.join(args.src, name) for name in shape_names]
-
- Parallel(n_jobs=10, verbose=2)(delayed(process_one)(path) for path in all_shape_dir)
-
-
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("--src", type=str)
- parser.add_argument("--cat", type=str)
- args = parser.parse_args()
-
- process_one(args.src, args.cat)
-
-
-if __name__ == '__main__':
- main()
\ No newline at end of file
diff --git a/utils/xml_from_ply.py b/utils/xml_from_ply.py
deleted file mode 100644
index 6a0aff1..0000000
--- a/utils/xml_from_ply.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import sys
-
-sys.path.append('..')
-import argparse
-import os
-import numpy as np
-import trimesh
-import glob
-from joblib import Parallel, delayed
-import re
-from utils.mitsuba_renderer import write_to_xml_batch
-
-
-def process_one(shape_dir):
- pc_paths = glob.glob(os.path.join(shape_dir, "fake*.ply"))
- pc_paths = sorted(pc_paths)
-
- xml_paths = [re.sub('.ply', '.xml', os.path.basename(pth)) for pth in pc_paths]
-
- gen_pcs = []
- for path in pc_paths:
- sample_pts = trimesh.load(path)
- sample_pts = np.array(sample_pts.vertices)
- gen_pcs.append(sample_pts)
-
- raw_pc = np.array(trimesh.load(os.path.join(shape_dir, "raw.ply")).vertices)
- raw_pc = np.concatenate([raw_pc, np.tile(raw_pc[0:1], (gen_pcs[0].shape[0]-raw_pc.shape[0],1))])
-
- gen_pcs.append(raw_pc)
- gen_pcs = np.stack(gen_pcs, axis=0)
- xml_paths.append('raw.xml')
-
- write_to_xml_batch(os.path.dirname(pc_paths[0]), gen_pcs, xml_paths)
-
-
-def process(args):
- shape_names = [n for n in sorted(os.listdir(args.src)) if
- os.path.isdir(os.path.join(args.src, n)) and not n.startswith('x')]
-
- all_shape_dir = [os.path.join(args.src, name) for name in shape_names]
-
- Parallel(n_jobs=10, verbose=2)(delayed(process_one)(path) for path in all_shape_dir)
-
-
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("--src", type=str)
- args = parser.parse_args()
-
- process_one(args)
-
-
-if __name__ == '__main__':
- main()
\ No newline at end of file