653 lines
23 KiB
Python
653 lines
23 KiB
Python
import argparse
|
|
from pprint import pprint
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.utils.data
|
|
from torch.distributions import Normal
|
|
from tqdm import tqdm
|
|
|
|
from dataset.shapenet_data_pc import ShapeNet15kPointClouds
|
|
from metrics.evaluation_metrics import compute_all_metrics
|
|
from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
|
|
from model.pvcnn_generation import PVCNN2Base
|
|
from utils.file_utils import *
|
|
from utils.visualize import *
|
|
|
|
"""
|
|
models
|
|
"""
|
|
|
|
|
|
def normal_kl(mean1, logvar1, mean2, logvar2):
|
|
"""
|
|
KL divergence between normal distributions parameterized by mean and log-variance.
|
|
"""
|
|
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + (mean1 - mean2) ** 2 * torch.exp(-logvar2))
|
|
|
|
|
|
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
|
# Assumes data is integers [0, 1]
|
|
assert x.shape == means.shape == log_scales.shape
|
|
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
|
|
|
|
centered_x = x - means
|
|
inv_stdv = torch.exp(-log_scales)
|
|
plus_in = inv_stdv * (centered_x + 0.5)
|
|
cdf_plus = px0.cdf(plus_in)
|
|
min_in = inv_stdv * (centered_x - 0.5)
|
|
cdf_min = px0.cdf(min_in)
|
|
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12))
|
|
log_one_minus_cdf_min = torch.log(torch.max(1.0 - cdf_min, torch.ones_like(cdf_min) * 1e-12))
|
|
cdf_delta = cdf_plus - cdf_min
|
|
|
|
log_probs = torch.where(
|
|
x < 0.001,
|
|
log_cdf_plus,
|
|
torch.where(
|
|
x > 0.999, log_one_minus_cdf_min, torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12))
|
|
),
|
|
)
|
|
assert log_probs.shape == x.shape
|
|
return log_probs
|
|
|
|
|
|
class GaussianDiffusion:
|
|
def __init__(self, betas, loss_type, model_mean_type, model_var_type):
|
|
self.loss_type = loss_type
|
|
self.model_mean_type = model_mean_type
|
|
self.model_var_type = model_var_type
|
|
assert isinstance(betas, np.ndarray)
|
|
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
|
|
assert (betas > 0).all() and (betas <= 1).all()
|
|
(timesteps,) = betas.shape
|
|
self.num_timesteps = int(timesteps)
|
|
|
|
# initialize twice the actual length so we can keep running for eval
|
|
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
|
|
|
|
alphas = 1.0 - betas
|
|
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
|
|
alphas_cumprod_prev = torch.from_numpy(np.append(1.0, alphas_cumprod[:-1])).float()
|
|
|
|
self.betas = torch.from_numpy(betas).float()
|
|
self.alphas_cumprod = alphas_cumprod.float()
|
|
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
|
|
|
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
|
|
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod).float()
|
|
self.log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod).float()
|
|
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod).float()
|
|
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1).float()
|
|
|
|
betas = torch.from_numpy(betas).float()
|
|
alphas = torch.from_numpy(alphas).float()
|
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
|
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
|
self.posterior_variance = posterior_variance
|
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
|
self.posterior_log_variance_clipped = torch.log(
|
|
torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))
|
|
)
|
|
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
|
self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
|
|
|
|
@staticmethod
|
|
def _extract(a, t, x_shape):
|
|
"""
|
|
Extract some coefficients at specified timesteps,
|
|
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
|
|
"""
|
|
(bs,) = t.shape
|
|
assert x_shape[0] == bs
|
|
out = torch.gather(a, 0, t)
|
|
assert out.shape == torch.Size([bs])
|
|
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
|
|
|
|
def q_mean_variance(self, x_start, t):
|
|
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
|
|
variance = self._extract(1.0 - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
|
|
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
|
|
return mean, variance, log_variance
|
|
|
|
def q_sample(self, x_start, t, noise=None):
|
|
"""
|
|
Diffuse the data (t == 0 means diffused for 1 step)
|
|
"""
|
|
if noise is None:
|
|
noise = torch.randn(x_start.shape, device=x_start.device)
|
|
assert noise.shape == x_start.shape
|
|
return (
|
|
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
|
|
+ self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
|
|
)
|
|
|
|
def q_posterior_mean_variance(self, x_start, x_t, t):
|
|
"""
|
|
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
|
|
"""
|
|
assert x_start.shape == x_t.shape
|
|
posterior_mean = (
|
|
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start
|
|
+ self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
|
|
)
|
|
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
|
|
posterior_log_variance_clipped = self._extract(
|
|
self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape
|
|
)
|
|
assert (
|
|
posterior_mean.shape[0]
|
|
== posterior_variance.shape[0]
|
|
== posterior_log_variance_clipped.shape[0]
|
|
== x_start.shape[0]
|
|
)
|
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
|
|
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
|
|
model_output = denoise_fn(data, t)
|
|
|
|
if self.model_var_type in ["fixedsmall", "fixedlarge"]:
|
|
# below: only log_variance is used in the KL computations
|
|
model_variance, model_log_variance = {
|
|
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
|
|
"fixedlarge": (
|
|
self.betas.to(data.device),
|
|
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device),
|
|
),
|
|
"fixedsmall": (
|
|
self.posterior_variance.to(data.device),
|
|
self.posterior_log_variance_clipped.to(data.device),
|
|
),
|
|
}[self.model_var_type]
|
|
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data)
|
|
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data)
|
|
else:
|
|
raise NotImplementedError(self.model_var_type)
|
|
|
|
if self.model_mean_type == "eps":
|
|
x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output)
|
|
|
|
if clip_denoised:
|
|
x_recon = torch.clamp(x_recon, -0.5, 0.5)
|
|
|
|
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t)
|
|
else:
|
|
raise NotImplementedError(self.loss_type)
|
|
|
|
assert model_mean.shape == x_recon.shape == data.shape
|
|
assert model_variance.shape == model_log_variance.shape == data.shape
|
|
if return_pred_xstart:
|
|
return model_mean, model_variance, model_log_variance, x_recon
|
|
else:
|
|
return model_mean, model_variance, model_log_variance
|
|
|
|
def _predict_xstart_from_eps(self, x_t, t, eps):
|
|
assert x_t.shape == eps.shape
|
|
return (
|
|
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t
|
|
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
|
|
)
|
|
|
|
""" samples """
|
|
|
|
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True):
|
|
"""
|
|
Sample from the model
|
|
"""
|
|
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
|
|
denoise_fn, data=data, t=t, clip_denoised=clip_denoised, return_pred_xstart=True
|
|
)
|
|
noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device)
|
|
assert noise.shape == data.shape
|
|
# no noise when t == 0
|
|
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(data.shape) - 1))
|
|
|
|
sample = model_mean
|
|
if use_var:
|
|
sample = sample + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
|
|
assert sample.shape == pred_xstart.shape
|
|
return (sample, pred_xstart) if return_pred_xstart else sample
|
|
|
|
def p_sample_loop(
|
|
self,
|
|
denoise_fn,
|
|
shape,
|
|
device,
|
|
noise_fn=torch.randn,
|
|
constrain_fn=lambda x, t: x,
|
|
clip_denoised=True,
|
|
max_timestep=None,
|
|
keep_running=False,
|
|
):
|
|
"""
|
|
Generate samples
|
|
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
|
|
|
|
"""
|
|
if max_timestep is None:
|
|
final_time = self.num_timesteps
|
|
else:
|
|
final_time = max_timestep
|
|
|
|
assert isinstance(shape, (tuple, list))
|
|
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
|
|
for t in reversed(range(0, final_time if not keep_running else len(self.betas))):
|
|
img_t = constrain_fn(img_t, t)
|
|
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
|
|
img_t = self.p_sample(
|
|
denoise_fn=denoise_fn,
|
|
data=img_t,
|
|
t=t_,
|
|
noise_fn=noise_fn,
|
|
clip_denoised=clip_denoised,
|
|
return_pred_xstart=False,
|
|
).detach()
|
|
|
|
assert img_t.shape == shape
|
|
return img_t
|
|
|
|
def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t: x):
|
|
assert t >= 1
|
|
|
|
t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t - 1)
|
|
encoding = self.q_sample(x0, t_vec)
|
|
|
|
img_t = encoding
|
|
|
|
for k in reversed(range(0, t)):
|
|
img_t = constrain_fn(img_t, k)
|
|
t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k)
|
|
img_t = self.p_sample(
|
|
denoise_fn=denoise_fn,
|
|
data=img_t,
|
|
t=t_,
|
|
noise_fn=noise_fn,
|
|
clip_denoised=False,
|
|
return_pred_xstart=False,
|
|
use_var=True,
|
|
).detach()
|
|
|
|
return img_t
|
|
|
|
|
|
class PVCNN2(PVCNN2Base):
|
|
sa_blocks = [
|
|
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
|
|
((64, 3, 16), (256, 0.2, 32, (64, 128))),
|
|
((128, 3, 8), (64, 0.4, 32, (128, 256))),
|
|
(None, (16, 0.8, 32, (256, 256, 512))),
|
|
]
|
|
fp_blocks = [
|
|
((256, 256), (256, 3, 8)),
|
|
((256, 256), (256, 3, 8)),
|
|
((256, 128), (128, 2, 16)),
|
|
((128, 128, 64), (64, 2, 32)),
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
num_classes,
|
|
embed_dim,
|
|
use_att,
|
|
dropout,
|
|
extra_feature_channels=3,
|
|
width_multiplier=1,
|
|
voxel_resolution_multiplier=1,
|
|
):
|
|
super().__init__(
|
|
num_classes=num_classes,
|
|
embed_dim=embed_dim,
|
|
use_att=use_att,
|
|
dropout=dropout,
|
|
extra_feature_channels=extra_feature_channels,
|
|
width_multiplier=width_multiplier,
|
|
voxel_resolution_multiplier=voxel_resolution_multiplier,
|
|
)
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type: str):
|
|
super(Model, self).__init__()
|
|
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type)
|
|
|
|
self.model = PVCNN2(
|
|
num_classes=args.nc,
|
|
embed_dim=args.embed_dim,
|
|
use_att=args.attention,
|
|
dropout=args.dropout,
|
|
extra_feature_channels=0,
|
|
)
|
|
|
|
def prior_kl(self, x0):
|
|
return self.diffusion._prior_bpd(x0)
|
|
|
|
def all_kl(self, x0, clip_denoised=True):
|
|
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
|
|
|
|
return {"total_bpd_b": total_bpd_b, "terms_bpd": vals_bt, "prior_bpd_b": prior_bpd_b, "mse_bt": mse_bt}
|
|
|
|
def _denoise(self, data, t):
|
|
B, D, N = data.shape
|
|
assert data.dtype == torch.float
|
|
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
|
|
|
|
out = self.model(data, t)
|
|
|
|
assert out.shape == torch.Size([B, D, N])
|
|
return out
|
|
|
|
def get_loss_iter(self, data, noises=None):
|
|
B, D, N = data.shape
|
|
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
|
|
|
|
if noises is not None:
|
|
noises[t != 0] = torch.randn((t != 0).sum(), *noises.shape[1:]).to(noises)
|
|
|
|
losses = self.diffusion.p_losses(denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
|
|
assert losses.shape == t.shape == torch.Size([B])
|
|
return losses
|
|
|
|
def gen_samples(
|
|
self,
|
|
shape,
|
|
device,
|
|
noise_fn=torch.randn,
|
|
constrain_fn=lambda x, t: x,
|
|
clip_denoised=False,
|
|
max_timestep=None,
|
|
keep_running=False,
|
|
):
|
|
return self.diffusion.p_sample_loop(
|
|
self._denoise,
|
|
shape=shape,
|
|
device=device,
|
|
noise_fn=noise_fn,
|
|
constrain_fn=constrain_fn,
|
|
clip_denoised=clip_denoised,
|
|
max_timestep=max_timestep,
|
|
keep_running=keep_running,
|
|
)
|
|
|
|
def reconstruct(self, x0, t, constrain_fn=lambda x, t: x):
|
|
return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn)
|
|
|
|
def train(self):
|
|
self.model.train()
|
|
|
|
def eval(self):
|
|
self.model.eval()
|
|
|
|
def multi_gpu_wrapper(self, f):
|
|
self.model = f(self.model)
|
|
|
|
|
|
def get_betas(schedule_type, b_start, b_end, time_num):
|
|
if schedule_type == "linear":
|
|
betas = np.linspace(b_start, b_end, time_num)
|
|
elif schedule_type == "warm0.1":
|
|
betas = b_end * np.ones(time_num, dtype=np.float64)
|
|
warmup_time = int(time_num * 0.1)
|
|
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
|
elif schedule_type == "warm0.2":
|
|
betas = b_end * np.ones(time_num, dtype=np.float64)
|
|
warmup_time = int(time_num * 0.2)
|
|
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
|
elif schedule_type == "warm0.5":
|
|
betas = b_end * np.ones(time_num, dtype=np.float64)
|
|
warmup_time = int(time_num * 0.5)
|
|
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
|
else:
|
|
raise NotImplementedError(schedule_type)
|
|
return betas
|
|
|
|
|
|
def get_constrain_function(ground_truth, mask, eps, num_steps=1):
|
|
"""
|
|
|
|
:param target_shape_constraint: target voxels
|
|
:return: constrained x
|
|
"""
|
|
# eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2))
|
|
eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000) ** 2))
|
|
|
|
def constrain_fn(x, t):
|
|
eps_ = eps_all[t] if (t < 1000) else 0
|
|
for _ in range(num_steps):
|
|
x = x - eps_ * ((x - ground_truth) * mask)
|
|
|
|
return x
|
|
|
|
return constrain_fn
|
|
|
|
|
|
#############################################################################
|
|
|
|
|
|
def get_dataset(dataroot, npoints, category, use_mask=False):
|
|
tr_dataset = ShapeNet15kPointClouds(
|
|
root_dir=dataroot,
|
|
categories=[category],
|
|
split="train",
|
|
tr_sample_size=npoints,
|
|
te_sample_size=npoints,
|
|
scale=1.0,
|
|
normalize_per_shape=False,
|
|
normalize_std_per_axis=False,
|
|
random_subsample=True,
|
|
use_mask=use_mask,
|
|
)
|
|
te_dataset = ShapeNet15kPointClouds(
|
|
root_dir=dataroot,
|
|
categories=[category],
|
|
split="val",
|
|
tr_sample_size=npoints,
|
|
te_sample_size=npoints,
|
|
scale=1.0,
|
|
normalize_per_shape=False,
|
|
normalize_std_per_axis=False,
|
|
all_points_mean=tr_dataset.all_points_mean,
|
|
all_points_std=tr_dataset.all_points_std,
|
|
use_mask=use_mask,
|
|
)
|
|
return tr_dataset, te_dataset
|
|
|
|
|
|
def evaluate_gen(opt, ref_pcs, logger):
|
|
if ref_pcs is None:
|
|
_, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.category, use_mask=False)
|
|
test_dataloader = torch.utils.data.DataLoader(
|
|
test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False
|
|
)
|
|
ref = []
|
|
for data in tqdm(test_dataloader, total=len(test_dataloader), desc="Generating Samples"):
|
|
x = data["test_points"]
|
|
m, s = data["mean"].float(), data["std"].float()
|
|
|
|
ref.append(x * s + m)
|
|
|
|
ref_pcs = torch.cat(ref, dim=0).contiguous()
|
|
|
|
logger.info("Loading sample path: %s" % (opt.eval_path))
|
|
sample_pcs = torch.load(opt.eval_path).contiguous()
|
|
|
|
logger.info("Generation sample size:%s reference size: %s" % (sample_pcs.size(), ref_pcs.size()))
|
|
|
|
# Compute metrics
|
|
results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
|
|
results = {k: (v.cpu().detach().item() if not isinstance(v, float) else v) for k, v in results.items()}
|
|
|
|
pprint(results)
|
|
logger.info(results)
|
|
|
|
jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy())
|
|
pprint("JSD: {}".format(jsd))
|
|
logger.info("JSD: {}".format(jsd))
|
|
|
|
|
|
def generate(model, opt):
|
|
_, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.category)
|
|
|
|
test_dataloader = torch.utils.data.DataLoader(
|
|
test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False
|
|
)
|
|
|
|
with torch.no_grad():
|
|
samples = []
|
|
ref = []
|
|
|
|
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc="Generating Samples"):
|
|
x = data["test_points"].transpose(1, 2)
|
|
m, s = data["mean"].float(), data["std"].float()
|
|
|
|
gen = model.gen_samples(x.shape, "cuda", clip_denoised=False).detach().cpu()
|
|
|
|
gen = gen.transpose(1, 2).contiguous()
|
|
x = x.transpose(1, 2).contiguous()
|
|
|
|
gen = gen * s + m
|
|
x = x * s + m
|
|
samples.append(gen)
|
|
ref.append(x)
|
|
|
|
# visualize_pointcloud_batch(
|
|
# os.path.join(str(Path(opt.eval_path).parent), 'x.png'),
|
|
# gen[:64],
|
|
# None,
|
|
# None,
|
|
# None
|
|
# )
|
|
# visualize_voxels(
|
|
# os.path.join(str(Path(opt.eval_path).parent), 'x.png'),
|
|
# gen[:64],
|
|
# 1,
|
|
# 0.5,
|
|
# )
|
|
|
|
# visualize using matplotlib
|
|
import matplotlib
|
|
import matplotlib.cm as cm
|
|
import matplotlib.pyplot as plt
|
|
|
|
matplotlib.use("TkAgg")
|
|
for idx, pc in enumerate(gen[:64]):
|
|
print(f"Visualizing point cloud {idx}...")
|
|
fig = plt.figure()
|
|
ax = fig.add_subplot(111, projection="3d")
|
|
ax.scatter(pc[:, 0], pc[:, 1], pc[:, 2], c=pc[:, 2], cmap=cm.jet)
|
|
ax.set_aspect("equal")
|
|
ax.axis("off")
|
|
# ax.set_xlim(-1, 1)
|
|
# ax.set_ylim(-1, 1)
|
|
# ax.set_zlim(-1, 1)
|
|
plt.show()
|
|
plt.close()
|
|
|
|
samples = torch.cat(samples, dim=0)
|
|
ref = torch.cat(ref, dim=0)
|
|
|
|
torch.save(samples, opt.eval_path)
|
|
|
|
return ref
|
|
|
|
|
|
def main(opt):
|
|
if opt.category == "airplane":
|
|
opt.beta_start = 1e-5
|
|
opt.beta_end = 0.008
|
|
opt.schedule_type = "warm0.1"
|
|
|
|
exp_id = os.path.splitext(os.path.basename(__file__))[0]
|
|
dir_id = os.path.dirname(__file__)
|
|
output_dir = get_output_dir(dir_id, exp_id)
|
|
copy_source(__file__, output_dir)
|
|
logger = setup_logging(output_dir)
|
|
|
|
(outf_syn,) = setup_output_subdirs(output_dir, "syn")
|
|
|
|
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
|
|
model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
|
|
|
|
if opt.cuda:
|
|
model.cuda()
|
|
|
|
def _transform_(m):
|
|
return nn.parallel.DataParallel(m)
|
|
|
|
model = model.cuda()
|
|
model.multi_gpu_wrapper(_transform_)
|
|
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
logger.info("Resume Path:%s" % opt.model)
|
|
|
|
resumed_param = torch.load(opt.model)
|
|
model.load_state_dict(resumed_param["model_state"])
|
|
|
|
ref = None
|
|
if opt.generate:
|
|
opt.eval_path = os.path.join(outf_syn, "samples.pth")
|
|
Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True)
|
|
ref = generate(model, opt)
|
|
|
|
if opt.eval_gen:
|
|
# Evaluate generation
|
|
evaluate_gen(opt, ref, logger)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--dataroot", default="ShapeNetCore.v2.PC15k/")
|
|
parser.add_argument("--category", default="chair")
|
|
|
|
parser.add_argument("--batch_size", type=int, default=50, help="input batch size")
|
|
parser.add_argument("--workers", type=int, default=16, help="workers")
|
|
parser.add_argument("--niter", type=int, default=10000, help="number of epochs to train for")
|
|
|
|
parser.add_argument("--generate", default=True)
|
|
parser.add_argument("--eval_gen", default=True)
|
|
|
|
parser.add_argument("--nc", default=3)
|
|
parser.add_argument("--npoints", default=2048)
|
|
"""model"""
|
|
parser.add_argument("--beta_start", default=0.0001)
|
|
parser.add_argument("--beta_end", default=0.02)
|
|
parser.add_argument("--schedule_type", default="linear")
|
|
parser.add_argument("--time_num", default=1000)
|
|
|
|
# params
|
|
parser.add_argument("--attention", default=True)
|
|
parser.add_argument("--dropout", default=0.1)
|
|
parser.add_argument("--embed_dim", type=int, default=64)
|
|
parser.add_argument("--loss_type", default="mse")
|
|
parser.add_argument("--model_mean_type", default="eps")
|
|
parser.add_argument("--model_var_type", default="fixedsmall")
|
|
|
|
parser.add_argument("--model", default="", required=True, help="path to model (to continue training)")
|
|
|
|
"""eval"""
|
|
|
|
parser.add_argument("--eval_path", default="")
|
|
|
|
parser.add_argument("--manualSeed", default=42, type=int, help="random seed")
|
|
|
|
parser.add_argument("--gpu", type=int, default=0, metavar="S", help="gpu id (default: 0)")
|
|
|
|
opt = parser.parse_args()
|
|
|
|
if torch.cuda.is_available():
|
|
opt.cuda = True
|
|
else:
|
|
opt.cuda = False
|
|
|
|
return opt
|
|
|
|
|
|
if __name__ == "__main__":
|
|
opt = parse_args()
|
|
set_seed(opt)
|
|
|
|
main(opt)
|