This commit is contained in:
Linqi (Alex) Zhou 2021-11-01 00:12:02 -07:00
parent 2f6aa752a6
commit 958537389a
34 changed files with 177 additions and 11490 deletions

View file

@ -1,12 +1,13 @@
# Shape Generation and Completion Through Point-Voxel Diffusion
<p float="left">
<img src="assets/pvd_teaser.gif" width="80%"/>
</p>
[Project]() | [Paper]()
[Project](https://alexzhou907.github.io/pvd) | [Paper](https://arxiv.org/abs/2104.03670)
Implementation of
Implementation of Shape Generation and Completion Through Point-Voxel Diffusion
## Pretrained Models
Pretrained models can be accessed [here](https://www.dropbox.com/s/a3xydf594fzaokl/cifar10_pretrained.rar?dl=0).
[Linqi Zhou](https://alexzhou907.github.io), [Yilun Du](https://yilundu.github.io/), [Jiajun Wu](https://jiajunwu.com/)
## Requirements:
@ -20,24 +21,56 @@ cudatoolkit==10.1
matplotlib==2.2.5
tqdm==4.32.1
open3d==0.9.0
trimesh=3.7.12
scipy==1.5.1
```
Install PyTorchEMD by
```
cd metrics/PyTorchEMD
python setup.py install
cp build/**/emd_cuda.cpython-36m-x86_64-linux-gnu.so .
```
The code was tested on Unbuntu with Titan RTX.
## Data
## Training on CIFAR-10:
For generation, we use ShapeNet point cloud, which can be downloaded [here](https://github.com/stevenygd/PointFlow).
For completion, we use ShapeNet rendering provided by [GenRe](https://github.com/xiumingzhang/GenRe-ShapeHD).
We provide script `convert_cam_params.py` to process the provided data.
For training the model on shape completion, we need camera parameters for each view
which are not directly available. To obtain these, simply run
```bash
$ python convert_cam_params.py --dataroot DATA_DIR --mitsuba_xml_root XML_DIR
```
which will create `..._cam_params.npz` in each provided data folder for each view.
## Pretrained models
Pretrained models can be downloaded [here](https://drive.google.com/drive/folders/1Q7aSaTr6lqmo8qx80nIm1j28mOHAHGiM?usp=sharing).
## Training:
```bash
$ python train_cifar.py
$ python train_generation.py --category car|chair|airplane
```
Please refer to the python file for optimal training parameters.
## Testing:
```bash
$ python train_generation.py --category car|chair|airplane --model MODEL_PATH
```
## Results
Some generative results are as follows.
<p float="left">
<img src="example/cifar_gen.png" width="300"/>
<img src="example/lsun_gen.png" width="300"/>
<img src="assets/cifar_gen.png" width="300"/>
<img src="assets/lsun_gen.png" width="300"/>
</p>

BIN
assets/gen_comp.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.6 MiB

BIN
assets/mm_partnet.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.2 MiB

BIN
assets/mm_redwood.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.7 MiB

BIN
assets/mm_shapenet.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 MiB

BIN
assets/pvd_teaser.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 MiB

123
convert_cam_params.py Normal file
View file

@ -0,0 +1,123 @@
from glob import glob
import re
import argparse
import numpy as np
from pathlib import Path
import os
def raw_camparam_from_xml(path, pose="lookAt"):
import xml.etree.ElementTree as ET
tree = ET.parse(path)
elm = tree.find("./sensor/transform/" + pose)
camparam = elm.attrib
origin = np.fromstring(camparam['origin'], dtype=np.float32, sep=',')
target = np.fromstring(camparam['target'], dtype=np.float32, sep=',')
up = np.fromstring(camparam['up'], dtype=np.float32, sep=',')
height = int(
tree.find("./sensor/film/integer[@name='height']").attrib['value'])
width = int(
tree.find("./sensor/film/integer[@name='width']").attrib['value'])
camparam = dict()
camparam['origin'] = origin
camparam['up'] = up
camparam['target'] = target
camparam['height'] = height
camparam['width'] = width
return camparam
def get_cam_pos(origin, target, up):
inward = origin - target
right = np.cross(up, inward)
up = np.cross(inward, right)
rx = np.cross(up, inward)
ry = np.array(up)
rz = np.array(inward)
rx /= np.linalg.norm(rx)
ry /= np.linalg.norm(ry)
rz /= np.linalg.norm(rz)
rot = np.stack([
rx,
ry,
-rz
], axis=0)
aff = np.concatenate([
np.eye(3), -origin[:,None]
], axis=1)
ext = np.matmul(rot, aff)
result = np.concatenate(
[ext, np.array([[0,0,0,1]])], axis=0
)
return result
def convert_cam_params_all_views(datapoint_dir, dataroot, camera_param_dir):
depths = sorted(glob(os.path.join(datapoint_dir, '*depth.png')))
cam_ext = ['_'.join(re.sub(dataroot.strip('/'), camera_param_dir.strip('/'), f).split('_')[:-1])+'.xml' for f in depths]
for i, (f, pth) in enumerate(zip(cam_ext, depths)):
if not os.path.exists(f):
continue
params=raw_camparam_from_xml(f)
origin, target, up, width, height = params['origin'], params['target'], params['up'],\
params['width'], params['height']
ext_matrix = get_cam_pos(origin, target, up)
#####
diag = (0.036 ** 2 + 0.024 ** 2) ** 0.5
focal_length = 0.05
res = [480, 480]
h_relative = (res[1] / res[0])
sensor_width = np.sqrt(diag ** 2 / (1 + h_relative ** 2))
pix_size = sensor_width / res[0]
K = np.array([
[focal_length / pix_size, 0, (sensor_width / pix_size - 1) / 2],
[0, -focal_length / pix_size, (sensor_width * (res[1] / res[0]) / pix_size - 1) / 2],
[0, 0, 1]
])
np.savez(pth.split('depth.png')[0]+ 'cam_params.npz', extr=ext_matrix, intr=K)
def main(opt):
dataroot_dir = Path(opt.dataroot)
leaf_subdirs = []
for dirpath, dirnames, filenames in os.walk(dataroot_dir):
if (not dirnames) and opt.mitsuba_xml_root not in dirpath:
leaf_subdirs.append(dirpath)
for k, dir_ in enumerate(leaf_subdirs):
print('Processing dir {}/{}: {}'.format(k, len(leaf_subdirs), dir_))
convert_cam_params_all_views(dir_, opt.dataroot, opt.mitsuba_xml_root)
if __name__ == '__main__':
args = argparse.ArgumentParser()
args.add_argument('--dataroot', type=str, default='GenReData/')
args.add_argument('--mitsuba_xml_root', type=str, default='GenReData/genre-xml_v2')
opt = args.parse_args()
main(opt)

View file

@ -5,9 +5,7 @@ import os
import json
import random
import trimesh
import csv
from plyfile import PlyData, PlyElement
from glob import glob
def project_pc_to_image(points, resolution=64):
"""project point clouds into 2D image
@ -181,33 +179,3 @@ class GANdatasetPartNet(Dataset):
if __name__ == '__main__':
data_root = '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetPointCloud'
data_raw_root = '/viscam/u/alexzhou907/01DATA/shapenet/shapenet_dim32_sdf_pc'
pc_dataroot = '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k'
sn_root = '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2'
classes = 'car'
npoints = 2048
# from datasets.shapenet_data_pc import ShapeNet15kPointClouds
# pc_ds = ShapeNet15kPointClouds(root_dir=pc_dataroot,
# categories=[classes], split='train',
# tr_sample_size=npoints,
# te_sample_size=npoints,
# scale=1.,
# normalize_per_shape=False,
# normalize_std_per_axis=False,
# random_subsample=True)
train_ds = GANdatasetPartNet('test', pc_dataroot, data_raw_root, classes, npoints, np.array([0,0,0]),
np.array([1, 1, 1]))
d1 = train_ds[0]
real = d1['real']
raw = d1['raw']
m, s = d1['m'], d1['s']
x = (torch.cat([raw, real], dim=-1) * s + m).transpose(0,1)
write_ply(x.numpy(), 'x.ply')
pass

View file

@ -1,825 +0,0 @@
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import argparse
from model.unet import get_model
from torch.distributions import Normal
from utils.file_utils import *
from utils.visualize import *
from model.pvcnn_completion import PVCNN2Base
import torch.distributed as dist
from datasets.partnet import GANdatasetPartNet
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
'''
some utils
'''
def rotation_matrix(axis, theta):
"""
Return the rotation matrix associated with counterclockwise rotation about
the given axis by theta radians.
"""
axis = np.asarray(axis)
axis = axis / np.sqrt(np.dot(axis, axis))
a = np.cos(theta / 2.0)
b, c, d = -axis * np.sin(theta / 2.0)
aa, bb, cc, dd = a * a, b * b, c * c, d * d
bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
[2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
[2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
def rotate(vertices, faces):
'''
vertices: [numpoints, 3]
'''
M = rotation_matrix([0, 1, 0], np.pi / 2).transpose()
N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose()
K = rotation_matrix([0, 0, 1], np.pi).transpose()
v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]]
return v, f
def norm(v, f):
v = (v - v.min())/(v.max() - v.min()) - 0.5
return v, f
def getGradNorm(net):
pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters()))
gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters()))
return pNorm, gradNorm
def weights_init(m):
"""
xavier initialization
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1 and m.weight is not None:
torch.nn.init.xavier_normal_(m.weight)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_()
m.bias.data.fill_(0)
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
def get_dataset(data_root, data_raw_root, pc_dataroot, npoints, category):
train_ds = GANdatasetPartNet('train', data_root, category, npoints)
return train_ds
def get_dataloader(opt, train_dataset, test_dataset=None):
if opt.distribution_type == 'multi':
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=opt.world_size,
rank=opt.rank
)
if test_dataset is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(
test_dataset,
num_replicas=opt.world_size,
rank=opt.rank
)
else:
test_sampler = None
else:
train_sampler = None
test_sampler = None
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler,
shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True)
if test_dataset is not None:
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
else:
test_dataloader = None
return train_dataloader, test_dataloader, train_sampler, test_sampler
def train(gpu, opt, output_dir, noises_init):
set_seed(opt)
logger = setup_logging(output_dir)
if opt.distribution_type == 'multi':
should_diag = gpu==0
else:
should_diag = True
if should_diag:
outf_syn, = setup_output_subdirs(output_dir, 'syn')
if opt.distribution_type == 'multi':
if opt.dist_url == "env://" and opt.rank == -1:
opt.rank = int(os.environ["RANK"])
base_rank = opt.rank * opt.ngpus_per_node
opt.rank = base_rank + gpu
dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,
world_size=opt.world_size, rank=opt.rank)
opt.bs = int(opt.bs / opt.ngpus_per_node)
opt.workers = 0
opt.saveIter = int(opt.saveIter / opt.ngpus_per_node)
opt.diagIter = int(opt.diagIter / opt.ngpus_per_node)
opt.vizIter = int(opt.vizIter / opt.ngpus_per_node)
''' data '''
train_dataset = get_dataset(opt.data_root, opt.data_raw_root,opt.pc_dataroot, opt.npoints,opt.classes)
dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None)
'''
create networks
'''
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.distribution_type == 'multi': # Multiple processes, single GPU per process
def _transform_(m):
return nn.parallel.DistributedDataParallel(
m, device_ids=[gpu], output_device=gpu)
torch.cuda.set_device(gpu)
netE.cuda(gpu)
netE.multi_gpu_wrapper(_transform_)
elif opt.distribution_type == 'single':
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
elif gpu is not None:
torch.cuda.set_device(gpu)
netE = netE.cuda(gpu)
else:
raise ValueError('distribution_type = multi | single | None')
if should_diag:
logger.info(opt)
optimizer= optim.Adam(netE.parameters(), lr=opt.lrE, weight_decay=opt.e_decay, betas=(opt.beta1, 0.999))
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.e_gamma)
if opt.netE != '':
ckpt = torch.load(opt.netE)
netE.load_state_dict(ckpt['model_state'])
optimizer.load_state_dict(ckpt['optimizer_state'])
if opt.netE != '':
start_epoch = torch.load(opt.netE)['epoch'] + 1
else:
start_epoch = 0
for epoch in range(start_epoch, opt.niter):
if opt.distribution_type == 'multi':
train_sampler.set_epoch(epoch)
lr_scheduler.step(epoch)
for i, data in enumerate(dataloader):
x = data['real']
sv_x = data['raw']
sv_x = torch.cat([sv_x, x[:, :, opt.svpoints:]], dim=-1)
noises_batch = noises_init[data['idx']]
'''
train diffusion
'''
if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None):
sv_x = sv_x.cuda(gpu)
noises_batch = noises_batch.cuda(gpu)
elif opt.distribution_type == 'single':
sv_x = sv_x.cuda()
noises_batch = noises_batch.cuda()
loss = netE.get_loss_iter(sv_x, noises_batch).mean()
optimizer.zero_grad()
loss.backward()
netpNorm, netgradNorm = getGradNorm(netE)
if opt.grad_clip is not None:
torch.nn.utils.clip_grad_norm_(netE.parameters(), opt.grad_clip)
optimizer.step()
if i % opt.print_freq == 0 and should_diag:
logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, '
'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} '
.format(
epoch, opt.niter, i, len(dataloader),loss.item(),
netpNorm, netgradNorm,
))
if (epoch + 1) % opt.diagIter == 0 and should_diag:
logger.info('Diagnosis:')
x_range = [x.min().item(), x.max().item()]
kl_stats = netE.all_kl(sv_x)
logger.info(' [{:>3d}/{:>3d}] '
'x_range: [{:>10.4f}, {:>10.4f}], '
'total_bpd_b: {:>10.4f}, '
'terms_bpd: {:>10.4f}, '
'prior_bpd_b: {:>10.4f} '
'mse_bt: {:>10.4f} '
.format(
epoch, opt.niter,
*x_range,
kl_stats['total_bpd_b'].item(),
kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item()
))
if (epoch + 1) % opt.vizIter == 0 and should_diag:
logger.info('Generation: eval')
netE.eval()
m, s = torch.tensor([[0,0,0]]).transpose(0,1)[None], torch.tensor([[1,1,1]]).transpose(0,1)[None]#train_dataset.get_pc_stats(0)
with torch.no_grad():
x_gen_eval = netE.gen_samples(sv_x[:,:,:opt.svpoints], sv_x[:,:,opt.svpoints:].shape, sv_x.device, clip_denoised=False).detach().cpu()
gen_stats = [x_gen_eval.mean(), x_gen_eval.std()]
gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()]
logger.info(' [{:>3d}/{:>3d}] '
'eval_gen_range: [{:>10.4f}, {:>10.4f}] '
'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] '
.format(
epoch, opt.niter,
*gen_eval_range, *gen_stats,
))
export_to_pc_batch('%s/epoch_%03d_samples_eval' % (outf_syn, epoch),
(x_gen_eval*s+m).transpose(1, 2).numpy()*3)
export_to_pc_batch('%s/epoch_%03d_ground_truth' % (outf_syn, epoch),
(sv_x.detach().cpu()*s+m).transpose(1, 2).numpy()*3)
export_to_pc_batch('%s/epoch_%03d_partial' % (outf_syn, epoch),
(sv_x[:,:,:opt.svpoints].detach().cpu()*s+m).transpose(1, 2).numpy()*3)
netE.train()
if (epoch + 1) % opt.saveIter == 0:
if should_diag:
save_dict = {
'epoch': epoch,
'model_state': netE.state_dict(),
'optimizer_state': optimizer.state_dict()
}
torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch))
if opt.distribution_type == 'multi':
dist.barrier()
map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu}
netE.load_state_dict(
torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state'])
dist.destroy_process_group()
def main():
opt = parse_args()
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
''' workaround '''
train_dataset = get_dataset(opt.data_root, opt.data_raw_root, opt.pc_dataroot, opt.npoints,opt.classes)
noises_init = torch.randn(len(train_dataset), opt.nc, opt.npoints-opt.svpoints)
if opt.dist_url == "env://" and opt.world_size == -1:
opt.world_size = int(os.environ["WORLD_SIZE"])
if opt.distribution_type == 'multi':
opt.ngpus_per_node = torch.cuda.device_count()
opt.world_size = opt.ngpus_per_node * opt.world_size
mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init))
else:
train(opt.gpu, opt, output_dir, noises_init)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data_root', default='/home/ubuntu/01DATA/partnet/data_v0', help='input batch size')
parser.add_argument('--data_raw_root', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_dim32_sdf_pc',
help='input batch size')
parser.add_argument('--pc_dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k',
help='input batch size')
parser.add_argument('--classes', default='Chair')
parser.add_argument('--bs', type=int, default=64, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=1024)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
parser.add_argument('--lrE', type=float, default=2e-4, help='learning rate for E, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--e_decay', type=float, default=0, help='weight decay for EBM')
parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM')
parser.add_argument('--e_gamma', type=float, default=0.998, help='lr decay for EBM')
parser.add_argument('--netE', default='', help="path to netE (to continue training)")
'''distributed'''
parser.add_argument('--world_size', default=1, type=int,
help='Number of distributed nodes.')
parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist_backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None],
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
parser.add_argument('--rank', default=0, type=int,
help='node rank for distributed training')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use. None means using all available GPUs.')
'''eval'''
parser.add_argument('--saveIter', default=100, type=int,help='unit: epoch')
parser.add_argument('--diagIter', default=50, type=int, help='unit: epoch')
parser.add_argument('--vizIter', default=50, type=int,help='unit: epoch')
parser.add_argument('--print_freq', default=50, type=int,help='unit: iter')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
opt = parser.parse_args()
return opt
if __name__ == '__main__':
main()

View file

@ -1,825 +0,0 @@
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import argparse
from model.unet import get_model
from torch.distributions import Normal
from utils.file_utils import *
from utils.visualize import *
from model.pvcnn_completion import PVCNN2Base
import torch.distributed as dist
from datasets.partnet import GANdatasetPartNet
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
'''
some utils
'''
def rotation_matrix(axis, theta):
"""
Return the rotation matrix associated with counterclockwise rotation about
the given axis by theta radians.
"""
axis = np.asarray(axis)
axis = axis / np.sqrt(np.dot(axis, axis))
a = np.cos(theta / 2.0)
b, c, d = -axis * np.sin(theta / 2.0)
aa, bb, cc, dd = a * a, b * b, c * c, d * d
bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
[2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
[2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
def rotate(vertices, faces):
'''
vertices: [numpoints, 3]
'''
M = rotation_matrix([0, 1, 0], np.pi / 2).transpose()
N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose()
K = rotation_matrix([0, 0, 1], np.pi).transpose()
v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]]
return v, f
def norm(v, f):
v = (v - v.min())/(v.max() - v.min()) - 0.5
return v, f
def getGradNorm(net):
pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters()))
gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters()))
return pNorm, gradNorm
def weights_init(m):
"""
xavier initialization
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1 and m.weight is not None:
torch.nn.init.xavier_normal_(m.weight)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_()
m.bias.data.fill_(0)
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
def get_dataset(data_root, data_raw_root, pc_dataroot, npoints, category):
train_ds = GANdatasetPartNet('train', data_root, category, npoints)
return train_ds
def get_dataloader(opt, train_dataset, test_dataset=None):
if opt.distribution_type == 'multi':
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=opt.world_size,
rank=opt.rank
)
if test_dataset is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(
test_dataset,
num_replicas=opt.world_size,
rank=opt.rank
)
else:
test_sampler = None
else:
train_sampler = None
test_sampler = None
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler,
shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True)
if test_dataset is not None:
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
else:
test_dataloader = None
return train_dataloader, test_dataloader, train_sampler, test_sampler
def train(gpu, opt, output_dir, noises_init):
set_seed(opt)
logger = setup_logging(output_dir)
if opt.distribution_type == 'multi':
should_diag = gpu==0
else:
should_diag = True
if should_diag:
outf_syn, = setup_output_subdirs(output_dir, 'syn')
if opt.distribution_type == 'multi':
if opt.dist_url == "env://" and opt.rank == -1:
opt.rank = int(os.environ["RANK"])
base_rank = opt.rank * opt.ngpus_per_node
opt.rank = base_rank + gpu
dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,
world_size=opt.world_size, rank=opt.rank)
opt.bs = int(opt.bs / opt.ngpus_per_node)
opt.workers = 0
opt.saveIter = int(opt.saveIter / opt.ngpus_per_node)
opt.diagIter = int(opt.diagIter / opt.ngpus_per_node)
opt.vizIter = int(opt.vizIter / opt.ngpus_per_node)
''' data '''
train_dataset = get_dataset(opt.data_root, opt.data_raw_root,opt.pc_dataroot, opt.npoints,opt.classes)
dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None)
'''
create networks
'''
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.distribution_type == 'multi': # Multiple processes, single GPU per process
def _transform_(m):
return nn.parallel.DistributedDataParallel(
m, device_ids=[gpu], output_device=gpu)
torch.cuda.set_device(gpu)
netE.cuda(gpu)
netE.multi_gpu_wrapper(_transform_)
elif opt.distribution_type == 'single':
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
elif gpu is not None:
torch.cuda.set_device(gpu)
netE = netE.cuda(gpu)
else:
raise ValueError('distribution_type = multi | single | None')
if should_diag:
logger.info(opt)
optimizer= optim.Adam(netE.parameters(), lr=opt.lrE, weight_decay=opt.e_decay, betas=(opt.beta1, 0.999))
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.e_gamma)
if opt.netE != '':
ckpt = torch.load(opt.netE)
netE.load_state_dict(ckpt['model_state'])
optimizer.load_state_dict(ckpt['optimizer_state'])
if opt.netE != '':
start_epoch = torch.load(opt.netE)['epoch'] + 1
else:
start_epoch = 0
for epoch in range(start_epoch, opt.niter):
if opt.distribution_type == 'multi':
train_sampler.set_epoch(epoch)
lr_scheduler.step(epoch)
for i, data in enumerate(dataloader):
x = data['real']
sv_x = data['raw']
sv_x = torch.cat([sv_x, x[:, :, opt.svpoints:]], dim=-1)
noises_batch = noises_init[data['idx']]
'''
train diffusion
'''
if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None):
sv_x = sv_x.cuda(gpu)
noises_batch = noises_batch.cuda(gpu)
elif opt.distribution_type == 'single':
sv_x = sv_x.cuda()
noises_batch = noises_batch.cuda()
loss = netE.get_loss_iter(sv_x, noises_batch).mean()
optimizer.zero_grad()
loss.backward()
netpNorm, netgradNorm = getGradNorm(netE)
if opt.grad_clip is not None:
torch.nn.utils.clip_grad_norm_(netE.parameters(), opt.grad_clip)
optimizer.step()
if i % opt.print_freq == 0 and should_diag:
logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, '
'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} '
.format(
epoch, opt.niter, i, len(dataloader),loss.item(),
netpNorm, netgradNorm,
))
if (epoch + 1) % opt.diagIter == 0 and should_diag:
logger.info('Diagnosis:')
x_range = [x.min().item(), x.max().item()]
kl_stats = netE.all_kl(sv_x)
logger.info(' [{:>3d}/{:>3d}] '
'x_range: [{:>10.4f}, {:>10.4f}], '
'total_bpd_b: {:>10.4f}, '
'terms_bpd: {:>10.4f}, '
'prior_bpd_b: {:>10.4f} '
'mse_bt: {:>10.4f} '
.format(
epoch, opt.niter,
*x_range,
kl_stats['total_bpd_b'].item(),
kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item()
))
if (epoch + 1) % opt.vizIter == 0 and should_diag:
logger.info('Generation: eval')
netE.eval()
m, s = torch.tensor([[0,0,0]]).transpose(0,1)[None], torch.tensor([[1,1,1]]).transpose(0,1)[None]#train_dataset.get_pc_stats(0)
with torch.no_grad():
x_gen_eval = netE.gen_samples(sv_x[:,:,:opt.svpoints], sv_x[:,:,opt.svpoints:].shape, sv_x.device, clip_denoised=False).detach().cpu()
gen_stats = [x_gen_eval.mean(), x_gen_eval.std()]
gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()]
logger.info(' [{:>3d}/{:>3d}] '
'eval_gen_range: [{:>10.4f}, {:>10.4f}] '
'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] '
.format(
epoch, opt.niter,
*gen_eval_range, *gen_stats,
))
export_to_pc_batch('%s/epoch_%03d_samples_eval' % (outf_syn, epoch),
(x_gen_eval*s+m).transpose(1, 2).numpy()*3)
export_to_pc_batch('%s/epoch_%03d_ground_truth' % (outf_syn, epoch),
(sv_x.detach().cpu()*s+m).transpose(1, 2).numpy()*3)
export_to_pc_batch('%s/epoch_%03d_partial' % (outf_syn, epoch),
(sv_x[:,:,:opt.svpoints].detach().cpu()*s+m).transpose(1, 2).numpy()*3)
netE.train()
if (epoch + 1) % opt.saveIter == 0:
if should_diag:
save_dict = {
'epoch': epoch,
'model_state': netE.state_dict(),
'optimizer_state': optimizer.state_dict()
}
torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch))
if opt.distribution_type == 'multi':
dist.barrier()
map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu}
netE.load_state_dict(
torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state'])
dist.destroy_process_group()
def main():
opt = parse_args()
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
''' workaround '''
train_dataset = get_dataset(opt.data_root, opt.data_raw_root, opt.pc_dataroot, opt.npoints,opt.classes)
noises_init = torch.randn(len(train_dataset), opt.nc, opt.npoints-opt.svpoints)
if opt.dist_url == "env://" and opt.world_size == -1:
opt.world_size = int(os.environ["WORLD_SIZE"])
if opt.distribution_type == 'multi':
opt.ngpus_per_node = torch.cuda.device_count()
opt.world_size = opt.ngpus_per_node * opt.world_size
mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init))
else:
train(opt.gpu, opt, output_dir, noises_init)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data_root', default='/home/ubuntu/01DATA/partnet/', help='input batch size')
parser.add_argument('--data_raw_root', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_dim32_sdf_pc',
help='input batch size')
parser.add_argument('--pc_dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k',
help='input batch size')
parser.add_argument('--classes', default='Table')
parser.add_argument('--bs', type=int, default=64, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=1024)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
parser.add_argument('--lrE', type=float, default=2e-4, help='learning rate for E, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--e_decay', type=float, default=0, help='weight decay for EBM')
parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM')
parser.add_argument('--e_gamma', type=float, default=0.998, help='lr decay for EBM')
parser.add_argument('--netE', default='', help="path to netE (to continue training)")
'''distributed'''
parser.add_argument('--world_size', default=1, type=int,
help='Number of distributed nodes.')
parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist_backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None],
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
parser.add_argument('--rank', default=0, type=int,
help='node rank for distributed training')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use. None means using all available GPUs.')
'''eval'''
parser.add_argument('--saveIter', default=100, type=int,help='unit: epoch')
parser.add_argument('--diagIter', default=50, type=int, help='unit: epoch')
parser.add_argument('--vizIter', default=50, type=int,help='unit: epoch')
parser.add_argument('--print_freq', default=50, type=int,help='unit: iter')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
opt = parser.parse_args()
return opt
if __name__ == '__main__':
main()

View file

@ -1,822 +0,0 @@
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import argparse
from model.unet import get_model
from torch.distributions import Normal
from utils.file_utils import *
from utils.visualize import *
from model.pvcnn_completion import PVCNN2Base
import torch.distributed as dist
from datasets.partnet import GANdatasetPartNet
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
'''
some utils
'''
def rotation_matrix(axis, theta):
"""
Return the rotation matrix associated with counterclockwise rotation about
the given axis by theta radians.
"""
axis = np.asarray(axis)
axis = axis / np.sqrt(np.dot(axis, axis))
a = np.cos(theta / 2.0)
b, c, d = -axis * np.sin(theta / 2.0)
aa, bb, cc, dd = a * a, b * b, c * c, d * d
bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
[2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
[2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
def rotate(vertices, faces):
'''
vertices: [numpoints, 3]
'''
M = rotation_matrix([0, 1, 0], np.pi / 2).transpose()
N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose()
K = rotation_matrix([0, 0, 1], np.pi).transpose()
v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]]
return v, f
def norm(v, f):
v = (v - v.min())/(v.max() - v.min()) - 0.5
return v, f
def getGradNorm(net):
pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters()))
gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters()))
return pNorm, gradNorm
def weights_init(m):
"""
xavier initialization
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1 and m.weight is not None:
torch.nn.init.xavier_normal_(m.weight)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_()
m.bias.data.fill_(0)
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
def get_dataset(data_root, npoints, category):
train_ds = GANdatasetPartNet('train', data_root, category, npoints)
return train_ds
def get_dataloader(opt, train_dataset, test_dataset=None):
if opt.distribution_type == 'multi':
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=opt.world_size,
rank=opt.rank
)
if test_dataset is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(
test_dataset,
num_replicas=opt.world_size,
rank=opt.rank
)
else:
test_sampler = None
else:
train_sampler = None
test_sampler = None
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler,
shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True)
if test_dataset is not None:
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
else:
test_dataloader = None
return train_dataloader, test_dataloader, train_sampler, test_sampler
def train(gpu, opt, output_dir, noises_init):
set_seed(opt)
logger = setup_logging(output_dir)
if opt.distribution_type == 'multi':
should_diag = gpu==0
else:
should_diag = True
if should_diag:
outf_syn, = setup_output_subdirs(output_dir, 'syn')
if opt.distribution_type == 'multi':
if opt.dist_url == "env://" and opt.rank == -1:
opt.rank = int(os.environ["RANK"])
base_rank = opt.rank * opt.ngpus_per_node
opt.rank = base_rank + gpu
dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,
world_size=opt.world_size, rank=opt.rank)
opt.bs = int(opt.bs / opt.ngpus_per_node)
opt.workers = 0
opt.saveIter = int(opt.saveIter / opt.ngpus_per_node)
opt.diagIter = int(opt.diagIter / opt.ngpus_per_node)
opt.vizIter = int(opt.vizIter / opt.ngpus_per_node)
''' data '''
train_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes)
dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None)
'''
create networks
'''
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.distribution_type == 'multi': # Multiple processes, single GPU per process
def _transform_(m):
return nn.parallel.DistributedDataParallel(
m, device_ids=[gpu], output_device=gpu)
torch.cuda.set_device(gpu)
netE.cuda(gpu)
netE.multi_gpu_wrapper(_transform_)
elif opt.distribution_type == 'single':
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
elif gpu is not None:
torch.cuda.set_device(gpu)
netE = netE.cuda(gpu)
else:
raise ValueError('distribution_type = multi | single | None')
if should_diag:
logger.info(opt)
optimizer= optim.Adam(netE.parameters(), lr=opt.lrE, weight_decay=opt.e_decay, betas=(opt.beta1, 0.999))
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.e_gamma)
if opt.netE != '':
ckpt = torch.load(opt.netE)
netE.load_state_dict(ckpt['model_state'])
optimizer.load_state_dict(ckpt['optimizer_state'])
if opt.netE != '':
start_epoch = torch.load(opt.netE)['epoch'] + 1
else:
start_epoch = 0
for epoch in range(start_epoch, opt.niter):
if opt.distribution_type == 'multi':
train_sampler.set_epoch(epoch)
lr_scheduler.step(epoch)
for i, data in enumerate(dataloader):
x = data['real']
sv_x = data['raw']
sv_x = torch.cat([sv_x, x[:, :, opt.svpoints:]], dim=-1)
noises_batch = noises_init[data['idx']]
'''
train diffusion
'''
if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None):
sv_x = sv_x.cuda(gpu)
noises_batch = noises_batch.cuda(gpu)
elif opt.distribution_type == 'single':
sv_x = sv_x.cuda()
noises_batch = noises_batch.cuda()
loss = netE.get_loss_iter(sv_x, noises_batch).mean()
optimizer.zero_grad()
loss.backward()
netpNorm, netgradNorm = getGradNorm(netE)
if opt.grad_clip is not None:
torch.nn.utils.clip_grad_norm_(netE.parameters(), opt.grad_clip)
optimizer.step()
if i % opt.print_freq == 0 and should_diag:
logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, '
'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} '
.format(
epoch, opt.niter, i, len(dataloader),loss.item(),
netpNorm, netgradNorm,
))
if (epoch + 1) % opt.diagIter == 0 and should_diag:
logger.info('Diagnosis:')
x_range = [x.min().item(), x.max().item()]
kl_stats = netE.all_kl(sv_x)
logger.info(' [{:>3d}/{:>3d}] '
'x_range: [{:>10.4f}, {:>10.4f}], '
'total_bpd_b: {:>10.4f}, '
'terms_bpd: {:>10.4f}, '
'prior_bpd_b: {:>10.4f} '
'mse_bt: {:>10.4f} '
.format(
epoch, opt.niter,
*x_range,
kl_stats['total_bpd_b'].item(),
kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item()
))
if (epoch + 1) % opt.vizIter == 0 and should_diag:
logger.info('Generation: eval')
netE.eval()
m, s = torch.tensor([[0,0,0]]).transpose(0,1)[None], torch.tensor([[1,1,1]]).transpose(0,1)[None]#train_dataset.get_pc_stats(0)
with torch.no_grad():
x_gen_eval = netE.gen_samples(sv_x[:,:,:opt.svpoints], sv_x[:,:,opt.svpoints:].shape, sv_x.device, clip_denoised=False).detach().cpu()
gen_stats = [x_gen_eval.mean(), x_gen_eval.std()]
gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()]
logger.info(' [{:>3d}/{:>3d}] '
'eval_gen_range: [{:>10.4f}, {:>10.4f}] '
'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] '
.format(
epoch, opt.niter,
*gen_eval_range, *gen_stats,
))
export_to_pc_batch('%s/epoch_%03d_samples_eval' % (outf_syn, epoch),
(x_gen_eval*s+m).transpose(1, 2).numpy()*3)
export_to_pc_batch('%s/epoch_%03d_ground_truth' % (outf_syn, epoch),
(sv_x.detach().cpu()*s+m).transpose(1, 2).numpy()*3)
export_to_pc_batch('%s/epoch_%03d_partial' % (outf_syn, epoch),
(sv_x[:,:,:opt.svpoints].detach().cpu()*s+m).transpose(1, 2).numpy()*3)
netE.train()
if (epoch + 1) % opt.saveIter == 0:
if should_diag:
save_dict = {
'epoch': epoch,
'model_state': netE.state_dict(),
'optimizer_state': optimizer.state_dict()
}
torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch))
if opt.distribution_type == 'multi':
dist.barrier()
map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu}
netE.load_state_dict(
torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state'])
dist.destroy_process_group()
def main():
opt = parse_args()
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
''' workaround '''
train_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes)
noises_init = torch.randn(len(train_dataset), opt.nc, opt.npoints-opt.svpoints)
if opt.dist_url == "env://" and opt.world_size == -1:
opt.world_size = int(os.environ["WORLD_SIZE"])
if opt.distribution_type == 'multi':
opt.ngpus_per_node = torch.cuda.device_count()
opt.world_size = opt.ngpus_per_node * opt.world_size
mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init))
else:
train(opt.gpu, opt, output_dir, noises_init)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data_root', default='/viscam/u/alexzhou907/01DATA/partnet/', help='input batch size')
parser.add_argument('--classes', default='Table')
parser.add_argument('--bs', type=int, default=64, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=1024)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
parser.add_argument('--lrE', type=float, default=2e-4, help='learning rate for E, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--e_decay', type=float, default=0, help='weight decay for EBM')
parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM')
parser.add_argument('--e_gamma', type=float, default=0.998, help='lr decay for EBM')
parser.add_argument('--netE', default='', help="path to netE (to continue training)")
'''distributed'''
parser.add_argument('--world_size', default=1, type=int,
help='Number of distributed nodes.')
parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist_backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None],
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
parser.add_argument('--rank', default=0, type=int,
help='node rank for distributed training')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use. None means using all available GPUs.')
'''eval'''
parser.add_argument('--saveIter', default=100, type=int,help='unit: epoch')
parser.add_argument('--diagIter', default=50, type=int, help='unit: epoch')
parser.add_argument('--vizIter', default=50, type=int,help='unit: epoch')
parser.add_argument('--print_freq', default=50, type=int,help='unit: iter')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
opt = parser.parse_args()
return opt
if __name__ == '__main__':
main()

View file

@ -1,660 +0,0 @@
from pprint import pprint
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.visualize import *
from utils.file_utils import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
def interpolate(self, x0_part, x1_part, x0_sv, x1_sv, t, lamb, denoise_fn, noise_fn=torch.randn):
assert t >= 1
t_vec = torch.empty(x0_part.shape[0], dtype=torch.int64, device=x0_part.device).fill_(t-1)
encoding0 = self.q_sample(x0_part, t_vec)
encoding1 = self.q_sample(x1_part, t_vec)
enc = encoding0 * (1-lamb) + (lamb) * encoding1
img_t = torch.cat([torch.cat([x0_sv[:,:,:int(self.sv_points*(1-lamb))], x1_sv[:,:,:(self.sv_points - int(self.sv_points*(1-lamb)))]], dim=-1), enc], dim=-1)
for k in reversed(range(0,t)):
t_ = torch.empty(img_t.shape[0], dtype=torch.int64, device=img_t.device).fill_(k)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=False, return_pred_xstart=False).detach()
return img_t
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def interpolate(self, x0_part, x1_part, x0_sv, x1_sv, t, lamb):
return self.diffusion.interpolate(x0_part, x1_part, x0_sv, x1_sv, t, lamb, self._denoise)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
#############################################################################
def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
cache=os.path.join(mesh_root, '../cache'), split='val',
categories=category,
radius=3, elev=-89, azim=180, img_size=512, focal_length=1000,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def get_mvr_dataset(pc_dataroot, views_root, npoints,category, get_image=True):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
cache=os.path.join(pc_dataroot, '../cache'), split='val',
categories=category,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std, get_image=get_image,
)
return te_dataset
def generate_multimodal(opt, netE, save_dir, logger):
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
if i!=3:
continue
gt_all = data['test_points']
x_all = data['sv_points']
img_all = data['image']
for v in range(20):
recons = []
svs = []
for p in [0,1]:
x = x_all[:,p].transpose(1, 2).contiguous()
img = img_all[:,p]
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
recons.append(recon)
svs.append(x[:, :opt.svpoints,:])
for l, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))):
# im = np.fliplr(np.flipud(d[-1]))
write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,l), 'mode_%03d'%v, 'depth_%d'%p),
(torch.stack(d[:-1], dim=0)* s[0] + m[0]).numpy(), cat='chair')
export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,l), 'mode_%03d'%v, 'depth_%d'%p),
(torch.stack(d[:-1], dim=0)* s[0] + m[0]).numpy())
plt.imsave(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, l), 'mode_%03d' % v, 'depth_%d.png' % p),
d[-1].permute(1, 2, 0), cmap='gray')
x0_part = recons[0].transpose(1, 2).contiguous()[:,:,opt.svpoints:].cuda()
x1_part = recons[1].transpose(1, 2).contiguous()[:,:,opt.svpoints:].cuda()
x0_sv = svs[0].transpose(1,2).cuda()
x1_sv = svs[1].transpose(1,2).cuda()
interres = []
for lamb in np.linspace(0.1, 0.9, 5):
res = netE.interpolate(x0_part, x1_part, x0_sv, x1_sv, 1000, lamb)
res = torch.cat([x0_sv, x1_sv, res[:,:,opt.svpoints:]], dim=-1).detach().cpu().transpose(1,2).contiguous()
interres.append(res)
for l, d in enumerate(torch.stack(interres, dim=1)):
write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d' % (i, l), 'mode_%03d' % v),
(d* s[0] + m[0]).numpy(), cat='chair')
export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, l), 'mode_%03d' % v),
(d * s[0] + m[0]).numpy())
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
opt.netE = ckpt
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
generate_multimodal(opt, netE, outf_syn, logger)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
help='input batch size')
parser.add_argument('--classes', default=['chair'])
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--eval_recon_mvr', default=False)
parser.add_argument('--generate_multimodal', default=False)
parser.add_argument('--eval_saved', default=False)
parser.add_argument('--eval_redwood', default=True)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=200)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=0.2)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/6_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-11-01-19-21-18', help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
parser.add_argument('--manualSeed', default=None, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
set_seed(opt)
main(opt)

View file

@ -1,706 +0,0 @@
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.visualize import *
from utils.file_utils import *
from utils.mitsuba_renderer2 import write_to_xml_batch, write_to_xml
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory2(self, partial_x, denoise_fn, shape, device, num_save,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
scale = np.exp(np.log(1/total_steps)/num_save)
save_step = total_steps
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
imgs = [img_t.detach().cpu()]
for t in reversed(range(0,total_steps)):
if (t+1) == save_step and t > 0 and len(imgs)<num_save:
imgs.append(img_t.detach().cpu())
save_step = int(save_step * scale)
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
imgs.append(img_t.detach().cpu())
assert imgs[-1][:,:,self.sv_points:].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def gen_samples_traj2(self, partial_x, shape, device, noise_fn=torch.randn, num_save=20,
clip_denoised=False,
keep_running=False):
return self.diffusion.p_sample_loop_trajectory2(partial_x, self._denoise, shape=shape, device=device, num_save=num_save, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
#############################################################################
def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
cache=os.path.join(mesh_root, '../cache'), split='val',
categories=category,
radius=3, elev=-89, azim=180, img_size=512, focal_length=1000,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
cache=os.path.join(pc_dataroot, '../cache'), split='val',
categories=category,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def generate_video(netE, opt, save_dir):
test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
# gt_all = data['test_points']
x_all = data['sv_points']
img_all = data['image']
export_to_pc_batch(
os.path.join(save_dir, 'batch_%03d_ply' % i), x_all[:, :opt.svpoints, :].numpy())
write_to_xml_batch(os.path.join(save_dir, 'batch_%03d' % i),
x_all[:, :opt.svpoints, :].numpy(), cat='chair')
for v in range(6):
x = x_all.transpose(1, 2).contiguous()
img = img_all
m, s = data['mean'].float(), data['std'].float()
gen_all = netE.gen_samples_traj2(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', num_save=55,
clip_denoised=False)
gen_all = torch.stack(gen_all, dim=1).detach().cpu()
gen_all = gen_all.transpose(2, 3).contiguous()
gen_all = gen_all * s[:, None] + m[:, None]
x = x.transpose(1, 2).contiguous()
for p, d in enumerate(zip(list(gen_all), list(img))):
im = np.fliplr(np.flipud(d[-1]))
gen = d[0]
plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray')
# gen
write_to_xml_batch(os.path.join(save_dir, 'batch_%03d'%i, 'sample_%03d/mode_%03d/xml/gen_process/' % (p,v)),
gen.numpy(), cat='chair')
for p, gen in enumerate(gen_all[:,-1]):
Path(os.path.join(save_dir, 'batch_%03d_ply' % i, 'sample_%03d' % p,
'mode_%03d' % v)).mkdir(parents=True, exist_ok=True)
pcwrite(
os.path.join(save_dir, 'batch_%03d_ply' % i, 'sample_%03d/mode_%03d/partial.ply' % (p,v)), gen.numpy())
for k, pcl in enumerate(gen_all[:, -1].cpu().numpy()):
dir_ = os.path.join(save_dir, 'batch_%03d' % i, 'sample_%03d/mode_%03d/xml/rotate_final/' % (k, v))
Path(dir_).mkdir(parents=True, exist_ok=True)
for azim in np.linspace(45, 405 - (360 / 50), 50):
write_to_xml(
os.path.join(dir_, 'azim_%03d.xml' % azim),
pcl, cat='chair', elev=19.471, azim=azim)
def generate_video_redwood(netE, opt, save_dir):
import open3d as o3d
pth = "/viscam/u/alexzhou907/research/diffusion/redwood/09620_pc_partial.ply"
pth_gt = "/viscam/u/alexzhou907/research/diffusion/redwood/09620_pc.ply"
points = np.asarray(o3d.io.read_point_cloud(pth).points)
gt_points = np.asarray(o3d.io.read_point_cloud(pth_gt).points)
np.save('gt.npy', gt_points)
test_dataset = ShapeNet15kPointClouds(root_dir=opt.dataroot_pc,
categories=opt.classes, split='train',
tr_sample_size=opt.npoints,
te_sample_size=opt.npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
m, s = torch.from_numpy(test_dataset[0]['mean']).float(), torch.from_numpy(test_dataset[0]['std']).float()
x = torch.from_numpy(points[np.random.choice(points.shape[0], size=opt.svpoints, replace=False)]).float()
x = (x - m) / s
x = x[None].transpose(1, 2).cuda()
shape = list(x.shape)
shape[-1] = opt.npoints - shape[-1]
res = []
for v in tqdm(range(20)):
gen_all = netE.gen_samples_traj2(x.cuda(), torch.Size(shape), 'cuda', num_save=55,
clip_denoised=False)
gen_all = torch.stack(gen_all, dim=1).detach().cpu()
gen_all = gen_all.transpose(2, 3).contiguous()
gen_all = gen_all * s[:, None] + m[:, None]
res.append(gen_all[:, -1].cpu())
for p, gen in enumerate(gen_all):
# gen
write_to_xml_batch(
os.path.join(save_dir, 'mode_%03d/xml/gen_process/' % ( v)),
gen.numpy(), cat='chair')
for k, pcl in enumerate(gen_all[:, -1].cpu().numpy()):
dir_ = os.path.join(save_dir, 'mode_%03d/xml/rotate_final/' % ( v))
Path(dir_).mkdir(parents=True, exist_ok=True)
for azim in np.linspace(45, 405 - (360 / 50), 50):
write_to_xml(
os.path.join(dir_, 'azim_%03d.xml' % azim),
pcl, cat='chair', elev=19.471, azim=azim)
pcwrite(os.path.join(save_dir, 'mode_%03d.ply'%v), gen_all[:, -1].cpu().numpy()[0])
pcwrite(os.path.join(save_dir, 'gt.ply'),
gt_points[np.random.choice(gt_points.shape[0], size=opt.npoints, replace=False)])
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
opt.netE = ckpt
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
generate_video_redwood( netE,opt, outf_syn)
exit()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
help='input batch size')
parser.add_argument('--classes', default=['chair'])
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--eval_recon_mvr', default=False)
parser.add_argument('--generate_multimodal', default=False)
parser.add_argument('--eval_saved', default=False)
parser.add_argument('--eval_redwood', default=True)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=200)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=0.2)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/6_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-11-01-19-21-18', help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
main(opt)

View file

@ -1,681 +0,0 @@
from pprint import pprint
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.file_utils import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
#############################################################################
def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
cache=os.path.join(mesh_root, '../cache'), split='val',
categories=category,
radius=3, elev=0, azim=0, img_size=512, focal_length=1000,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
cache=os.path.join(pc_dataroot, '../cache'), split='val',
categories=category,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def evaluate_recon_mvr(opt, netE, save_dir, logger):
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
ref = []
samples = []
images = []
masked = []
k = 0
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
img_all = data['image']
B,V,N,C = x_all.shape
gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
img = img_all.reshape(B * V, *img_all.shape[2:])
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
x_adj = x.reshape(B,V,N,C)* s + m
recon_adj = recon.reshape(B,V,N,C)* s + m
img = img.reshape(B,V,*img.shape[1:])
ref.append( gt_all * s + m)
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
samples.append(recon_adj)
images.append(img)
ref_pcs = torch.cat(ref, dim=0)
sample_pcs = torch.cat(samples, dim=0)
images = torch.cat(images, dim=0)
masked = torch.cat(masked, dim=0)
B, V, N, C = ref_pcs.shape
torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
# torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
# Compute metrics
results = EMD_CD(sample_pcs.reshape(B*V, N, C),
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
pprint({key: val.mean().item() for key, val in results.items()})
logger.info({key: val.mean().item() for key, val in results.items()})
results['pc'] = sample_pcs
torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
def generate_multimodal(opt, netE, save_dir, logger):
test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
img_all = data['image']
Path(os.path.join(save_dir, 'x_%03d'%k)).mkdir(exist_ok=True, parents=True)
Path(os.path.join(save_dir, 'x_ply_%03d' % k)).mkdir(exist_ok=True, parents=True)
for v in range(5):
x = x_all.transpose(1, 2).contiguous()
img = img_all
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))):
im = np.fliplr(np.flipud(d[-1]))
plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray')
write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
opt.netE = ckpt
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
if opt.eval_recon_mvr:
# Evaluate generation
evaluate_recon_mvr(opt, netE, outf_syn, logger)
if opt.generate_multimodal:
generate_multimodal(opt, netE, outf_syn, logger)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
help='input batch size')
parser.add_argument('--classes', default=['car'])
parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--eval_recon_mvr', default=True)
parser.add_argument('--generate_multimodal', default=False)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=200)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=0.2)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/3_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-11-01-03-08-40', help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
set_seed(opt)
main(opt)

View file

@ -1,753 +0,0 @@
from pprint import pprint
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.visualize import *
from utils.file_utils import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
#############################################################################
def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
cache=os.path.join(mesh_root, '../cache'), split='val',
categories=category,
radius=3, elev=-89, azim=180, img_size=512, focal_length=1000,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
cache=os.path.join(pc_dataroot, '../cache'), split='val',
categories=category,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def evaluate_recon_mvr(opt, netE, save_dir, logger):
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
ref = []
samples = []
images = []
masked = []
k = 0
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
# img_all = data['image']
B,V,N,C = x_all.shape
gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
# img = img_all.reshape(B * V, *img_all.shape[2:])
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
x_adj = x.reshape(B,V,N,C)* s + m
recon_adj = recon.reshape(B,V,N,C)* s + m
# img = img.reshape(B,V,*img.shape[1:])
ref.append( gt_all * s + m)
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
samples.append(recon_adj)
# images.append(img)
ref_pcs = torch.cat(ref, dim=0)
sample_pcs = torch.cat(samples, dim=0)
# images = torch.cat(images, dim=0)
masked = torch.cat(masked, dim=0)
B, V, N, C = ref_pcs.shape
torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
# torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
# Compute metrics
results = EMD_CD(sample_pcs.reshape(B*V, N, C),
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
pprint({key: val.mean().item() for key, val in results.items()})
logger.info({key: val.mean().item() for key, val in results.items()})
results['pc'] = sample_pcs
torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
del ref_pcs, masked, results
def evaluate_saved(opt, netE, save_dir, logger):
ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn'
gt_pth = ours_base + '/recon_gt.pth'
ours_pth = ours_base + '/ours_results.pth'
gt = torch.load(gt_pth).permute(1,0,2,3)
ours = torch.load(ours_pth)['pc'].permute(1,0,2,3)
all_res = {}
for i, (gt_, ours_) in enumerate(zip(gt, ours)):
results = compute_all_metrics(gt_, ours_, opt.batch_size)
for key, val in results.items():
if i == 0:
all_res[key] = val
else:
all_res[key] += val
pprint(results)
for key, val in all_res.items():
all_res[key] = val / gt.shape[0]
pprint({key: val.mean().item() for key, val in all_res.items()})
def generate_multimodal(opt, netE, save_dir, logger):
test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
img_all = data['image']
for v in range(6):
x = x_all.transpose(1, 2).contiguous()
img = img_all
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))):
im = np.fliplr(np.flipud(d[-1]))
plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray')
write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy(), cat='chair')
export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
def redwood_demo(opt, netE, save_dir, logger):
import open3d as o3d
pth = "/viscam/u/alexzhou907/01DATA/redwood/09515_pc_partial.ply"
pth_gt = "/viscam/u/alexzhou907/01DATA/redwood/09515_pc.ply"
points = np.asarray(o3d.io.read_point_cloud(pth).points)
gt_points = np.asarray(o3d.io.read_point_cloud(pth_gt).points)
np.save('gt.npy', gt_points)
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
m, s = torch.from_numpy(test_dataset[0]['mean']).float(), torch.from_numpy(test_dataset[0]['std']).float()
x = torch.from_numpy(points[np.random.choice(points.shape[0], size=opt.npoints, replace=False)]).float()
x = (x-m)/s
x = x.transpose(1,2).cuda()
res = []
for k in range(20):
recon = netE.gen_samples(x[:,:,:opt.svpoints], x[:,:,opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
recon = recon * s+ m
res.append(recon)
res = torch.cat(res, dim=0)
write_to_xml_batch(os.path.join(save_dir, 'xml'),
(res).numpy(), cat='chair')
export_to_pc_batch(os.path.join(save_dir, 'ply'),
(res).numpy())
torch.save(res, os.path.join(save_dir, 'redwood_demo.pth'))
pcwrite(os.path.join(save_dir, 'ply', 'gt.ply'),
gt_points[np.random.choice(gt_points.shape[0], size=opt.npoints, replace=False)])
write_to_xml_batch(os.path.join(save_dir, 'xml_gt'),
gt_points[None], cat='chair')
exit()
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
opt.netE = ckpt
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
if opt.eval_recon_mvr:
# Evaluate generation
evaluate_recon_mvr(opt, netE, outf_syn, logger)
if opt.generate_multimodal:
generate_multimodal(opt, netE, outf_syn, logger)
if opt.eval_saved:
evaluate_saved(opt, netE, outf_syn, logger)
if opt.eval_redwood:
redwood_demo(opt, netE, outf_syn, logger)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
help='input batch size')
parser.add_argument('--classes', default=['chair'])
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--eval_recon_mvr', default=False)
parser.add_argument('--generate_multimodal', default=False)
parser.add_argument('--eval_saved', default=False)
parser.add_argument('--eval_redwood', default=True)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=200)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=0.2)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/6_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-11-01-19-21-18', help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
main(opt)

View file

@ -1,599 +0,0 @@
from pprint import pprint
from tqdm import tqdm
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.visualize import *
from utils.file_utils import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.partnet import GANdatasetPartNet
import trimesh
import csv
import numpy as np
import random
from plyfile import PlyData, PlyElement
def write_ply(points, filename, text=False):
""" input: Nx3, write points to filename as PLY format. """
points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])]
vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')])
el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
with open(filename, mode='wb') as f:
PlyData([el], text=text).write(f)
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
#############################################################################
def get_dataset(data_root, npoints, category):
train_ds = GANdatasetPartNet('test', data_root, category, npoints)
return train_ds
def generate_multimodal(opt, netE, save_dir, logger):
test_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['real']
x_all = data['raw']
for j in range(5):
x = torch.cat([x_all, gt_all[:, :, opt.svpoints:]], dim=-1)
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
for p, d in enumerate(zip(list(x_all), list(recon), list(data['raw_id']))):
partial = torch.cat([d[0], d[0][:, 0:1].expand(-1, opt.svpoints)], dim=-1)
rec = d[1]
rid = d[2]
write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d' % (i, p), 'mode_%03d' % j),
(torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy(), cat='chair')
export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, p), 'mode_%03d' % j),
(torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy())
raw_id = rid.split('.')[0]
save_sample_dir = os.path.join(save_dir, "{}".format(raw_id))
Path(save_sample_dir).mkdir(parents=True, exist_ok=True)
# save input partial shape
if j == 0:
save_path = os.path.join(save_sample_dir, "raw.ply")
write_ply((partial.detach().cpu()).transpose(0, 1).numpy(), save_path)
# save completed shape
save_path = os.path.join(save_sample_dir, "fake-z{}.ply".format(j))
write_ply((rec.detach().cpu()).transpose(0, 1).numpy(), save_path)
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
opt.netE = ckpt
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
if opt.generate_multimodal:
generate_multimodal(opt, netE, outf_syn, logger)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data_root', default='/viscam/u/alexzhou907/01DATA/partnet',
help='input batch size')
parser.add_argument('--classes', default='Chair')
parser.add_argument('--batch_size', type=int, default=64, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--eval_recon_mvr', default=False)
parser.add_argument('--generate_multimodal', default=True)
parser.add_argument('--eval_saved', default=False)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=1024)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=0.2)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/epn3d_chair', help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
main(opt)

View file

@ -1,599 +0,0 @@
from pprint import pprint
from tqdm import tqdm
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.visualize import *
from utils.file_utils import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.partnet import GANdatasetPartNet
import trimesh
import csv
import numpy as np
import random
from plyfile import PlyData, PlyElement
def write_ply(points, filename, text=False):
""" input: Nx3, write points to filename as PLY format. """
points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])]
vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')])
el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
with open(filename, mode='wb') as f:
PlyData([el], text=text).write(f)
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
#############################################################################
def get_dataset(data_root, npoints, category):
train_ds = GANdatasetPartNet('test', data_root, category, npoints)
return train_ds
def generate_multimodal(opt, netE, save_dir, logger):
test_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['real']
x_all = data['raw']
for j in range(5):
x = torch.cat([x_all, gt_all[:, :, opt.svpoints:]], dim=-1)
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
for p, d in enumerate(zip(list(x_all), list(recon), list(data['raw_id']))):
partial = torch.cat([d[0], d[0][:, 0:1].expand(-1, opt.svpoints)], dim=-1)
rec = d[1]
rid = d[2]
write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d' % (i, p), 'mode_%03d' % j),
(torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy(), cat='chair')
export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, p), 'mode_%03d' % j),
(torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy())
raw_id = rid.split('.')[0]
save_sample_dir = os.path.join(save_dir, "{}".format(raw_id))
Path(save_sample_dir).mkdir(parents=True, exist_ok=True)
# save input partial shape
if j == 0:
save_path = os.path.join(save_sample_dir, "raw.ply")
write_ply((partial.detach().cpu()).transpose(0, 1).numpy(), save_path)
# save completed shape
save_path = os.path.join(save_sample_dir, "fake-z{}.ply".format(j))
write_ply((rec.detach().cpu()).transpose(0, 1).numpy(), save_path)
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
opt.netE = ckpt
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
if opt.generate_multimodal:
generate_multimodal(opt, netE, outf_syn, logger)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data_root', default='/viscam/u/alexzhou907/01DATA/partnet',
help='input batch size')
parser.add_argument('--classes', default='Table')
parser.add_argument('--batch_size', type=int, default=64, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--eval_recon_mvr', default=False)
parser.add_argument('--generate_multimodal', default=True)
parser.add_argument('--eval_saved', default=False)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=1024)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=0.2)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/epn3d_chair', help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
main(opt)

View file

@ -1,681 +0,0 @@
from pprint import pprint
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.file_utils import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
#############################################################################
def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
cache=os.path.join(mesh_root, '../cache'), split='val',
categories=category,
radius=3, elev=0, azim=0, img_size=512, focal_length=1000,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
cache=os.path.join(pc_dataroot, '../cache'), split='val',
categories=category,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def evaluate_recon_mvr(opt, netE, save_dir, logger):
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
ref = []
samples = []
images = []
masked = []
k = 0
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
img_all = data['image']
B,V,N,C = x_all.shape
gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
img = img_all.reshape(B * V, *img_all.shape[2:])
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
x_adj = x.reshape(B,V,N,C)* s + m
recon_adj = recon.reshape(B,V,N,C)* s + m
img = img.reshape(B,V,*img.shape[1:])
ref.append( gt_all * s + m)
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
samples.append(recon_adj)
images.append(img)
ref_pcs = torch.cat(ref, dim=0)
sample_pcs = torch.cat(samples, dim=0)
images = torch.cat(images, dim=0)
masked = torch.cat(masked, dim=0)
B, V, N, C = ref_pcs.shape
torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
# torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
# Compute metrics
results = EMD_CD(sample_pcs.reshape(B*V, N, C),
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
pprint({key: val.mean().item() for key, val in results.items()})
logger.info({key: val.mean().item() for key, val in results.items()})
results['pc'] = sample_pcs
torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
def generate_multimodal(opt, netE, save_dir, logger):
test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
img_all = data['image']
Path(os.path.join(save_dir, 'x_%03d'%k)).mkdir(exist_ok=True, parents=True)
Path(os.path.join(save_dir, 'x_ply_%03d' % k)).mkdir(exist_ok=True, parents=True)
for v in range(5):
x = x_all.transpose(1, 2).contiguous()
img = img_all
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))):
im = np.fliplr(np.flipud(d[-1]))
plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray')
write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
opt.netE = ckpt
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
if opt.eval_recon_mvr:
# Evaluate generation
evaluate_recon_mvr(opt, netE, outf_syn, logger)
if opt.generate_multimodal:
generate_multimodal(opt, netE, outf_syn, logger)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
help='input batch size')
parser.add_argument('--classes', default=['airplane'])
parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--eval_recon_mvr', default=True)
parser.add_argument('--generate_multimodal', default=False)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=200)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=0.2)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/airplane_ckpt/', help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
set_seed(opt)
main(opt)

View file

@ -1,764 +0,0 @@
from pprint import pprint
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.visualize import *
from utils.file_utils import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
#############################################################################
def get_pc_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
return tr_dataset
def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
cache=os.path.join(mesh_root, '../cache'), split='val',
categories=category,
radius=3, elev=-89, azim=180, img_size=512, focal_length=1000,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
cache=os.path.join(pc_dataroot, '../cache'), split='val',
categories=category,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def evaluate_recon_mvr(opt, netE, save_dir, logger):
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
ref = []
samples = []
images = []
masked = []
k = 0
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
# img_all = data['image']
B,V,N,C = x_all.shape
gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
# img = img_all.reshape(B * V, *img_all.shape[2:])
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
x_adj = x.reshape(B,V,N,C)* s + m
recon_adj = recon.reshape(B,V,N,C)* s + m
# img = img.reshape(B,V,*img.shape[1:])
ref.append( gt_all * s + m)
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
samples.append(recon_adj)
# images.append(img)
ref_pcs = torch.cat(ref, dim=0)
sample_pcs = torch.cat(samples, dim=0)
# images = torch.cat(images, dim=0)
masked = torch.cat(masked, dim=0)
B, V, N, C = ref_pcs.shape
torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
# torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
# Compute metrics
results = EMD_CD(sample_pcs.reshape(B*V, N, C),
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
pprint({key: val.mean().item() for key, val in results.items()})
logger.info({key: val.mean().item() for key, val in results.items()})
results['pc'] = sample_pcs
torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
del ref_pcs, masked, results
def evaluate_saved(opt, netE, save_dir, logger):
ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn'
gt_pth = ours_base + '/recon_gt.pth'
ours_pth = ours_base + '/ours_results.pth'
gt = torch.load(gt_pth).permute(1,0,2,3)
ours = torch.load(ours_pth)['pc'].permute(1,0,2,3)
all_res = {}
for i, (gt_, ours_) in enumerate(zip(gt, ours)):
results = compute_all_metrics(gt_, ours_, opt.batch_size)
for key, val in results.items():
if i == 0:
all_res[key] = val
else:
all_res[key] += val
pprint(results)
for key, val in all_res.items():
all_res[key] = val / gt.shape[0]
pprint({key: val.mean().item() for key, val in all_res.items()})
def generate_multimodal(opt, netE, save_dir, logger):
test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
img_all = data['image']
for v in range(6):
x = x_all.transpose(1, 2).contiguous()
img = img_all
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))):
im = np.fliplr(np.flipud(d[-1]))
plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray')
write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy(), cat='chair')
export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
def redwood_demo(opt, netE, save_dir, logger):
import open3d as o3d
pth = "/viscam/u/alexzhou907/01DATA/redwood/01605_sample_1.ply"
pth_gt = "/viscam/u/alexzhou907/01DATA/redwood/01605_pc_gt.ply"
points = np.asarray(o3d.io.read_point_cloud(pth).points)
gt_points = np.asarray(o3d.io.read_point_cloud(pth_gt).points)
np.save('gt.npy', gt_points[np.random.choice(gt_points.shape[0], size=opt.npoints, replace=False)])
write_to_xml_batch(os.path.join(save_dir, 'xml_gt'),
gt_points[np.random.choice(gt_points.shape[0], size=opt.npoints, replace=False)][None], cat='table')
test_dataset = get_pc_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
m, s = torch.from_numpy(test_dataset[0]['mean']).float(), torch.from_numpy(test_dataset[0]['std']).float()
x = torch.from_numpy(points[np.random.choice(points.shape[0], size=opt.npoints, replace=False)]).float()
x = (x-m)/s
x = x[None].transpose(1,2).cuda()
res = []
for k in range(20):
recon = netE.gen_samples(x[:,:,:opt.svpoints], x[:,:,opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
recon = recon * s+ m
res.append(recon)
res = torch.cat(res, dim=0)
write_to_xml_batch(os.path.join(save_dir, 'xml'),
(res).numpy(), cat='table')
export_to_pc_batch(os.path.join(save_dir, 'ply'),
(res).numpy())
torch.save(res, os.path.join(save_dir, 'redwood_demo.pth'))
exit()
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
opt.netE = ckpt
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
if opt.eval_recon_mvr:
# Evaluate generation
evaluate_recon_mvr(opt, netE, outf_syn, logger)
if opt.generate_multimodal:
generate_multimodal(opt, netE, outf_syn, logger)
if opt.eval_saved:
evaluate_saved(opt, netE, outf_syn, logger)
if opt.eval_redwood:
redwood_demo(opt, netE, outf_syn, logger)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
help='input batch size')
parser.add_argument('--classes', default=['table'])
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--eval_recon_mvr', default=False)
parser.add_argument('--generate_multimodal', default=False)
parser.add_argument('--eval_saved', default=False)
parser.add_argument('--eval_redwood', default=True)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=200)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=0.2)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/9_res32_pc_table_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-12-16-14-09-50', help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
main(opt)

View file

View file

@ -1,905 +0,0 @@
import torch
from pprint import pprint
from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import argparse
from model.unet import get_model
from torch.distributions import Normal
from utils.file_utils import *
from utils.visualize import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_generation import PVCNN2Base
from tqdm import tqdm
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self,betas, loss_type, model_mean_type, model_var_type):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output)
if clip_denoised:
x_recon = torch.clamp(x_recon, -.5, .5)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape == data.shape
assert model_variance.shape == model_log_variance.shape == data.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device)
assert noise.shape == data.shape
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(data.shape) - 1))
sample = model_mean
if use_var:
sample = sample + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
assert sample.shape == pred_xstart.shape
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, denoise_fn, shape, device,
noise_fn=torch.randn, constrain_fn=lambda x, t:x,
clip_denoised=True, max_timestep=None, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
if max_timestep is None:
final_time = self.num_timesteps
else:
final_time = max_timestep
assert isinstance(shape, (tuple, list))
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
for t in reversed(range(0, final_time if not keep_running else len(self.betas))):
img_t = constrain_fn(img_t, t)
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False).detach()
assert img_t.shape == shape
return img_t
def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x):
assert t >= 1
t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1)
encoding = self.q_sample(x0, t_vec)
img_t = encoding
for k in reversed(range(0,t)):
img_t = constrain_fn(img_t, k)
t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=False, return_pred_xstart=False, use_var=True).detach()
return img_t
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type)
self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
assert out.shape == torch.Size([B, D, N])
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, shape, device, noise_fn=torch.randn, constrain_fn=lambda x, t:x,
clip_denoised=False, max_timestep=None,
keep_running=False):
return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn,
constrain_fn=constrain_fn,
clip_denoised=clip_denoised, max_timestep=max_timestep,
keep_running=keep_running)
def reconstruct(self, x0, t, constrain_fn=lambda x, t:x):
return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
def get_constrain_function(ground_truth, mask, eps, num_steps=1):
'''
:param target_shape_constraint: target voxels
:return: constrained x
'''
# eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2))
eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000)**2 ))
def constrain_fn(x, t):
eps_ = eps_all[t] if (t<1000) else 0
for _ in range(num_steps):
x = x - eps_ * ((x - ground_truth) * mask)
return x
return constrain_fn
#############################################################################
def get_dataset(dataroot, npoints,category,use_mask=False):
tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True, use_mask = use_mask)
te_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
categories=category, split='val',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
use_mask=use_mask
)
return tr_dataset, te_dataset
def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
cache=os.path.join(mesh_root, '../cache'), split='val',
categories=category,
radius=3, elev=0, azim=0, img_size=512, focal_length=1000,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def get_mvr_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
cache=os.path.join(mesh_root, '../cache'), split='val',
categories=category,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def get_mvr_dataset_v2(pc_dataroot, views_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
cache=os.path.join(pc_dataroot, '../cache'), split='val',
categories=category,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def evaluate_gen(opt, ref_pcs, logger):
if ref_pcs is None:
_, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
ref = []
for data in tqdm(test_dataloader, total=len(test_dataloader), desc='Generating Samples'):
x = data['test_points']
m, s = data['mean'].float(), data['std'].float()
ref.append(x*s + m)
ref_pcs = torch.cat(ref, dim=0).contiguous()
logger.info("Loading sample path: %s"
% (opt.eval_path))
sample_pcs = torch.load(opt.eval_path).contiguous()
logger.info("Generation sample size:%s reference size: %s"
% (sample_pcs.size(), ref_pcs.size()))
# Compute metrics
results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
results = {k: (v.cpu().detach().item()
if not isinstance(v, float) else v) for k, v in results.items()}
pprint(results)
logger.info(results)
jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy())
pprint('JSD: {}'.format(jsd))
logger.info('JSD: {}'.format(jsd))
def evaluate_recon(opt, netE, save_dir, logger):
test_dataset = get_mvr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2',
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
ref = []
samples = []
images = []
masked = []
k = 0
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
randind = i%24
gt_all = data['test_points'][:,randind:randind+1]
x_all = data['sv_points'][:,randind:randind+1]
mask_all= data['masks'][:,randind:randind+1]
img_all = data['image'][:,randind:randind+1]
B,V,N,C = x_all.shape
x = x_all.reshape(B*V,N,C).transpose(1,2).contiguous()
mask = mask_all.reshape(B*V,N,C).transpose(1,2).contiguous()
img = img_all.reshape(B*V, *img_all.shape[2:])
m, s = data['mean'].float(), data['std'].float()
# visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None)
# for t in [10]:
# recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
# opt.constrain_steps)).detach().cpu()
recon = netE.gen_samples(x.shape, 'cuda',
constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
opt.constrain_steps),
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
# for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))):
# write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
# visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None)
#
# export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
#
# k+=1
x_adj = x.reshape(B,V,N,C)* s + m
recon_adj = recon.reshape(B,V,N,C)* s + m
img = img.reshape(B,V,*img.shape[1:])
ref.append( gt_all * s + m)
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
samples.append(recon_adj)
images.append(img)
ref_pcs = torch.cat(ref, dim=0)
sample_pcs = torch.cat(samples, dim=0)
images = torch.cat(images, dim=0)
masked = torch.cat(masked, dim=0)
B, V, N, C = ref_pcs.shape
torch.save(ref_pcs.reshape(B*V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
torch.save(images.reshape(B*V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
torch.save(masked.reshape(B*V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
# Compute metrics
results = EMD_CD(sample_pcs.reshape(B*V, N, C),
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
pprint({key: val.mean() for key, val in results.items()})
logger.info({key: val.mean() for key, val in results.items()})
results['pc'] = sample_pcs
torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
#
# results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
#
# results = {k: (v.cpu().detach().item()
# if not isinstance(v, float) else v) for k, v in results.items()}
# pprint(results)
# logger.info(results)
def evaluate_recon_mvr(opt, netE, save_dir, logger):
test_dataset = get_mvr_dataset_v2(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
ref = []
samples = []
images = []
masked = []
k = 0
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
mask_all= data['masks']
img_all = data['image']
B,V,N,C = x_all.shape
gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
# visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None)
# for t in [10]:
# recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
# opt.constrain_steps)).detach().cpu()
cd_res = []
recon_res = []
for p in range(5):
x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
mask = mask_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
img = img_all.reshape(B * V, *img_all.shape[2:])
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x.shape, 'cuda',
constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
opt.constrain_steps),
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
cd = (((recon - x)**2)*mask.transpose(1, 2).contiguous()).sum(dim=(1,2))
cd_res.append(cd)
recon_res.append(recon)
cd_res = torch.stack(cd_res, dim=0)
recon_res = torch.stack(recon_res, dim=0)
_, argmin = torch.min(cd_res, 0)
recon = recon_res[argmin,torch.arange(0,argmin.shape[0])]
# for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))):
# write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
# visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None)
#
# export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
#
# k+=1
x_adj = x.reshape(B,V,N,C)* s + m
recon_adj = recon.reshape(B,V,N,C)* s + m
img = img.reshape(B,V,*img.shape[1:])
ref.append( gt_all * s + m)
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
samples.append(recon_adj)
images.append(img)
ref_pcs = torch.cat(ref, dim=0)
sample_pcs = torch.cat(samples, dim=0)
images = torch.cat(images, dim=0)
masked = torch.cat(masked, dim=0)
B, V, N, C = ref_pcs.shape
torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
# Compute metrics
results = EMD_CD(sample_pcs.reshape(B*V, N, C),
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
pprint({key: val.mean().item() for key, val in results.items()})
logger.info({key: val.mean().item() for key, val in results.items()})
results['pc'] = sample_pcs
torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
#
# results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
#
# results = {k: (v.cpu().detach().item()
# if not isinstance(v, float) else v) for k, v in results.items()}
# pprint(results)
# logger.info(results)
def generate(netE, opt, logger):
_, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
with torch.no_grad():
samples = []
ref = []
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Generating Samples'):
x = data['test_points'].transpose(1,2)
m, s = data['mean'].float(), data['std'].float()
gen = netE.gen_samples(x.shape,
'cuda', clip_denoised=False).detach().cpu()
gen = gen.transpose(1,2).contiguous()
x = x.transpose(1,2).contiguous()
gen = gen * s + m
x = x * s + m
samples.append(gen)
ref.append(x)
visualize_pointcloud_batch(os.path.join(str(Path(opt.eval_path).parent), 'x.png'), gen[:64], None,
None, None)
write_to_xml_batch(os.path.join(str(Path(opt.eval_path).parent), 'xml_samples_%03d'%i), gen[:min(gen.shape[0], 40)].numpy(), default_color='b')
samples = torch.cat(samples, dim=0)
ref = torch.cat(ref, dim=0)
torch.save(samples, opt.eval_path)
return ref
def generate_multimodal(opt, netE, save_dir, logger):
test_dataset = get_svr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2',
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
k = 0
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
mask_all= data['masks']
img_all = data['image']
# visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None)
# for t in [10]:
# recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
# opt.constrain_steps)).detach().cpu()
Path(os.path.join(save_dir, 'x_%03d'%k)).mkdir(exist_ok=True, parents=True)
Path(os.path.join(save_dir, 'x_ply_%03d' % k)).mkdir(exist_ok=True, parents=True)
for v in range(5):
x = x_all.transpose(1, 2).contiguous()
mask = mask_all.transpose(1, 2).contiguous()
img = img_all
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x.shape, 'cuda',
constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
opt.constrain_steps),
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
for d in zip(list(gt_all), list(recon), list(x), list(img)):
im = np.fliplr(np.flipud(d[-1]))
plt.imsave(os.path.join(save_dir, 'depth_%03d.png'%k), im, cmap='gray')
write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k, 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k, 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
k+=1
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
#'/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-02-45/epoch_1399.pth'
#'/viscam/u/alexzhou907/research/diffusion/shapenet/output/67_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-03-38/epoch_2799.pth'
opt.netE = '/viscam/u/alexzhou907/research/diffusion/shapenet/output/67_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-03-38/epoch_2799.pth'#ckpt
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
ref = None
if opt.generate:
epoch = int(os.path.basename(ckpt).split('.')[0].split('_')[-1])
opt.eval_path = os.path.join(outf_syn, 'epoch_{}_samples.pth'.format(epoch))
Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True)
ref=generate(netE, opt, logger)
if opt.eval_gen:
# Evaluate generation
evaluate_gen(opt, ref, logger)
if opt.eval_recon:
# Evaluate generation
evaluate_recon(opt, netE, outf_syn, logger)
if opt.eval_recon_mvr:
# Evaluate generation
evaluate_recon_mvr(opt, netE, outf_syn, logger)
if opt.generate_multimodal:
generate_multimodal(opt, netE, outf_syn, logger)
exit()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--classes', default=['car'])
parser.add_argument('--batch_size', type=int, default=50, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--generate',default=True)
parser.add_argument('--eval_gen', default=True)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=0.2)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='',required=True, help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
set_seed(opt)
main(opt)
# results in /viscam/u/alexzhou907/research/diffusion/shapenet/output/test_chair

View file

@ -1,911 +0,0 @@
import torch
from pprint import pprint
from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import argparse
from model.unet import get_model
from torch.distributions import Normal
from utils.file_utils import *
from utils.visualize import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_generation import PVCNN2Base
from tqdm import tqdm
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self,betas, loss_type, model_mean_type, model_var_type):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output)
if clip_denoised:
x_recon = torch.clamp(x_recon, -.5, .5)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape == data.shape
assert model_variance.shape == model_log_variance.shape == data.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device)
assert noise.shape == data.shape
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(data.shape) - 1))
sample = model_mean
if use_var:
sample = sample + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
assert sample.shape == pred_xstart.shape
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, denoise_fn, shape, device,
noise_fn=torch.randn, constrain_fn=lambda x, t:x,
clip_denoised=True, max_timestep=None, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
if max_timestep is None:
final_time = self.num_timesteps
else:
final_time = max_timestep
assert isinstance(shape, (tuple, list))
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
for t in reversed(range(0, final_time if not keep_running else len(self.betas))):
img_t = constrain_fn(img_t, t)
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False).detach()
assert img_t.shape == shape
return img_t
def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x):
assert t >= 1
t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1)
encoding = self.q_sample(x0, t_vec)
img_t = encoding
for k in reversed(range(0,t)):
img_t = constrain_fn(img_t, k)
t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=False, return_pred_xstart=False, use_var=True).detach()
return img_t
def interpolate(self, x0, x1, t, lamb, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x):
assert t >= 1
t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1)
encoding0 = self.q_sample(x0, t_vec)
encoding1 = self.q_sample(x1, t_vec)
enc = encoding0 * lamb + (1-lamb) * encoding1
img_t = enc
for k in reversed(range(0,t)):
img_t = constrain_fn(img_t, k)
t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=False, return_pred_xstart=False, use_var=True).detach()
return img_t
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type)
self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
assert out.shape == torch.Size([B, D, N])
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, shape, device, noise_fn=torch.randn, constrain_fn=lambda x, t:x,
clip_denoised=False, max_timestep=None,
keep_running=False):
return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn,
constrain_fn=constrain_fn,
clip_denoised=clip_denoised, max_timestep=max_timestep,
keep_running=keep_running)
def reconstruct(self, x0, t, constrain_fn=lambda x, t:x):
return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn)
def interpolate(self, x0, x1, t, lamb, constrain_fn=lambda x, t:x):
return self.diffusion.interpolate(x0, x1, t, lamb, self._denoise, constrain_fn=constrain_fn)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
def get_constrain_function(ground_truth, mask, eps, num_steps=1):
'''
:param target_shape_constraint: target voxels
:return: constrained x
'''
# eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2))
eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000)**2 ))
def constrain_fn(x, t):
eps_ = eps_all[t] if (t<1000) else 0
for _ in range(num_steps):
x = x - eps_ * ((x - ground_truth) * mask)
return x
return constrain_fn
#############################################################################
def get_dataset(dataroot, npoints,category,use_mask=False):
tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True, use_mask = use_mask)
te_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
categories=category, split='val',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
use_mask=use_mask
)
return tr_dataset, te_dataset
def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
cache=os.path.join(mesh_root, '../cache'), split='val',
categories=category,
radius=3, elev=-89, azim=180, img_size=512, focal_length=1000,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def get_mvr_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
cache=os.path.join(mesh_root, '../cache'), split='val',
categories=category,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def get_mvr_dataset_v2(pc_dataroot, views_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
cache=os.path.join(pc_dataroot, '../cache'), split='val',
categories=category,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def evaluate_gen(opt, ref_pcs, logger):
if ref_pcs is None:
_, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
ref = []
for data in tqdm(test_dataloader, total=len(test_dataloader), desc='Generating Samples'):
x = data['test_points']
m, s = data['mean'].float(), data['std'].float()
ref.append(x*s + m)
ref_pcs = torch.cat(ref, dim=0).contiguous()
logger.info("Loading sample path: %s"
% (opt.eval_path))
sample_pcs = torch.load(opt.eval_path).contiguous()
logger.info("Generation sample size:%s reference size: %s"
% (sample_pcs.size(), ref_pcs.size()))
# Compute metrics
# results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
# results = {k: (v.cpu().detach().item()
# if not isinstance(v, float) else v) for k, v in results.items()}
#
# pprint(results)
# logger.info(results)
jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy())
pprint('JSD: {}'.format(jsd))
logger.info('JSD: {}'.format(jsd))
def evaluate_recon(opt, netE, save_dir, logger):
test_dataset = get_mvr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2',
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
ref = []
samples = []
images = []
masked = []
k = 0
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
randind = i%24
gt_all = data['test_points'][:,randind:randind+1]
x_all = data['sv_points'][:,randind:randind+1]
mask_all= data['masks'][:,randind:randind+1]
img_all = data['image'][:,randind:randind+1]
B,V,N,C = x_all.shape
gt = gt_all.reshape(B*V,N,C).transpose(1,2).contiguous()
x = x_all.reshape(B*V,N,C).transpose(1,2).contiguous()
mask = mask_all.reshape(B*V,N,C).transpose(1,2).contiguous()
img = img_all.reshape(B*V, *img_all.shape[2:])
m, s = data['mean'].float(), data['std'].float()
# visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None)
# for t in [10]:
# t_vec = torch.empty(gt.shape[0], dtype=torch.int64, device='cuda').fill_(80)
# recon = netE.diffusion.q_sample(gt.cuda(), t_vec).detach().cpu()
recon = netE.gen_samples(x.shape, 'cuda',
constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
opt.constrain_steps),
clip_denoised=False).detach().cpu()
# recon = recon.transpose(1, 2).contiguous()
# x = x.transpose(1, 2).contiguous()
# gt = gt.transpose(1, 2).contiguous()
# write_to_xml_batch(os.path.join(save_dir, 'intermediate_%03d' % i),
# (recon.detach().cpu() * s[0].squeeze() + m[0].squeeze()).numpy())
# write_to_xml_batch(os.path.join(save_dir, 'x_%03d' % i),
# (gt.detach().cpu() * s[0].squeeze() + m[0].squeeze()).numpy())
# write_to_xml_batch(os.path.join(save_dir, 'noise_%03d' % i),
# (torch.randn_like(gt).detach().cpu() * s[0].squeeze() + m[0].squeeze()).numpy())
# for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))):
# write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
# visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None)
#
# k+=1
x_adj = x.reshape(B,V,N,C)* s + m
recon_adj = recon.reshape(B,V,N,C)* s + m
img = img.reshape(B,V,*img.shape[1:])
ref.append( gt_all * s + m)
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
samples.append(recon_adj)
images.append(img)
ref_pcs = torch.cat(ref, dim=0)
sample_pcs = torch.cat(samples, dim=0)
images = torch.cat(images, dim=0)
masked = torch.cat(masked, dim=0)
B, V, N, C = ref_pcs.shape
torch.save(ref_pcs.reshape(B*V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
torch.save(images.reshape(B*V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
torch.save(masked.reshape(B*V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
# Compute metrics
results = EMD_CD(sample_pcs.reshape(B*V, N, C),
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
pprint({key: val.mean().item() for key, val in results.items()})
logger.info({key: val.mean().item() for key, val in results.items()})
results['pc'] = sample_pcs
torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
#
# results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
#
# results = {k: (v.cpu().detach().item()
# if not isinstance(v, float) else v) for k, v in results.items()}
# pprint(results)
# logger.info(results)
def evaluate_recon_mvr(opt, netE, save_dir, logger):
test_dataset = get_mvr_dataset_v2(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
ref = []
samples = []
images = []
masked = []
k = 0
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
mask_all= data['masks']
# img_all = data['image']
B,V,N,C = x_all.shape
gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
mask = mask_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
# img = img_all.reshape(B * V, *img_all.shape[2:])
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x.shape, 'cuda',
constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
opt.constrain_steps),
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
x_adj = x.reshape(B,V,N,C)* s + m
recon_adj = recon.reshape(B,V,N,C)* s + m
# img = img.reshape(B,V,*img.shape[1:])
ref.append( gt_all * s + m)
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
samples.append(recon_adj)
# images.append(img)
ref_pcs = torch.cat(ref, dim=0)
sample_pcs = torch.cat(samples, dim=0)
# images = torch.cat(images, dim=0)
masked = torch.cat(masked, dim=0)
B, V, N, C = ref_pcs.shape
torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
# torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
# Compute metrics
results = EMD_CD(sample_pcs.reshape(B*V, N, C),
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
pprint({key: val.mean().item() for key, val in results.items()})
logger.info({key: val.mean().item() for key, val in results.items()})
results['pc'] = sample_pcs
torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
#
# results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
#
# results = {k: (v.cpu().detach().item()
# if not isinstance(v, float) else v) for k, v in results.items()}
# pprint(results)
# logger.info(results)
del ref_pcs, masked
def generate(netE, opt, logger):
_, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
with torch.no_grad():
samples = []
ref = []
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Generating Samples'):
x = data['test_points'].transpose(1,2)
m, s = data['mean'].float(), data['std'].float()
gen = netE.gen_samples(x.shape,
'cuda', clip_denoised=False).detach().cpu()
gen = gen.transpose(1,2).contiguous()
x = x.transpose(1,2).contiguous()
gen = gen * s + m
x = x * s + m
samples.append(gen)
ref.append(x)
visualize_pointcloud_batch(os.path.join(str(Path(opt.eval_path).parent), 'x.png'), gen[:64], None,
None, None)
write_to_xml_batch(os.path.join(str(Path(opt.eval_path).parent), 'xml_samples_%03d'%i), gen[:min(gen.shape[0], 40)].numpy(), cat='chair')
samples = torch.cat(samples, dim=0)
ref = torch.cat(ref, dim=0)
torch.save(samples, opt.eval_path)
return ref
def generate_multimodal(opt, netE, save_dir, logger):
test_dataset = get_svr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2',
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
mask_all= data['masks']
img_all = data['image']
# visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None)
# for t in [10]:
# recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
# opt.constrain_steps)).detach().cpu()
for v in range(10):
x = x_all.transpose(1, 2).contiguous()
mask = mask_all.transpose(1, 2).contiguous()
img = img_all
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x.shape, 'cuda',
constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
opt.constrain_steps),
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))):
im = np.fliplr(np.flipud(d[-1]))
plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray')
write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
#'/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-02-45/epoch_1399.pth'
#'/viscam/u/alexzhou907/research/diffusion/shapenet/output/67_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-03-38/epoch_2799.pth'
opt.netE = '/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do_best/2020-10-16-12-23-44/epoch_1799.pth'#ckpt
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
ref = None
if opt.generate:
epoch = int(os.path.basename(ckpt).split('.')[0].split('_')[-1])
opt.eval_path = os.path.join(outf_syn, 'epoch_{}_samples.pth'.format(epoch))
Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True)
ref=generate(netE, opt, logger)
if opt.eval_gen:
# Evaluate generation
evaluate_gen(opt, ref, logger)
if opt.eval_recon:
# Evaluate generation
evaluate_recon(opt, netE, outf_syn, logger)
if opt.eval_recon_mvr:
# Evaluate generation
evaluate_recon_mvr(opt, netE, outf_syn, logger)
if opt.generate_multimodal:
generate_multimodal(opt, netE, outf_syn, logger)
exit()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--classes', default=['chair'])
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--generate',default=True)
parser.add_argument('--eval_gen', default=False)
parser.add_argument('--eval_recon', default=False)
parser.add_argument('--eval_recon_mvr', default=False)
parser.add_argument('--generate_multimodal', default=False)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=.051)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do_best/2020-10-16-12-23-44/', help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test_chair/2020-10-18-13-46-21/syn/epoch_1699_samples.pth')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
set_seed(opt)
main(opt)
# results in /viscam/u/alexzhou907/research/diffusion/shapenet/output/test_chair/2020-10-18-13-46-21

View file

@ -1,925 +0,0 @@
import torch
import functools
from pprint import pprint
from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD, distChamfer
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import argparse
from model.unet import get_model
from torch.distributions import Normal
from utils.file_utils import *
from utils.visualize import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_generation import PVCNN2Base
from tqdm import tqdm
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self,betas, loss_type, model_mean_type, model_var_type):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output)
if clip_denoised:
x_recon = torch.clamp(x_recon, -.5, .5)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape == data.shape
assert model_variance.shape == model_log_variance.shape == data.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device)
assert noise.shape == data.shape
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(data.shape) - 1))
sample = model_mean
if use_var:
sample = sample + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
assert sample.shape == pred_xstart.shape
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, denoise_fn, shape, device,
noise_fn=torch.randn, constrain_fn=lambda x, t:x,
clip_denoised=True, max_timestep=None, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
if max_timestep is None:
final_time = self.num_timesteps
else:
final_time = max_timestep
assert isinstance(shape, (tuple, list))
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
for t in reversed(range(0, final_time if not keep_running else len(self.betas))):
img_t = constrain_fn(img_t, t)
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False).detach()
assert img_t.shape == shape
return img_t
def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x):
assert t >= 1
t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1)
encoding = self.q_sample(x0, t_vec)
img_t = encoding
for k in reversed(range(0,t)):
img_t = constrain_fn(img_t, k)
t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=False, return_pred_xstart=False, use_var=True).detach()
return img_t
def reconstruct2(self, x0, mask, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda forward, x, t:x):
z = noise_fn(size=x0.shape, dtype=torch.float, device=x0.device)
for _ in range(10):
img_t = z
outputs =[None for _ in range(len(self.betas))]
for t in reversed(range(0, len(self.betas))):
t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t)
outputs[t] = img_t.detach().cpu().clone()
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=False, return_pred_xstart=False, use_var=True).detach()
img_t = torch.autograd.Variable(img_t.data, requires_grad=True)
dist = ((img_t - x0) ** 2 * mask).sum(dim=0).mean()
grad = torch.autograd.grad(dist, [img_t])[0].detach()
print('Dist', dist.detach().cpu().item())
for t in (range(0, len(outputs))):
t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t)
x = outputs[t].to(x0).requires_grad_()
y = self.p_sample(denoise_fn=denoise_fn, data=x, t=t_, noise_fn=noise_fn,
clip_denoised=False, return_pred_xstart=False, use_var=True)
grad = torch.autograd.grad(y, [x], grad_outputs=grad)[0]
z = x.detach().to(x0) - 0.1 * grad.detach()
img_t = z
for t in reversed(range(0, len(self.betas))):
t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=False, return_pred_xstart=False, use_var=True).detach()
return img_t
# class PVCNN2(PVCNN2Base):
# sa_blocks = [
# ((32, 2, 32), (1024, 0.1, 32, (32, 64))),
# ((64, 3, 16), (256, 0.2, 32, (64, 128))),
# ((128, 3, 8), (64, 0.4, 32, (128, 256))),
# (None, (16, 0.8, 32, (256, 256, 512))),
# ]
# fp_blocks = [
# ((256, 256), (256, 3, 8)),
# ((256, 256), (256, 3, 8)),
# ((256, 128), (128, 2, 16)),
# ((128, 128, 64), (64, 2, 32)),
# ]
#
# def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
# voxel_resolution_multiplier=1):
# super().__init__(
# num_classes=num_classes, embed_dim=embed_dim, use_att=use_att,
# dropout=dropout, extra_feature_channels=extra_feature_channels,
# width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
# )
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type)
self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
assert out.shape == torch.Size([B, D, N])
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, shape, device, noise_fn=torch.randn, constrain_fn=lambda x, t:x,
clip_denoised=False, max_timestep=None,
keep_running=False):
return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn,
constrain_fn=constrain_fn,
clip_denoised=clip_denoised, max_timestep=max_timestep,
keep_running=keep_running)
def reconstruct(self, x0, t, constrain_fn=lambda x, t:x):
return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn)
def reconstruct2(self, x0, mask, constrain_fn):
return self.diffusion.reconstruct2(x0, mask, self._denoise, constrain_fn=constrain_fn)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
def get_constrain_function(ground_truth, mask, eps, num_steps=1):
'''
:param target_shape_constraint: target voxels
:return: constrained x
'''
# eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2))
eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 500)**2 ))
def constrain_fn(x, t):
eps_ = eps_all[t] if (t<500) else 0
for _ in range(num_steps):
x = x - eps_ * ((x - ground_truth) * mask)
return x
# mask_single = mask[0, :, 0]
# num = mask_single.sum().int().item()
def constrain_fn2(forward, x, t):
x = torch.autograd.Variable(x.data, requires_grad=True)
y = forward(x)
dist = ((y - ground_truth)**2 * mask).sum(dim=0).mean()
grad = torch.autograd.grad(dist, [x], retain_graph=True)[0]
x = x - eps * (grad)
print('Dist', dist.detach().cpu().item())
return x
return constrain_fn
#############################################################################
def get_dataset(dataroot, npoints,category,use_mask=False):
tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True, use_mask = use_mask)
te_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
categories=category, split='val',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
use_mask=use_mask
)
return tr_dataset, te_dataset
def get_mvr_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
cache=os.path.join(mesh_root, '../cache'), split='val',
categories=category,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def get_mvr_dataset_v2(pc_dataroot, views_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
cache=os.path.join(pc_dataroot, '../cache'), split='val',
categories=category,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def evaluate_gen(opt, ref_pcs, logger):
if ref_pcs is None:
_, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
ref = []
for data in tqdm(test_dataloader, total=len(test_dataloader), desc='Generating Samples'):
x = data['test_points']
m, s = data['mean'].float(), data['std'].float()
ref.append(x*s + m)
ref_pcs = torch.cat(ref, dim=0).contiguous()
logger.info("Loading sample path: %s"
% (opt.eval_path))
sample_pcs = torch.load(opt.eval_path).contiguous()
logger.info("Generation sample size:%s reference size: %s"
% (sample_pcs.size(), ref_pcs.size()))
# Compute metrics
# results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
# results = {k: (v.cpu().detach().item()
# if not isinstance(v, float) else v) for k, v in results.items()}
#
# pprint(results)
# logger.info(results)
jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy())
pprint('JSD: {}'.format(jsd))
logger.info('JSD: {}'.format(jsd))
def evaluate_recon(opt, netE, save_dir, logger):
test_dataset = get_mvr_dataset(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2',
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
ref = []
samples = []
images = []
masked = []
k = 0
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
randind = i%24
gt_all = data['test_points'][:,randind:randind+1]
x_all = data['sv_points'][:,randind:randind+1]
mask_all= data['masks'][:,randind:randind+1]
img_all = data['image'][:,randind:randind+1]
B,V,N,C = x_all.shape
x = x_all.reshape(B*V,N,C).transpose(1,2).contiguous()
mask = mask_all.reshape(B*V,N,C).transpose(1,2).contiguous()
img = img_all.reshape(B*V, *img_all.shape[2:])
m, s = data['mean'].float(), data['std'].float()
# visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None)
# for t in [10]:
# recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
# opt.constrain_steps)).detach().cpu()
recon = netE.gen_samples(x.shape, 'cuda',
constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
opt.constrain_steps),
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
# for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))):
# write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
# visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None)
#
# export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
#
# k+=1
x_adj = x.reshape(B,V,N,C)* s + m
recon_adj = recon.reshape(B,V,N,C)* s + m
img = img.reshape(B,V,*img.shape[1:])
ref.append( gt_all * s + m)
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
samples.append(recon_adj)
images.append(img)
ref_pcs = torch.cat(ref, dim=0)
sample_pcs = torch.cat(samples, dim=0)
images = torch.cat(images, dim=0)
masked = torch.cat(masked, dim=0)
B, V, N, C = ref_pcs.shape
torch.save(ref_pcs.reshape(B*V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
torch.save(images.reshape(B*V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
torch.save(masked.reshape(B*V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
# Compute metrics
results = EMD_CD(sample_pcs.reshape(B*V, N, C),
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
pprint({key: val.mean() for key, val in results.items()})
logger.info({key: val.mean() for key, val in results.items()})
results['pc'] = sample_pcs
torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
#
# results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
#
# results = {k: (v.cpu().detach().item()
# if not isinstance(v, float) else v) for k, v in results.items()}
# pprint(results)
# logger.info(results)
def evaluate_recon_mvr(opt, netE, save_dir, logger):
test_dataset = get_mvr_dataset_v2(opt.dataroot, '/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
ref = []
samples = []
images = []
masked = []
k = 0
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
mask_all= data['masks']
img_all = data['image']
B,V,N,C = x_all.shape
gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
# visualize_pointcloud_batch(os.path.join(save_dir, 'y_%03d.png'%i), x.transpose(1,2) * s + m, None, None, None)
# for t in [10]:
# recon = netE.reconstruct2(x.cuda(), mask.cuda(), get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
# opt.constrain_steps)).detach().cpu()
cd_res = []
recon_res = []
for p in range(5):
x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
mask = mask_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
img = img_all.reshape(B * V, *img_all.shape[2:])
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x.shape, 'cuda',
constrain_fn=get_constrain_function(x.cuda(), mask.cuda(), opt.constrain_eps,
opt.constrain_steps),
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
cd = (((recon - x)**2)*mask.transpose(1, 2).contiguous()).sum(dim=(1,2))
cd_res.append(cd)
recon_res.append(recon)
cd_res = torch.stack(cd_res, dim=0)
recon_res = torch.stack(recon_res, dim=0)
_, argmin = torch.min(cd_res, 0)
recon = recon_res[argmin,torch.arange(0,argmin.shape[0])]
# for d in zip(list(data['test_points'].reshape(B*V,N,C)), list(recon), list(x), list(torch.zeros_like(x))):
# write_to_xml_batch(os.path.join(save_dir, 'x_%03d'%k), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
# visualize_pointcloud_batch(os.path.join(save_dir, 'x_%03d.png'%k), torch.stack(d, dim=0)* s[0:1] + m[0:1], None, None, None)
#
# export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d'%k),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
#
# k+=1
x_adj = x.reshape(B,V,N,C)* s + m
recon_adj = recon.reshape(B,V,N,C)* s + m
img = img.reshape(B,V,*img.shape[1:])
ref.append( gt_all * s + m)
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
samples.append(recon_adj)
images.append(img)
ref_pcs = torch.cat(ref, dim=0)
sample_pcs = torch.cat(samples, dim=0)
images = torch.cat(images, dim=0)
masked = torch.cat(masked, dim=0)
B, V, N, C = ref_pcs.shape
torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
# Compute metrics
results = EMD_CD(sample_pcs.reshape(B*V, N, C),
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
pprint({key: val.mean().item() for key, val in results.items()})
logger.info({key: val.mean().item() for key, val in results.items()})
results['pc'] = sample_pcs
torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
#
# results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
#
# results = {k: (v.cpu().detach().item()
# if not isinstance(v, float) else v) for k, v in results.items()}
# pprint(results)
# logger.info(results)
def generate(netE, opt, logger):
_, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
with torch.no_grad():
samples = []
ref = []
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Generating Samples'):
x = data['test_points'].transpose(1,2)
m, s = data['mean'].float(), data['std'].float()
gen = netE.gen_samples(x.shape,
'cuda', clip_denoised=False).detach().cpu()
gen = gen.transpose(1,2).contiguous()
x = x.transpose(1,2).contiguous()
gen = gen * s + m
x = x * s + m
samples.append(gen)
ref.append(x)
# visualize_pointcloud_batch(os.path.join(str(Path(opt.eval_path).parent), 'x_%03d.png'%i), gen[:64], None,
# None, None)
# export_to_pc_batch(os.path.join(str(Path(opt.eval_path).parent), 'ply_%03d'%i),
# gen[:64].numpy())
write_to_xml_batch(os.path.join(str(Path(opt.eval_path).parent), 'xml_samples_%03d'%i), gen[:min(gen.shape[0], 40)].numpy(), cat='airplane')
samples = torch.cat(samples, dim=0)
ref = torch.cat(ref, dim=0)
torch.save(samples, opt.eval_path)
return ref
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
# netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
#'/viscam/u/alexzhou907/research/diffusion/shapenet/output/66_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-02-45/epoch_1399.pth'
#'/viscam/u/alexzhou907/research/diffusion/shapenet/output/67_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-10-04-01-03-38/epoch_2799.pth'
# opt.netE = '/viscam/u/alexzhou907/research/diffusion/shapenet/output/71_res32_pc_plane_mse_fs_dattng_8e3-1e5_w0.1beta_0wd_0.1do_best/2020-10-13-13-33-53/epoch_2899.pth'#ckpt
opt.netE = '/viscam/u/alexzhou907/research/diffusion/shapenet/output/71_res32_pc_plane_mse_fs_dattng_8e3-1e5_w0.1beta_0wd_0.1do/2020-10-07-13-26-10/epoch_2299.pth'
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
ref = None
if opt.generate:
epoch = int(os.path.basename(ckpt).split('.')[0].split('_')[-1])
opt.eval_path = os.path.join(outf_syn, 'epoch_{}_samples.pth'.format(epoch))
Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True)
ref=generate(netE, opt, logger)
if opt.eval_gen:
# Evaluate generation
evaluate_gen(opt, ref, logger)
if opt.eval_recon:
# Evaluate generation
evaluate_recon(opt, netE, outf_syn, logger)
if opt.eval_recon_mvr:
# Evaluate generation
evaluate_recon_mvr(opt, netE, outf_syn, logger)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--classes', default=['airplane'])
parser.add_argument('--batch_size', type=int, default=20, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--generate',default=False)
parser.add_argument('--eval_gen', default=True)
parser.add_argument('--eval_recon', default=False)
parser.add_argument('--eval_recon_mvr', default=False)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
'''model'''
parser.add_argument('--beta_start', default=0.00001)
parser.add_argument('--beta_end', default=0.008)
parser.add_argument('--schedule_type', default='warm0.1')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=.1)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/71_res32_pc_plane_mse_fs_dattng_8e3-1e5_w0.1beta_0wd_0.1do_best/2020-10-13-13-33-53', help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test_plane/2020-10-18-13-49-20/syn/epoch_2499_samples.pth')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
set_seed(opt)
main(opt)

View file

@ -7,9 +7,7 @@ import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.visualize import *
from utils.file_utils import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
@ -579,9 +577,8 @@ def main(opt):
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
help='input batch size')
parser.add_argument('--dataroot_pc', default='ShapeNetCore.v2.PC15k/')
parser.add_argument('--dataroot_sv', default='GenReData/')
parser.add_argument('--category', default='chair')
parser.add_argument('--batch_size', type=int, default=50, help='input batch size')

View file

@ -4,16 +4,13 @@ from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import argparse
from model.unet import get_model
from torch.distributions import Normal
from utils.file_utils import *
from utils.visualize import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_generation import PVCNN2Base
from tqdm import tqdm
@ -271,6 +268,8 @@ class PVCNN2(PVCNN2Base):
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
@ -534,8 +533,8 @@ def main(opt):
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--category', default='car')
parser.add_argument('--dataroot', default='ShapeNetCore.v2.PC15k/')
parser.add_argument('--category', default='chair')
parser.add_argument('--batch_size', type=int, default=50, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
@ -585,5 +584,3 @@ if __name__ == '__main__':
set_seed(opt)
main(opt)
# results in /viscam/u/alexzhou907/research/diffusion/shapenet/output/test_chair

View file

@ -4,7 +4,6 @@ import torch.optim as optim
import torch.utils.data
import argparse
from model.unet import get_model
from torch.distributions import Normal
from utils.file_utils import *
@ -559,7 +558,7 @@ def train(gpu, opt, output_dir, noises_init):
''' data '''
train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints,opt.classes)
train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints,opt.category)
dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None)
@ -609,10 +608,6 @@ def train(gpu, opt, output_dir, noises_init):
else:
start_epoch = 0
def new_x_chain(x, num_chain):
return torch.randn(num_chain, *x.shape[1:], device=x.device)
for epoch in range(start_epoch, opt.niter):
@ -754,7 +749,7 @@ def main():
''' workaround '''
train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints,opt.classes)
train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints,opt.category)
noises_init = torch.randn(len(train_dataset), opt.npoints-opt.svpoints, opt.nc)
if opt.dist_url == "env://" and opt.world_size == -1:
@ -772,9 +767,8 @@ def main():
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
help='input batch size')
parser.add_argument('--dataroot_pc', default='ShapeNetCore.v2.PC15k/')
parser.add_argument('--dataroot_sv', default='GenReData/')
parser.add_argument('--category', default='chair')
parser.add_argument('--bs', type=int, default=48, help='input batch size')

View file

@ -4,7 +4,6 @@ import torch.optim as optim
import torch.utils.data
import argparse
from model.unet import get_model
from torch.distributions import Normal
from utils.file_utils import *
@ -787,10 +786,10 @@ def main():
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--dataroot', default='ShapeNetCore.v2.PC15k/')
parser.add_argument('--category', default='chair')
parser.add_argument('--bs', type=int, default=48, help='input batch size')
parser.add_argument('--bs', type=int, default=16, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')

View file

@ -1,266 +0,0 @@
# Copyright (C) 2012 Daniel Maturana
# This file is part of binvox-rw-py.
#
# binvox-rw-py is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# binvox-rw-py is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with binvox-rw-py. If not, see <http://www.gnu.org/licenses/>.
#
"""
Binvox to Numpy and back.
>>> import numpy as np
>>> import binvox_rw
>>> with open('chair.binvox', 'rb') as f:
... m1 = binvox_rw.read_as_3d_array(f)
...
>>> m1.dims
[32, 32, 32]
>>> m1.scale
41.133000000000003
>>> m1.translate
[0.0, 0.0, 0.0]
>>> with open('chair_out.binvox', 'wb') as f:
... m1.write(f)
...
>>> with open('chair_out.binvox', 'rb') as f:
... m2 = binvox_rw.read_as_3d_array(f)
...
>>> m1.dims==m2.dims
True
>>> m1.scale==m2.scale
True
>>> m1.translate==m2.translate
True
>>> np.all(m1.data==m2.data)
True
>>> with open('chair.binvox', 'rb') as f:
... md = binvox_rw.read_as_3d_array(f)
...
>>> with open('chair.binvox', 'rb') as f:
... ms = binvox_rw.read_as_coord_array(f)
...
>>> data_ds = binvox_rw.dense_to_sparse(md.data)
>>> data_sd = binvox_rw.sparse_to_dense(ms.data, 32)
>>> np.all(data_sd==md.data)
True
>>> # the ordering of elements returned by numpy.nonzero changes with axis
>>> # ordering, so to compare for equality we first lexically sort the voxels.
>>> np.all(ms.data[:, np.lexsort(ms.data)] == data_ds[:, np.lexsort(data_ds)])
True
"""
import numpy as np
class Voxels(object):
""" Holds a binvox model.
data is either a three-dimensional numpy boolean array (dense representation)
or a two-dimensional numpy float array (coordinate representation).
dims, translate and scale are the model metadata.
dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model.
scale and translate relate the voxels to the original model coordinates.
To translate voxel coordinates i, j, k to original coordinates x, y, z:
x_n = (i+.5)/dims[0]
y_n = (j+.5)/dims[1]
z_n = (k+.5)/dims[2]
x = scale*x_n + translate[0]
y = scale*y_n + translate[1]
z = scale*z_n + translate[2]
"""
def __init__(self, data, dims, translate, scale, axis_order):
self.data = data
self.dims = dims
self.translate = translate
self.scale = scale
assert (axis_order in ('xzy', 'xyz'))
self.axis_order = axis_order
def clone(self):
data = self.data.copy()
dims = self.dims[:]
translate = self.translate[:]
return Voxels(data, dims, translate, self.scale, self.axis_order)
def write(self, fp):
write(self, fp)
def read_header(fp):
""" Read binvox header. Mostly meant for internal use.
"""
line = fp.readline().strip()
if not line.startswith(b'#binvox'):
raise IOError('Not a binvox file')
dims = list(map(int, fp.readline().strip().split(b' ')[1:]))
translate = list(map(float, fp.readline().strip().split(b' ')[1:]))
scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0]
line = fp.readline()
return dims, translate, scale
def read_as_3d_array(fp, fix_coords=True):
""" Read binary binvox format as array.
Returns the model with accompanying metadata.
Voxels are stored in a three-dimensional numpy array, which is simple and
direct, but may use a lot of memory for large models. (Storage requirements
are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy
boolean arrays use a byte per element).
Doesn't do any checks on input except for the '#binvox' line.
"""
dims, translate, scale = read_header(fp)
raw_data = np.frombuffer(fp.read(), dtype=np.uint8)
# if just using reshape() on the raw data:
# indexing the array as array[i,j,k], the indices map into the
# coords as:
# i -> x
# j -> z
# k -> y
# if fix_coords is true, then data is rearranged so that
# mapping is
# i -> x
# j -> y
# k -> z
values, counts = raw_data[::2], raw_data[1::2]
data = np.repeat(values, counts).astype(np.bool)
data = data.reshape(dims)
if fix_coords:
# xzy to xyz TODO the right thing
data = np.transpose(data, (0, 2, 1))
axis_order = 'xyz'
else:
axis_order = 'xzy'
return Voxels(data, dims, translate, scale, axis_order)
def read_as_coord_array(fp, fix_coords=True):
""" Read binary binvox format as coordinates.
Returns binvox model with voxels in a "coordinate" representation, i.e. an
3 x N array where N is the number of nonzero voxels. Each column
corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates
of the voxel. (The odd ordering is due to the way binvox format lays out
data). Note that coordinates refer to the binvox voxels, without any
scaling or translation.
Use this to save memory if your model is very sparse (mostly empty).
Doesn't do any checks on input except for the '#binvox' line.
"""
dims, translate, scale = read_header(fp)
raw_data = np.frombuffer(fp.read(), dtype=np.uint8)
values, counts = raw_data[::2], raw_data[1::2]
sz = np.prod(dims)
index, end_index = 0, 0
end_indices = np.cumsum(counts)
indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype)
values = values.astype(np.bool)
indices = indices[values]
end_indices = end_indices[values]
nz_voxels = []
for index, end_index in zip(indices, end_indices):
nz_voxels.extend(range(index, end_index))
nz_voxels = np.array(nz_voxels)
# TODO are these dims correct?
# according to docs,
# index = x * wxh + z * width + y; // wxh = width * height = d * d
x = nz_voxels / (dims[0]*dims[1])
zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y
z = zwpy / dims[0]
y = zwpy % dims[0]
if fix_coords:
data = np.vstack((x, y, z))
axis_order = 'xyz'
else:
data = np.vstack((x, z, y))
axis_order = 'xzy'
#return Voxels(data, dims, translate, scale, axis_order)
return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order)
def dense_to_sparse(voxel_data, dtype=np.int):
""" From dense representation to sparse (coordinate) representation.
No coordinate reordering.
"""
if voxel_data.ndim!=3:
raise ValueError('voxel_data is wrong shape; should be 3D array.')
return np.asarray(np.nonzero(voxel_data), dtype)
def sparse_to_dense(voxel_data, dims, dtype=np.bool):
if voxel_data.ndim!=2 or voxel_data.shape[0]!=3:
raise ValueError('voxel_data is wrong shape; should be 3xN array.')
if np.isscalar(dims):
dims = [dims]*3
dims = np.atleast_2d(dims).T
# truncate to integers
xyz = voxel_data.astype(np.int)
# discard voxels that fall outside dims
valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0)
xyz = xyz[:,valid_ix]
out = np.zeros(dims.flatten(), dtype=dtype)
out[tuple(xyz)] = True
return out
#def get_linear_index(x, y, z, dims):
#""" Assuming xzy order. (y increasing fastest.
#TODO ensure this is right when dims are not all same
#"""
#return x*(dims[1]*dims[2]) + z*dims[1] + y
def write(voxel_model, fp):
""" Write binary binvox format.
Note that when saving a model in sparse (coordinate) format, it is first
converted to dense format.
Doesn't check if the model is 'sane'.
"""
if voxel_model.data.ndim==2:
# TODO avoid conversion to dense
dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims)
else:
dense_voxel_data = voxel_model.data
fp.write('#binvox 1\n')
fp.write('dim '+' '.join(map(str, voxel_model.dims))+'\n')
fp.write('translate '+' '.join(map(str, voxel_model.translate))+'\n')
fp.write('scale '+str(voxel_model.scale)+'\n')
fp.write('data\n')
if not voxel_model.axis_order in ('xzy', 'xyz'):
raise ValueError('Unsupported voxel model axis order')
if voxel_model.axis_order=='xzy':
voxels_flat = dense_voxel_data.flatten()
elif voxel_model.axis_order=='xyz':
voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten()
# keep a sort of state machine for writing run length encoding
state = voxels_flat[0]
ctr = 0
for c in voxels_flat:
if c==state:
ctr += 1
# if ctr hits max, dump
if ctr==255:
fp.write(chr(state))
fp.write(chr(ctr))
ctr = 0
else:
# if switch state, dump
fp.write(chr(state))
fp.write(chr(ctr))
state = c
ctr = 1
# flush out remainders
if ctr > 0:
fp.write(chr(state))
fp.write(chr(ctr))
if __name__ == '__main__':
import doctest
doctest.testmod()

View file

@ -1,46 +0,0 @@
from skimage import measure
import numpy as np
def get_mesh(tsdf_vol, color_vol, threshold=0, vol_max=.5, vol_min=-.5):
"""Compute a mesh from the voxel volume using marching cubes.
"""
vol_origin = vol_min
voxel_size = (vol_max - vol_min) / tsdf_vol.shape[-1]
# Marching cubes
verts, faces, norms, vals = measure.marching_cubes_lewiner(tsdf_vol, level=threshold)
verts_ind = np.round(verts).astype(int)
verts = verts * voxel_size + vol_origin # voxel grid coordinates to world coordinates
# Get vertex colors
if color_vol is None:
return verts, faces, norms
colors = color_vol[:, verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]].T
return verts, faces, norms, colors
def get_point_cloud(tsdf_vol, color_vol, vol_max=0.5, vol_min=-0.5):
vol_origin = vol_min
voxel_size = (vol_max - vol_min) / tsdf_vol.shape[-1]
# Marching cubes
verts = measure.marching_cubes_lewiner(tsdf_vol, level=0)[0]
verts_ind = np.round(verts).astype(int)
verts = verts * voxel_size + vol_origin
# Get vertex colors
colors = color_vol[:, verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]].T
pc = np.hstack([verts, colors])
return pc
def sparse_to_dense_voxel(coords, feats, res):
coords = coords.astype('int64', copy=False)
a = np.zeros((res, res, res), dtype=feats.dtype)
a[coords[:,0],coords[:,1],coords[:,2] ] = feats[:,0].astype(a.dtype, copy=False)
return a

View file

@ -1,146 +0,0 @@
import numpy as np
from pathlib import Path
import os
def standardize_bbox(pcl, points_per_object, scale=None):
pt_indices = np.random.choice(pcl.shape[0], points_per_object, replace=False)
np.random.shuffle(pt_indices)
pcl = pcl[pt_indices] # n by 3
mins = np.amin(pcl, axis=0)
maxs = np.amax(pcl, axis=0)
center = (mins + maxs) / 2.
if scale is None:
scale = np.amax(maxs - mins)
result = ((pcl - center) / scale).astype(np.float32) # [-0.5, 0.5]
return result
xml_head = \
"""
<scene version="0.6.0">
<integrator type="path">
<integer name="maxDepth" value="-1"/>
</integrator>
<sensor type="perspective">
<float name="farClip" value="100"/>
<float name="nearClip" value="0.1"/>
<transform name="toWorld">
<lookat origin="3,3,3" target="0,0,0" up="0,0,1"/>
</transform>
<float name="fov" value="{}"/>
<sampler type="ldsampler">
<integer name="sampleCount" value="256"/>
</sampler>
<film type="hdrfilm">
<integer name="width" value="256"/>
<integer name="height" value="256"/>
<rfilter type="gaussian"/>
<boolean name="banner" value="false"/>
</film>
</sensor>
<bsdf type="roughplastic" id="surfaceMaterial">
<string name="distribution" value="ggx"/>
<float name="alpha" value="0.05"/>
<float name="intIOR" value="1.46"/>
<rgb name="diffuseReflectance" value="1,1,1"/> <!-- default 0.5 -->
</bsdf>
"""
xml_ball_segment = \
"""
<shape type="sphere">
<float name="radius" value="{}"/>
<transform name="toWorld">
<translate x="{}" y="{}" z="{}"/>
</transform>
<bsdf type="diffuse">
<rgb name="reflectance" value="{},{},{}"/>
</bsdf>
</shape>
"""
xml_tail = \
"""
<shape type="rectangle">
<bsdf type="diffuse">
<rgb name="reflectance" value="1"/>
</bsdf>
<transform name="toWorld">
<scale x="100" y="100" z="1"/>
<translate x="0" y="0" z="{}"/>
</transform>
</shape>
<shape type="sphere">
<transform name="toWorld">
<scale x="10" y="10" z="1"/>
<lookat origin="2,0,18" target="0,0,0" up="0,0,1"/>
</transform>
<emitter type="area">
<rgb name="radiance" value="5"/>
</emitter>
</shape>
<shape type="sphere">
<transform name="toWorld">
<scale x="10" y="10" z="1"/>
<lookat origin="-30,0,18" target="-100,0,0" up="0,0,1"/>
</transform>
<emitter type="area">
<rgb name="radiance" value="5"/>
</emitter>
</shape>
</scene>
"""
def colormap_fn(x, y, z):
vec = np.array([x, y, z])
vec = np.clip(vec, 0.001, 1.0)
norm = np.sqrt(np.sum(vec ** 2))
vec /= norm
return [vec[0], vec[1], vec[2]]
color_dict = {'r': [163, 102, 96], 'g': [20, 130, 3],
'o': [145, 128, 47], 'b': [91, 102, 112], 'p':[133,111,139], 'br':[111,92,81]}
color_map = {'airplane': 'r', 'chair': 'o', 'car': 'b', 'table': 'p', 'lamp':'br'}
fov_map = {'airplane': 12, 'chair': 16, 'car':15, 'table': 13, 'lamp':13}
radius_map = {'airplane': 0.02, 'chair': 0.035, 'car': 0.01, 'table':0.035, 'lamp':0.035}
def write_to_xml_batch(dir, pcl_batch, filenames=None, color_batch=None, cat='airplane'):
default_color = color_map[cat]
Path(dir).mkdir(parents=True, exist_ok=True)
if filenames is not None:
assert len(filenames) == pcl_batch.shape[0]
# mins = np.amin(pcl_batch, axis=(0,1))
# maxs = np.amax(pcl_batch, axis=(0,1))
# scale = 1; print(np.amax(maxs - mins))
for k, pcl in enumerate(pcl_batch):
xml_segments = [xml_head.format(fov_map[cat])]
pcl = standardize_bbox(pcl, pcl.shape[0])
pcl = pcl[:, [2, 0, 1]]
pcl[:, 0] *= -1
pcl[:, 2] += 0.0125
for i in range(pcl.shape[0]):
if color_batch is not None:
color = color_batch[k, i]
else:
color = np.array(color_dict[default_color]) / 255
# color = colormap_fn(pcl[i,0]+0.5,pcl[i,1]+0.5,pcl[i,2]+0.5-0.0125)
xml_segments.append(xml_ball_segment.format(radius_map[cat], pcl[i, 0], pcl[i, 1], pcl[i, 2], *color))
xml_segments.append(
xml_tail.format(pcl[:, 2].min()))
xml_content = str.join('', xml_segments)
if filenames is None:
fn = 'sample_{}.xml'.format(k)
else:
fn = filenames[k]
with open(os.path.join(dir, fn), 'w') as f:
f.write(xml_content)

View file

@ -1,170 +0,0 @@
import numpy as np
from pathlib import Path
import os
def standardize_bbox(pcl, points_per_object):
pt_indices = np.random.choice(pcl.shape[0], points_per_object, replace=False)
np.random.shuffle(pt_indices)
pcl = pcl[pt_indices] # n by 3
mins = np.amin(pcl, axis=0)
maxs = np.amax(pcl, axis=0)
center = (mins + maxs) / 2.
scale = np.amax(maxs - mins)
result = ((pcl - center) / scale).astype(np.float32) # [-0.5, 0.5]
return result
xml_head = \
"""
<scene version="0.6.0">
<integrator type="path">
<integer name="maxDepth" value="-1"/>
</integrator>
<sensor type="perspective">
<float name="farClip" value="100"/>
<float name="nearClip" value="0.1"/>
<transform name="toWorld">
<lookat origin="{},{},{}" target="0,0,0" up="0,0,1"/>
</transform>
<float name="fov" value="20"/>
<sampler type="ldsampler">
<integer name="sampleCount" value="256"/>
</sampler>
<film type="hdrfilm">
<integer name="width" value="480"/>
<integer name="height" value="480"/>
<rfilter type="gaussian"/>
<boolean name="banner" value="false"/>
</film>
</sensor>
<bsdf type="roughplastic" id="surfaceMaterial">
<string name="distribution" value="ggx"/>
<float name="alpha" value="0.05"/>
<float name="intIOR" value="1.46"/>
<rgb name="diffuseReflectance" value="1,1,1"/> <!-- default 0.5 -->
</bsdf>
"""
xml_ball_segment = \
"""
<shape type="sphere">
<float name="radius" value="{}"/>
<transform name="toWorld">
<translate x="{}" y="{}" z="{}"/>
</transform>
<bsdf type="diffuse">
<rgb name="reflectance" value="{},{},{}"/>
</bsdf>
</shape>
"""
xml_tail = \
"""
<shape type="rectangle">
<bsdf type="diffuse">
<rgb name="reflectance" value="1"/>
</bsdf>
<transform name="toWorld">
<scale x="100" y="100" z="1"/>
<translate x="0" y="0" z="{}"/>
</transform>
</shape>
<shape type="sphere">
<transform name="toWorld">
<scale x="10" y="10" z="1"/>
<lookat origin="2,0,18" target="0,0,0" up="0,0,1"/>
</transform>
<emitter type="area">
<rgb name="radiance" value="5"/>
</emitter>
</shape>
<shape type="sphere">
<transform name="toWorld">
<scale x="10" y="10" z="1"/>
<lookat origin="-30,0,18" target="-100,0,0" up="0,0,1"/>
</transform>
<emitter type="area">
<rgb name="radiance" value="5"/>
</emitter>
</shape>
</scene>
"""
def colormap_fn(x, y, z):
vec = np.array([x, y, z])
vec = np.clip(vec, 0.001, 1.0)
norm = np.sqrt(np.sum(vec ** 2))
vec /= norm
return [vec[0], vec[1], vec[2]]
color_dict = {'r': [163, 102, 96], 'p': [133,111,139], 'g': [20, 130, 3],
'o': [145, 128, 47], 'b': [91, 102, 112]}
color_map = {'airplane': 'r', 'chair': 'o', 'car': 'b', 'table': 'p'}
fov_map = {'airplane': 12, 'chair': 15, 'car':12, 'table':12}
radius_map = {'airplane': 0.0175, 'chair': 0.035, 'car': 0.025, 'table': 0.02}
def write_to_xml_batch(dir, pcl_batch, color_batch=None, cat='airplane', elev=15, azim=45, radius=np.sqrt(18)):
elev_rad = elev * np.pi / 180
azim_rad = azim * np.pi / 180
x = radius * np.cos(elev_rad)*np.cos(azim_rad)
y = radius * np.cos(elev_rad)*np.sin(azim_rad)
z = radius * np.sin(elev_rad)
default_color = color_map[cat]
Path(dir).mkdir(parents=True, exist_ok=True)
for k, pcl in enumerate(pcl_batch):
xml_segments = [xml_head.format(x,y,z)]
pcl = standardize_bbox(pcl, pcl.shape[0])
pcl = pcl[:, [2, 0, 1]]
pcl[:, 0] *= -1
pcl[:, 2] += 0.0125
for i in range(pcl.shape[0]):
if color_batch is not None:
color = color_batch[k, i]
else:
color = np.array(color_dict[default_color]) / 255
# color = colormap_fn(pcl[i,0]+0.5,pcl[i,1]+0.5,pcl[i,2]+0.5-0.0125)
xml_segments.append(xml_ball_segment.format(0.0175, pcl[i, 0], pcl[i, 1], pcl[i, 2], *color))
xml_segments.append(
xml_tail.format(pcl[:, 2].min()))
xml_content = str.join('', xml_segments)
with open(os.path.join(dir, 'sample_{}.xml'.format(k)), 'w') as f:
f.write(xml_content)
def write_to_xml(file, pcl, cat='airplane', elev=15, azim=45, radius=np.sqrt(18)):
assert pcl.ndim == 2
elev_rad = elev * np.pi / 180
azim_rad = azim * np.pi / 180
x = radius * np.cos(elev_rad)*np.cos(azim_rad)
y = radius * np.cos(elev_rad)*np.sin(azim_rad)
z = radius * np.sin(elev_rad)
default_color = color_map[cat]
xml_segments = [xml_head.format(x,y,z)]
pcl = standardize_bbox(pcl, pcl.shape[0])
pcl = pcl[:, [2, 0, 1]]
pcl[:, 0] *= -1
pcl[:, 2] += 0.0125
for i in range(pcl.shape[0]):
color = np.array(color_dict[default_color]) / 255
# color = colormap_fn(pcl[i,0]+0.5,pcl[i,1]+0.5,pcl[i,2]+0.5-0.0125)
xml_segments.append(xml_ball_segment.format(radius_map[cat], pcl[i, 0], pcl[i, 1], pcl[i, 2], *color))
xml_segments.append(
xml_tail.format(pcl[:, 2].min()))
xml_content = str.join('', xml_segments)
with open(file, 'w') as f:
f.write(xml_content)

View file

@ -1,86 +0,0 @@
import sys
sys.path.append('..')
import argparse
import os
import numpy as np
import trimesh
import glob
from joblib import Parallel, delayed
import re
from utils.mitsuba_renderer import write_to_xml_batch
def rotation_matrix(axis, theta):
"""
Return the rotation matrix associated with counterclockwise rotation about
the given axis by theta radians.
"""
axis = np.asarray(axis)
axis = axis / np.sqrt(np.dot(axis, axis))
a = np.cos(theta / 2.0)
b, c, d = -axis * np.sin(theta / 2.0)
aa, bb, cc, dd = a * a, b * b, c * c, d * d
bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
[2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
[2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
def rotate(vertices, faces):
'''
vertices: [numpoints, 3]
'''
N = rotation_matrix([1, 0, 0], 3* np.pi / 4).transpose()
# M = rotation_matrix([0, 1, 0], -np.pi / 2).transpose()
v, f = vertices.dot(N), faces
return v, f
def as_mesh(scene_or_mesh):
if isinstance(scene_or_mesh, trimesh.Scene):
mesh = trimesh.util.concatenate([
trimesh.Trimesh(vertices=m.vertices, faces=m.faces)
for m in scene_or_mesh.geometry.values()])
else:
mesh = scene_or_mesh
return mesh
def process_one(shape_dir, cat):
pc_paths = glob.glob(os.path.join(shape_dir, "*.obj"))
pc_paths = sorted(pc_paths)
xml_paths = [] #[re.sub('.ply', '.xml', os.path.basename(pth)) for pth in pc_paths]
gen_pcs = []
for path in pc_paths:
sample_mesh = trimesh.load(path, force='mesh')
v, f = rotate(sample_mesh.vertices,sample_mesh.faces)
mesh = trimesh.Trimesh(v, f)
sample_pts = trimesh.sample.sample_surface(mesh, 2048)[0]
gen_pcs.append(sample_pts)
xml_paths.append(re.sub('.obj', '.xml', os.path.basename(path)))
gen_pcs = np.stack(gen_pcs, axis=0)
write_to_xml_batch(os.path.dirname(pc_paths[0]), gen_pcs, xml_paths, cat=cat)
def process(args):
shape_names = [n for n in sorted(os.listdir(args.src)) if
os.path.isdir(os.path.join(args.src, n)) and not n.startswith('x')]
all_shape_dir = [os.path.join(args.src, name) for name in shape_names]
Parallel(n_jobs=10, verbose=2)(delayed(process_one)(path) for path in all_shape_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--src", type=str)
parser.add_argument("--cat", type=str)
args = parser.parse_args()
process_one(args.src, args.cat)
if __name__ == '__main__':
main()

View file

@ -1,54 +0,0 @@
import sys
sys.path.append('..')
import argparse
import os
import numpy as np
import trimesh
import glob
from joblib import Parallel, delayed
import re
from utils.mitsuba_renderer import write_to_xml_batch
def process_one(shape_dir):
pc_paths = glob.glob(os.path.join(shape_dir, "fake*.ply"))
pc_paths = sorted(pc_paths)
xml_paths = [re.sub('.ply', '.xml', os.path.basename(pth)) for pth in pc_paths]
gen_pcs = []
for path in pc_paths:
sample_pts = trimesh.load(path)
sample_pts = np.array(sample_pts.vertices)
gen_pcs.append(sample_pts)
raw_pc = np.array(trimesh.load(os.path.join(shape_dir, "raw.ply")).vertices)
raw_pc = np.concatenate([raw_pc, np.tile(raw_pc[0:1], (gen_pcs[0].shape[0]-raw_pc.shape[0],1))])
gen_pcs.append(raw_pc)
gen_pcs = np.stack(gen_pcs, axis=0)
xml_paths.append('raw.xml')
write_to_xml_batch(os.path.dirname(pc_paths[0]), gen_pcs, xml_paths)
def process(args):
shape_names = [n for n in sorted(os.listdir(args.src)) if
os.path.isdir(os.path.join(args.src, n)) and not n.startswith('x')]
all_shape_dir = [os.path.join(args.src, name) for name in shape_names]
Parallel(n_jobs=10, verbose=2)(delayed(process_one)(path) for path in all_shape_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--src", type=str)
args = parser.parse_args()
process_one(args)
if __name__ == '__main__':
main()