PVD/test_completion.py

704 lines
27 KiB
Python
Raw Normal View History

2023-04-11 09:12:58 +00:00
import argparse
2021-10-19 20:54:46 +00:00
from pprint import pprint
import torch.nn as nn
import torch.utils.data
from torch.distributions import Normal
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
2023-04-11 09:12:58 +00:00
from metrics.evaluation_metrics import EMD_CD, compute_all_metrics
from model.pvcnn_completion import PVCNN2Base
from utils.file_utils import *
"""
2021-10-19 20:54:46 +00:00
models
2023-04-11 09:12:58 +00:00
"""
2021-10-19 20:54:46 +00:00
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
2023-04-11 09:12:58 +00:00
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + (mean1 - mean2) ** 2 * torch.exp(-logvar2))
2021-10-19 20:54:46 +00:00
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
2023-04-11 09:12:58 +00:00
min_in = inv_stdv * (centered_x - 0.5)
2021-10-19 20:54:46 +00:00
cdf_min = px0.cdf(min_in)
2023-04-11 09:12:58 +00:00
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1.0 - cdf_min, torch.ones_like(cdf_min) * 1e-12))
2021-10-19 20:54:46 +00:00
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
2023-04-11 09:12:58 +00:00
x < 0.001,
log_cdf_plus,
torch.where(
x > 0.999, log_one_minus_cdf_min, torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12))
),
)
2021-10-19 20:54:46 +00:00
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
2023-04-11 09:12:58 +00:00
(timesteps,) = betas.shape
2021-10-19 20:54:46 +00:00
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
2023-04-11 09:12:58 +00:00
alphas = 1.0 - betas
2021-10-19 20:54:46 +00:00
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
2023-04-11 09:12:58 +00:00
alphas_cumprod_prev = torch.from_numpy(np.append(1.0, alphas_cumprod[:-1])).float()
2021-10-19 20:54:46 +00:00
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
2023-04-11 09:12:58 +00:00
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1).float()
2021-10-19 20:54:46 +00:00
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
2023-04-11 09:12:58 +00:00
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
2021-10-19 20:54:46 +00:00
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
2023-04-11 09:12:58 +00:00
self.posterior_log_variance_clipped = torch.log(
torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))
)
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
2021-10-19 20:54:46 +00:00
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
2023-04-11 09:12:58 +00:00
(bs,) = t.shape
2021-10-19 20:54:46 +00:00
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
2023-04-11 09:12:58 +00:00
variance = self._extract(1.0 - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
2021-10-19 20:54:46 +00:00
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
2023-04-11 09:12:58 +00:00
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
+ self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
2021-10-19 20:54:46 +00:00
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
2023-04-11 09:12:58 +00:00
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start
+ self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
2021-10-19 20:54:46 +00:00
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
2023-04-11 09:12:58 +00:00
posterior_log_variance_clipped = self._extract(
self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape
)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
2021-10-19 20:54:46 +00:00
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
2023-04-11 09:12:58 +00:00
model_output = denoise_fn(data, t)[:, :, self.sv_points :]
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
if self.model_var_type in ["fixedsmall", "fixedlarge"]:
2021-10-19 20:54:46 +00:00
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
2023-04-11 09:12:58 +00:00
"fixedlarge": (
self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device),
),
"fixedsmall": (
self.posterior_variance.to(data.device),
self.posterior_log_variance_clipped.to(data.device),
),
2021-10-19 20:54:46 +00:00
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
2023-04-11 09:12:58 +00:00
if self.model_mean_type == "eps":
x_recon = self._predict_xstart_from_eps(data[:, :, self.sv_points :], t=t, eps=model_output)
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:, :, self.sv_points :], t=t)
2021-10-19 20:54:46 +00:00
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
2023-04-11 09:12:58 +00:00
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
2021-10-19 20:54:46 +00:00
)
2023-04-11 09:12:58 +00:00
""" samples """
2021-10-19 20:54:46 +00:00
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
2023-04-11 09:12:58 +00:00
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data, t=t, clip_denoised=clip_denoised, return_pred_xstart=True
)
2021-10-19 20:54:46 +00:00
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
2023-04-11 09:12:58 +00:00
sample = torch.cat([data[:, :, : self.sv_points], sample], dim=-1)
2021-10-19 20:54:46 +00:00
return (sample, pred_xstart) if return_pred_xstart else sample
2023-04-11 09:12:58 +00:00
def p_sample_loop(
self, partial_x, denoise_fn, shape, device, noise_fn=torch.randn, clip_denoised=True, keep_running=False
):
2021-10-19 20:54:46 +00:00
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
2023-04-11 09:12:58 +00:00
img_t = self.p_sample(
denoise_fn=denoise_fn,
data=img_t,
t=t_,
noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False,
)
assert img_t[:, :, self.sv_points :].shape == shape
2021-10-19 20:54:46 +00:00
return img_t
2023-04-11 09:12:58 +00:00
def p_sample_loop_trajectory(
self, denoise_fn, shape, device, freq, noise_fn=torch.randn, clip_denoised=True, keep_running=False
):
2021-10-19 20:54:46 +00:00
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
2023-04-11 09:12:58 +00:00
total_steps = self.num_timesteps if not keep_running else len(self.betas)
2021-10-19 20:54:46 +00:00
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
2023-04-11 09:12:58 +00:00
for t in reversed(range(0, total_steps)):
2021-10-19 20:54:46 +00:00
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
2023-04-11 09:12:58 +00:00
img_t = self.p_sample(
denoise_fn=denoise_fn,
data=img_t,
t=t_,
noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False,
)
if t % freq == 0 or t == total_steps - 1:
2021-10-19 20:54:46 +00:00
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
2023-04-11 09:12:58 +00:00
"""losses"""
2021-10-19 20:54:46 +00:00
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
2023-04-11 09:12:58 +00:00
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
x_start=data_start[:, :, self.sv_points :], x_t=data_t[:, :, self.sv_points :], t=t
)
2021-10-19 20:54:46 +00:00
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
2023-04-11 09:12:58 +00:00
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True
)
2021-10-19 20:54:46 +00:00
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
2023-04-11 09:12:58 +00:00
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.0)
2021-10-19 20:54:46 +00:00
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
2023-04-11 09:12:58 +00:00
noise = torch.randn(
data_start[:, :, self.sv_points :].shape, dtype=data_start.dtype, device=data_start.device
)
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
data_t = self.q_sample(x_start=data_start[:, :, self.sv_points :], t=t, noise=noise)
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
if self.loss_type == "mse":
2021-10-19 20:54:46 +00:00
# predict the noise instead of x_start. seems to be weighted naturally like SNR
2023-04-11 09:12:58 +00:00
eps_recon = denoise_fn(torch.cat([data_start[:, :, : self.sv_points], data_t], dim=-1), t)[
:, :, self.sv_points :
]
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
losses = ((noise - eps_recon) ** 2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == "kl":
2021-10-19 20:54:46 +00:00
losses = self._vb_terms_bpd(
2023-04-11 09:12:58 +00:00
denoise_fn=denoise_fn,
data_start=data_start,
data_t=data_t,
t=t,
clip_denoised=False,
return_pred_xstart=False,
)
2021-10-19 20:54:46 +00:00
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
2023-04-11 09:12:58 +00:00
"""debug"""
2021-10-19 20:54:46 +00:00
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
2023-04-11 09:12:58 +00:00
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T - 1)
2021-10-19 20:54:46 +00:00
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
2023-04-11 09:12:58 +00:00
kl_prior = normal_kl(
mean1=qt_mean,
logvar1=qt_log_variance,
mean2=torch.tensor([0.0]).to(qt_mean),
logvar2=torch.tensor([0.0]).to(qt_log_variance),
)
2021-10-19 20:54:46 +00:00
assert kl_prior.shape == x_start.shape
2023-04-11 09:12:58 +00:00
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.0)
2021-10-19 20:54:46 +00:00
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
2023-04-11 09:12:58 +00:00
vals_bt_, mse_bt_ = torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
2021-10-19 20:54:46 +00:00
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
2023-04-11 09:12:58 +00:00
data_t = torch.cat(
[x_start[:, :, : self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points :], t=t_b)],
dim=-1,
)
2021-10-19 20:54:46 +00:00
new_vals_b, pred_xstart = self._vb_terms_bpd(
2023-04-11 09:12:58 +00:00
denoise_fn,
data_start=x_start,
data_t=data_t,
t=t_b,
clip_denoised=clip_denoised,
return_pred_xstart=True,
)
2021-10-19 20:54:46 +00:00
# MSE for progressive prediction loss
2023-04-11 09:12:58 +00:00
assert pred_xstart.shape == x_start[:, :, self.sv_points :].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points :]) ** 2).mean(
dim=list(range(1, len(pred_xstart.shape)))
)
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
2021-10-19 20:54:46 +00:00
# Insert the calculated term into the tensor of all terms
2023-04-11 09:12:58 +00:00
mask_bt = t_b[:, None] == torch.arange(T, device=t_b.device)[None, :].float()
2021-10-19 20:54:46 +00:00
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
2023-04-11 09:12:58 +00:00
prior_bpd_b = self._prior_bpd(x_start[:, :, self.sv_points :])
2021-10-19 20:54:46 +00:00
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
2023-04-11 09:12:58 +00:00
assert vals_bt_.shape == mse_bt_.shape == torch.Size(
[B, T]
) and total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
2021-10-19 20:54:46 +00:00
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
2023-04-11 09:12:58 +00:00
def __init__(
self,
num_classes,
sv_points,
embed_dim,
use_att,
dropout,
extra_feature_channels=3,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
2021-10-19 20:54:46 +00:00
super().__init__(
2023-04-11 09:12:58 +00:00
num_classes=num_classes,
sv_points=sv_points,
embed_dim=embed_dim,
use_att=use_att,
dropout=dropout,
extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier,
2021-10-19 20:54:46 +00:00
)
class Model(nn.Module):
2023-04-11 09:12:58 +00:00
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type: str):
2021-10-19 20:54:46 +00:00
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
2023-04-11 09:12:58 +00:00
self.model = PVCNN2(
num_classes=args.nc,
sv_points=args.svpoints,
embed_dim=args.embed_dim,
use_att=args.attention,
dropout=args.dropout,
extra_feature_channels=0,
)
2021-10-19 20:54:46 +00:00
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
2023-04-11 09:12:58 +00:00
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
return {"total_bpd_b": total_bpd_b, "terms_bpd": vals_bt, "prior_bpd_b": prior_bpd_b, "mse_bt": mse_bt}
2021-10-19 20:54:46 +00:00
def _denoise(self, data, t):
2023-04-11 09:12:58 +00:00
B, D, N = data.shape
2021-10-19 20:54:46 +00:00
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
2023-04-11 09:12:58 +00:00
noises[t != 0] = torch.randn((t != 0).sum(), *noises.shape[1:]).to(noises)
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
losses = self.diffusion.p_losses(denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
2021-10-19 20:54:46 +00:00
assert losses.shape == t.shape == torch.Size([B])
return losses
2023-04-11 09:12:58 +00:00
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, clip_denoised=True, keep_running=False):
return self.diffusion.p_sample_loop(
partial_x,
self._denoise,
shape=shape,
device=device,
noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running,
)
2021-10-19 20:54:46 +00:00
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
2023-04-11 09:12:58 +00:00
2021-10-19 20:54:46 +00:00
def get_betas(schedule_type, b_start, b_end, time_num):
2023-04-11 09:12:58 +00:00
if schedule_type == "linear":
2021-10-19 20:54:46 +00:00
betas = np.linspace(b_start, b_end, time_num)
2023-04-11 09:12:58 +00:00
elif schedule_type == "warm0.1":
2021-10-19 20:54:46 +00:00
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
2023-04-11 09:12:58 +00:00
elif schedule_type == "warm0.2":
2021-10-19 20:54:46 +00:00
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
2023-04-11 09:12:58 +00:00
elif schedule_type == "warm0.5":
2021-10-19 20:54:46 +00:00
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
#############################################################################
2023-04-11 09:12:58 +00:00
def get_mvr_dataset(pc_dataroot, views_root, npoints, category):
tr_dataset = ShapeNet15kPointClouds(
root_dir=pc_dataroot,
categories=[category],
split="train",
2021-10-19 20:54:46 +00:00
tr_sample_size=npoints,
te_sample_size=npoints,
2023-04-11 09:12:58 +00:00
scale=1.0,
2021-10-19 20:54:46 +00:00
normalize_per_shape=False,
normalize_std_per_axis=False,
2023-04-11 09:12:58 +00:00
random_subsample=True,
)
te_dataset = ShapeNet_Multiview_Points(
root_pc=pc_dataroot,
root_views=views_root,
cache=os.path.join(pc_dataroot, "../cache"),
split="val",
2021-10-19 20:54:46 +00:00
categories=[category],
2023-04-11 09:12:58 +00:00
npoints=npoints,
sv_samples=200,
2021-10-19 20:54:46 +00:00
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def evaluate_recon_mvr(opt, model, save_dir, logger):
2023-04-11 09:12:58 +00:00
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.category)
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
test_dataloader = torch.utils.data.DataLoader(
test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False
)
2021-10-19 20:54:46 +00:00
ref = []
samples = []
masked = []
2023-04-11 09:12:58 +00:00
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc="Reconstructing Samples"):
gt_all = data["test_points"]
x_all = data["sv_points"]
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
B, V, N, C = x_all.shape
gt_all = gt_all[:, None, :, :].expand(-1, V, -1, -1)
2021-10-19 20:54:46 +00:00
x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
2023-04-11 09:12:58 +00:00
m, s = data["mean"].float(), data["std"].float()
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
recon = (
model.gen_samples(
x[:, :, : opt.svpoints].cuda(), x[:, :, opt.svpoints :].shape, "cuda", clip_denoised=False
)
.detach()
.cpu()
)
2021-10-19 20:54:46 +00:00
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
2023-04-11 09:12:58 +00:00
x_adj = x.reshape(B, V, N, C) * s + m
recon_adj = recon.reshape(B, V, N, C) * s + m
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
ref.append(gt_all * s + m)
masked.append(x_adj[:, :, : test_dataloader.dataset.sv_samples, :])
2021-10-19 20:54:46 +00:00
samples.append(recon_adj)
ref_pcs = torch.cat(ref, dim=0)
sample_pcs = torch.cat(samples, dim=0)
masked = torch.cat(masked, dim=0)
B, V, N, C = ref_pcs.shape
2023-04-11 09:12:58 +00:00
torch.save(ref_pcs.reshape(B, V, N, C), os.path.join(save_dir, "recon_gt.pth"))
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
torch.save(masked.reshape(B, V, *masked.shape[2:]), os.path.join(save_dir, "recon_masked.pth"))
2021-10-19 20:54:46 +00:00
# Compute metrics
2023-04-11 09:12:58 +00:00
results = EMD_CD(sample_pcs.reshape(B * V, N, C), ref_pcs.reshape(B * V, N, C), opt.batch_size, reduced=False)
results = {
ky: val.reshape(B, V)
if val.shape
== torch.Size(
[
B * V,
]
)
else val
for ky, val in results.items()
}
2021-10-19 20:54:46 +00:00
pprint({key: val.mean().item() for key, val in results.items()})
logger.info({key: val.mean().item() for key, val in results.items()})
2023-04-11 09:12:58 +00:00
results["pc"] = sample_pcs
torch.save(results, os.path.join(save_dir, "ours_results.pth"))
2021-10-19 20:54:46 +00:00
del ref_pcs, masked, results
2023-04-11 09:12:58 +00:00
2021-10-19 20:54:46 +00:00
def evaluate_saved(opt, saved_dir):
# ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn'
2023-04-11 09:12:58 +00:00
gt_pth = saved_dir + "/recon_gt.pth"
ours_pth = saved_dir + "/ours_results.pth"
gt = torch.load(gt_pth).permute(1, 0, 2, 3)
ours = torch.load(ours_pth)["pc"].permute(1, 0, 2, 3)
2021-10-19 20:54:46 +00:00
all_res = {}
for i, (gt_, ours_) in enumerate(zip(gt, ours)):
results = compute_all_metrics(gt_, ours_, opt.batch_size)
for key, val in results.items():
if i == 0:
all_res[key] = val
else:
all_res[key] += val
pprint(results)
for key, val in all_res.items():
all_res[key] = val / gt.shape[0]
pprint({key: val.mean().item() for key, val in all_res.items()})
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
2023-04-11 09:12:58 +00:00
(outf_syn,) = setup_output_subdirs(output_dir, "syn")
2021-10-19 20:54:46 +00:00
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
model.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
model = model.cuda()
model.multi_gpu_wrapper(_transform_)
model.eval()
with torch.no_grad():
logger.info("Resume Path:%s" % opt.model)
resumed_param = torch.load(opt.model)
2023-04-11 09:12:58 +00:00
model.load_state_dict(resumed_param["model_state"])
2021-10-19 20:54:46 +00:00
if opt.eval_recon_mvr:
# Evaluate generation
evaluate_recon_mvr(opt, model, outf_syn, logger)
if opt.eval_saved:
evaluate_saved(opt, outf_syn)
def parse_args():
parser = argparse.ArgumentParser()
2023-04-11 09:12:58 +00:00
parser.add_argument("--dataroot_pc", default="ShapeNetCore.v2.PC15k/")
parser.add_argument("--dataroot_sv", default="GenReData/")
parser.add_argument("--category", default="chair")
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
parser.add_argument("--batch_size", type=int, default=50, help="input batch size")
parser.add_argument("--workers", type=int, default=16, help="workers")
parser.add_argument("--niter", type=int, default=10000, help="number of epochs to train for")
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
parser.add_argument("--eval_recon_mvr", default=True)
parser.add_argument("--eval_saved", default=True)
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
parser.add_argument("--nc", default=3)
parser.add_argument("--npoints", default=2048)
parser.add_argument("--svpoints", default=200)
"""model"""
parser.add_argument("--beta_start", default=0.0001)
parser.add_argument("--beta_end", default=0.02)
parser.add_argument("--schedule_type", default="linear")
parser.add_argument("--time_num", default=1000)
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
# params
parser.add_argument("--attention", default=True)
parser.add_argument("--dropout", default=0.1)
parser.add_argument("--embed_dim", type=int, default=64)
parser.add_argument("--loss_type", default="mse")
parser.add_argument("--model_mean_type", default="eps")
parser.add_argument("--model_var_type", default="fixedsmall")
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
parser.add_argument("--model", default="", required=True, help="path to model (to continue training)")
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
"""eval"""
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
parser.add_argument("--eval_path", default="")
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
parser.add_argument("--manualSeed", default=42, type=int, help="random seed")
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
parser.add_argument("--gpu", type=int, default=0, metavar="S", help="gpu id (default: 0)")
2021-10-19 20:54:46 +00:00
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
2023-04-11 09:12:58 +00:00
if __name__ == "__main__":
2021-10-19 20:54:46 +00:00
opt = parse_args()
main(opt)