2023-04-11 09:12:58 +00:00
|
|
|
import argparse
|
|
|
|
|
2023-04-11 14:01:07 +00:00
|
|
|
import datasets
|
2023-04-11 09:12:58 +00:00
|
|
|
import numpy as np
|
|
|
|
import torch.distributed as dist
|
2021-10-19 20:54:46 +00:00
|
|
|
import torch.multiprocessing as mp
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.optim as optim
|
|
|
|
import torch.utils.data
|
|
|
|
from torch.distributions import Normal
|
|
|
|
|
2023-04-11 14:01:07 +00:00
|
|
|
# from dataset.shapenet_data_pc import ShapeNet15kPointClouds
|
2023-04-11 09:12:58 +00:00
|
|
|
from model.pvcnn_generation import PVCNN2Base
|
2021-10-19 20:54:46 +00:00
|
|
|
from utils.file_utils import *
|
|
|
|
from utils.visualize import *
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
"""
|
2021-10-19 20:54:46 +00:00
|
|
|
some utils
|
2023-04-11 09:12:58 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
2021-10-19 20:54:46 +00:00
|
|
|
def rotation_matrix(axis, theta):
|
|
|
|
"""
|
|
|
|
Return the rotation matrix associated with counterclockwise rotation about
|
|
|
|
the given axis by theta radians.
|
|
|
|
"""
|
|
|
|
axis = np.asarray(axis)
|
|
|
|
axis = axis / np.sqrt(np.dot(axis, axis))
|
|
|
|
a = np.cos(theta / 2.0)
|
|
|
|
b, c, d = -axis * np.sin(theta / 2.0)
|
|
|
|
aa, bb, cc, dd = a * a, b * b, c * c, d * d
|
|
|
|
bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
|
2023-04-11 09:12:58 +00:00
|
|
|
return np.array(
|
|
|
|
[
|
|
|
|
[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
|
|
|
|
[2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
|
|
|
|
[2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc],
|
|
|
|
]
|
|
|
|
)
|
|
|
|
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
def rotate(vertices, faces):
|
2023-04-11 09:12:58 +00:00
|
|
|
"""
|
2021-10-19 20:54:46 +00:00
|
|
|
vertices: [numpoints, 3]
|
2023-04-11 09:12:58 +00:00
|
|
|
"""
|
2021-10-19 20:54:46 +00:00
|
|
|
M = rotation_matrix([0, 1, 0], np.pi / 2).transpose()
|
|
|
|
N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose()
|
|
|
|
K = rotation_matrix([0, 0, 1], np.pi).transpose()
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
v, f = vertices[:, [1, 2, 0]].dot(M).dot(N).dot(K), faces[:, [1, 2, 0]]
|
2021-10-19 20:54:46 +00:00
|
|
|
return v, f
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
|
2021-10-19 20:54:46 +00:00
|
|
|
def norm(v, f):
|
2023-04-11 09:12:58 +00:00
|
|
|
v = (v - v.min()) / (v.max() - v.min()) - 0.5
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
return v, f
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
|
2021-10-19 20:54:46 +00:00
|
|
|
def getGradNorm(net):
|
2023-04-11 09:12:58 +00:00
|
|
|
pNorm = torch.sqrt(sum(torch.sum(p**2) for p in net.parameters()))
|
|
|
|
gradNorm = torch.sqrt(sum(torch.sum(p.grad**2) for p in net.parameters()))
|
2021-10-19 20:54:46 +00:00
|
|
|
return pNorm, gradNorm
|
|
|
|
|
|
|
|
|
|
|
|
def weights_init(m):
|
|
|
|
"""
|
|
|
|
xavier initialization
|
|
|
|
"""
|
|
|
|
classname = m.__class__.__name__
|
2023-04-11 09:12:58 +00:00
|
|
|
if classname.find("Conv") != -1 and m.weight is not None:
|
2021-10-19 20:54:46 +00:00
|
|
|
torch.nn.init.xavier_normal_(m.weight)
|
2023-04-11 09:12:58 +00:00
|
|
|
elif classname.find("BatchNorm") != -1:
|
2021-10-19 20:54:46 +00:00
|
|
|
m.weight.data.normal_()
|
|
|
|
m.bias.data.fill_(0)
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
|
|
|
|
"""
|
2021-10-19 20:54:46 +00:00
|
|
|
models
|
2023-04-11 09:12:58 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
2021-10-19 20:54:46 +00:00
|
|
|
def normal_kl(mean1, logvar1, mean2, logvar2):
|
|
|
|
"""
|
|
|
|
KL divergence between normal distributions parameterized by mean and log-variance.
|
|
|
|
"""
|
2023-04-11 09:12:58 +00:00
|
|
|
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + (mean1 - mean2) ** 2 * torch.exp(-logvar2))
|
|
|
|
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
|
|
|
# Assumes data is integers [0, 1]
|
|
|
|
assert x.shape == means.shape == log_scales.shape
|
|
|
|
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
|
|
|
|
|
|
|
|
centered_x = x - means
|
|
|
|
inv_stdv = torch.exp(-log_scales)
|
|
|
|
plus_in = inv_stdv * (centered_x + 0.5)
|
|
|
|
cdf_plus = px0.cdf(plus_in)
|
2023-04-11 09:12:58 +00:00
|
|
|
min_in = inv_stdv * (centered_x - 0.5)
|
2021-10-19 20:54:46 +00:00
|
|
|
cdf_min = px0.cdf(min_in)
|
2023-04-11 09:12:58 +00:00
|
|
|
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12))
|
|
|
|
log_one_minus_cdf_min = torch.log(torch.max(1.0 - cdf_min, torch.ones_like(cdf_min) * 1e-12))
|
2021-10-19 20:54:46 +00:00
|
|
|
cdf_delta = cdf_plus - cdf_min
|
|
|
|
|
|
|
|
log_probs = torch.where(
|
2023-04-11 09:12:58 +00:00
|
|
|
x < 0.001,
|
|
|
|
log_cdf_plus,
|
|
|
|
torch.where(
|
|
|
|
x > 0.999, log_one_minus_cdf_min, torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12))
|
|
|
|
),
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
assert log_probs.shape == x.shape
|
|
|
|
return log_probs
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
|
2021-10-19 20:54:46 +00:00
|
|
|
class GaussianDiffusion:
|
2023-04-11 09:12:58 +00:00
|
|
|
def __init__(self, betas, loss_type, model_mean_type, model_var_type):
|
2021-10-19 20:54:46 +00:00
|
|
|
self.loss_type = loss_type
|
|
|
|
self.model_mean_type = model_mean_type
|
|
|
|
self.model_var_type = model_var_type
|
|
|
|
assert isinstance(betas, np.ndarray)
|
|
|
|
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
|
|
|
|
assert (betas > 0).all() and (betas <= 1).all()
|
2023-04-11 09:12:58 +00:00
|
|
|
(timesteps,) = betas.shape
|
2021-10-19 20:54:46 +00:00
|
|
|
self.num_timesteps = int(timesteps)
|
|
|
|
|
|
|
|
# initialize twice the actual length so we can keep running for eval
|
|
|
|
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
alphas = 1.0 - betas
|
2021-10-19 20:54:46 +00:00
|
|
|
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
|
2023-04-11 09:12:58 +00:00
|
|
|
alphas_cumprod_prev = torch.from_numpy(np.append(1.0, alphas_cumprod[:-1])).float()
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
self.betas = torch.from_numpy(betas).float()
|
|
|
|
self.alphas_cumprod = alphas_cumprod.float()
|
|
|
|
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
|
|
|
|
|
|
|
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
|
|
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
|
2023-04-11 09:12:58 +00:00
|
|
|
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod).float()
|
|
|
|
self.log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod).float()
|
|
|
|
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod).float()
|
|
|
|
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1).float()
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
betas = torch.from_numpy(betas).float()
|
|
|
|
alphas = torch.from_numpy(alphas).float()
|
|
|
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
2023-04-11 09:12:58 +00:00
|
|
|
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
2021-10-19 20:54:46 +00:00
|
|
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
|
|
|
self.posterior_variance = posterior_variance
|
|
|
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
2023-04-11 09:12:58 +00:00
|
|
|
self.posterior_log_variance_clipped = torch.log(
|
|
|
|
torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))
|
|
|
|
)
|
|
|
|
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
|
|
|
self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _extract(a, t, x_shape):
|
|
|
|
"""
|
|
|
|
Extract some coefficients at specified timesteps,
|
|
|
|
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
|
|
|
|
"""
|
2023-04-11 09:12:58 +00:00
|
|
|
(bs,) = t.shape
|
2021-10-19 20:54:46 +00:00
|
|
|
assert x_shape[0] == bs
|
|
|
|
out = torch.gather(a, 0, t)
|
|
|
|
assert out.shape == torch.Size([bs])
|
|
|
|
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
|
|
|
|
|
|
|
|
def q_mean_variance(self, x_start, t):
|
|
|
|
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
|
2023-04-11 09:12:58 +00:00
|
|
|
variance = self._extract(1.0 - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
|
2021-10-19 20:54:46 +00:00
|
|
|
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
|
|
|
|
return mean, variance, log_variance
|
|
|
|
|
|
|
|
def q_sample(self, x_start, t, noise=None):
|
|
|
|
"""
|
|
|
|
Diffuse the data (t == 0 means diffused for 1 step)
|
|
|
|
"""
|
|
|
|
if noise is None:
|
|
|
|
noise = torch.randn(x_start.shape, device=x_start.device)
|
|
|
|
assert noise.shape == x_start.shape
|
|
|
|
return (
|
2023-04-11 09:12:58 +00:00
|
|
|
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
|
|
|
|
+ self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
|
2021-10-19 20:54:46 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
def q_posterior_mean_variance(self, x_start, x_t, t):
|
|
|
|
"""
|
|
|
|
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
|
|
|
|
"""
|
|
|
|
assert x_start.shape == x_t.shape
|
|
|
|
posterior_mean = (
|
2023-04-11 09:12:58 +00:00
|
|
|
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start
|
|
|
|
+ self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
|
2021-10-19 20:54:46 +00:00
|
|
|
)
|
|
|
|
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
|
2023-04-11 09:12:58 +00:00
|
|
|
posterior_log_variance_clipped = self._extract(
|
|
|
|
self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape
|
|
|
|
)
|
|
|
|
assert (
|
|
|
|
posterior_mean.shape[0]
|
|
|
|
== posterior_variance.shape[0]
|
|
|
|
== posterior_log_variance_clipped.shape[0]
|
|
|
|
== x_start.shape[0]
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
|
|
|
|
|
|
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
|
|
|
|
model_output = denoise_fn(data, t)
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
if self.model_var_type in ["fixedsmall", "fixedlarge"]:
|
2021-10-19 20:54:46 +00:00
|
|
|
# below: only log_variance is used in the KL computations
|
|
|
|
model_variance, model_log_variance = {
|
|
|
|
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
|
2023-04-11 09:12:58 +00:00
|
|
|
"fixedlarge": (
|
|
|
|
self.betas.to(data.device),
|
|
|
|
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device),
|
|
|
|
),
|
|
|
|
"fixedsmall": (
|
|
|
|
self.posterior_variance.to(data.device),
|
|
|
|
self.posterior_log_variance_clipped.to(data.device),
|
|
|
|
),
|
2021-10-19 20:54:46 +00:00
|
|
|
}[self.model_var_type]
|
|
|
|
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data)
|
|
|
|
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(self.model_var_type)
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
if self.model_mean_type == "eps":
|
2021-10-19 20:54:46 +00:00
|
|
|
x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output)
|
|
|
|
|
|
|
|
if clip_denoised:
|
2023-04-11 09:12:58 +00:00
|
|
|
x_recon = torch.clamp(x_recon, -0.5, 0.5)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(self.loss_type)
|
|
|
|
|
|
|
|
assert model_mean.shape == x_recon.shape == data.shape
|
|
|
|
assert model_variance.shape == model_log_variance.shape == data.shape
|
|
|
|
if return_pred_xstart:
|
|
|
|
return model_mean, model_variance, model_log_variance, x_recon
|
|
|
|
else:
|
|
|
|
return model_mean, model_variance, model_log_variance
|
|
|
|
|
|
|
|
def _predict_xstart_from_eps(self, x_t, t, eps):
|
|
|
|
assert x_t.shape == eps.shape
|
|
|
|
return (
|
2023-04-11 09:12:58 +00:00
|
|
|
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t
|
|
|
|
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
|
2021-10-19 20:54:46 +00:00
|
|
|
)
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
""" samples """
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
|
|
|
|
"""
|
|
|
|
Sample from the model
|
|
|
|
"""
|
2023-04-11 09:12:58 +00:00
|
|
|
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
|
|
|
|
denoise_fn, data=data, t=t, clip_denoised=clip_denoised, return_pred_xstart=True
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device)
|
|
|
|
assert noise.shape == data.shape
|
|
|
|
# no noise when t == 0
|
|
|
|
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(data.shape) - 1))
|
|
|
|
|
|
|
|
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
|
|
|
|
assert sample.shape == pred_xstart.shape
|
|
|
|
return (sample, pred_xstart) if return_pred_xstart else sample
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
def p_sample_loop(self, denoise_fn, shape, device, noise_fn=torch.randn, clip_denoised=True, keep_running=False):
|
2021-10-19 20:54:46 +00:00
|
|
|
"""
|
|
|
|
Generate samples
|
|
|
|
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
assert isinstance(shape, (tuple, list))
|
|
|
|
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
|
|
|
|
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
|
|
|
|
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
|
2023-04-11 09:12:58 +00:00
|
|
|
img_t = self.p_sample(
|
|
|
|
denoise_fn=denoise_fn,
|
|
|
|
data=img_t,
|
|
|
|
t=t_,
|
|
|
|
noise_fn=noise_fn,
|
|
|
|
clip_denoised=clip_denoised,
|
|
|
|
return_pred_xstart=False,
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
assert img_t.shape == shape
|
|
|
|
return img_t
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
def p_sample_loop_trajectory(
|
|
|
|
self, denoise_fn, shape, device, freq, noise_fn=torch.randn, clip_denoised=True, keep_running=False
|
|
|
|
):
|
2021-10-19 20:54:46 +00:00
|
|
|
"""
|
|
|
|
Generate samples, returning intermediate images
|
|
|
|
Useful for visualizing how denoised images evolve over time
|
|
|
|
Args:
|
|
|
|
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
|
|
|
|
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
|
|
|
|
"""
|
|
|
|
assert isinstance(shape, (tuple, list))
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
total_steps = self.num_timesteps if not keep_running else len(self.betas)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
|
|
|
|
imgs = [img_t]
|
2023-04-11 09:12:58 +00:00
|
|
|
for t in reversed(range(0, total_steps)):
|
2021-10-19 20:54:46 +00:00
|
|
|
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
|
2023-04-11 09:12:58 +00:00
|
|
|
img_t = self.p_sample(
|
|
|
|
denoise_fn=denoise_fn,
|
|
|
|
data=img_t,
|
|
|
|
t=t_,
|
|
|
|
noise_fn=noise_fn,
|
|
|
|
clip_denoised=clip_denoised,
|
|
|
|
return_pred_xstart=False,
|
|
|
|
)
|
|
|
|
if t % freq == 0 or t == total_steps - 1:
|
2021-10-19 20:54:46 +00:00
|
|
|
imgs.append(img_t)
|
|
|
|
|
|
|
|
assert imgs[-1].shape == shape
|
|
|
|
return imgs
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
"""losses"""
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
|
|
|
|
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start, x_t=data_t, t=t)
|
|
|
|
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
|
2023-04-11 09:12:58 +00:00
|
|
|
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
|
2023-04-11 09:12:58 +00:00
|
|
|
kl = kl.mean(dim=list(range(1, len(data_start.shape)))) / np.log(2.0)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
return (kl, pred_xstart) if return_pred_xstart else kl
|
|
|
|
|
|
|
|
def p_losses(self, denoise_fn, data_start, t, noise=None):
|
|
|
|
"""
|
|
|
|
Training loss calculation
|
|
|
|
"""
|
|
|
|
B, D, N = data_start.shape
|
|
|
|
assert t.shape == torch.Size([B])
|
|
|
|
|
|
|
|
if noise is None:
|
|
|
|
noise = torch.randn(data_start.shape, dtype=data_start.dtype, device=data_start.device)
|
|
|
|
assert noise.shape == data_start.shape and noise.dtype == data_start.dtype
|
|
|
|
|
|
|
|
data_t = self.q_sample(x_start=data_start, t=t, noise=noise)
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
if self.loss_type == "mse":
|
2021-10-19 20:54:46 +00:00
|
|
|
# predict the noise instead of x_start. seems to be weighted naturally like SNR
|
|
|
|
eps_recon = denoise_fn(data_t, t)
|
|
|
|
assert data_t.shape == data_start.shape
|
|
|
|
assert eps_recon.shape == torch.Size([B, D, N])
|
|
|
|
assert eps_recon.shape == data_start.shape
|
2023-04-11 09:12:58 +00:00
|
|
|
losses = ((noise - eps_recon) ** 2).mean(dim=list(range(1, len(data_start.shape))))
|
|
|
|
elif self.loss_type == "kl":
|
2021-10-19 20:54:46 +00:00
|
|
|
losses = self._vb_terms_bpd(
|
2023-04-11 09:12:58 +00:00
|
|
|
denoise_fn=denoise_fn,
|
|
|
|
data_start=data_start,
|
|
|
|
data_t=data_t,
|
|
|
|
t=t,
|
|
|
|
clip_denoised=False,
|
|
|
|
return_pred_xstart=False,
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
else:
|
|
|
|
raise NotImplementedError(self.loss_type)
|
|
|
|
|
|
|
|
assert losses.shape == torch.Size([B])
|
|
|
|
return losses
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
"""debug"""
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
def _prior_bpd(self, x_start):
|
|
|
|
with torch.no_grad():
|
|
|
|
B, T = x_start.shape[0], self.num_timesteps
|
2023-04-11 09:12:58 +00:00
|
|
|
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T - 1)
|
2021-10-19 20:54:46 +00:00
|
|
|
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
|
2023-04-11 09:12:58 +00:00
|
|
|
kl_prior = normal_kl(
|
|
|
|
mean1=qt_mean,
|
|
|
|
logvar1=qt_log_variance,
|
|
|
|
mean2=torch.tensor([0.0]).to(qt_mean),
|
|
|
|
logvar2=torch.tensor([0.0]).to(qt_log_variance),
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
assert kl_prior.shape == x_start.shape
|
2023-04-11 09:12:58 +00:00
|
|
|
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.0)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
|
|
|
|
with torch.no_grad():
|
|
|
|
B, T = x_start.shape[0], self.num_timesteps
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
vals_bt_, mse_bt_ = torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
|
2021-10-19 20:54:46 +00:00
|
|
|
for t in reversed(range(T)):
|
|
|
|
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
|
|
|
|
# Calculate VLB term at the current timestep
|
|
|
|
new_vals_b, pred_xstart = self._vb_terms_bpd(
|
2023-04-11 09:12:58 +00:00
|
|
|
denoise_fn,
|
|
|
|
data_start=x_start,
|
|
|
|
data_t=self.q_sample(x_start=x_start, t=t_b),
|
|
|
|
t=t_b,
|
|
|
|
clip_denoised=clip_denoised,
|
|
|
|
return_pred_xstart=True,
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
# MSE for progressive prediction loss
|
|
|
|
assert pred_xstart.shape == x_start.shape
|
2023-04-11 09:12:58 +00:00
|
|
|
new_mse_b = ((pred_xstart - x_start) ** 2).mean(dim=list(range(1, len(x_start.shape))))
|
|
|
|
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
|
2021-10-19 20:54:46 +00:00
|
|
|
# Insert the calculated term into the tensor of all terms
|
2023-04-11 09:12:58 +00:00
|
|
|
mask_bt = t_b[:, None] == torch.arange(T, device=t_b.device)[None, :].float()
|
2021-10-19 20:54:46 +00:00
|
|
|
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
|
|
|
|
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
|
|
|
|
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
|
|
|
|
|
|
|
|
prior_bpd_b = self._prior_bpd(x_start)
|
|
|
|
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
|
2023-04-11 09:12:58 +00:00
|
|
|
assert vals_bt_.shape == mse_bt_.shape == torch.Size(
|
|
|
|
[B, T]
|
|
|
|
) and total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
|
2021-10-19 20:54:46 +00:00
|
|
|
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
|
|
|
|
|
|
|
|
|
|
|
|
class PVCNN2(PVCNN2Base):
|
|
|
|
sa_blocks = [
|
|
|
|
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
|
|
|
|
((64, 3, 16), (256, 0.2, 32, (64, 128))),
|
|
|
|
((128, 3, 8), (64, 0.4, 32, (128, 256))),
|
|
|
|
(None, (16, 0.8, 32, (256, 256, 512))),
|
|
|
|
]
|
|
|
|
fp_blocks = [
|
|
|
|
((256, 256), (256, 3, 8)),
|
|
|
|
((256, 256), (256, 3, 8)),
|
|
|
|
((256, 128), (128, 2, 16)),
|
|
|
|
((128, 128, 64), (64, 2, 32)),
|
|
|
|
]
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
num_classes,
|
|
|
|
embed_dim,
|
|
|
|
use_att,
|
|
|
|
dropout,
|
|
|
|
extra_feature_channels=3,
|
|
|
|
width_multiplier=1,
|
|
|
|
voxel_resolution_multiplier=1,
|
|
|
|
):
|
2021-10-19 20:54:46 +00:00
|
|
|
super().__init__(
|
2023-04-11 09:12:58 +00:00
|
|
|
num_classes=num_classes,
|
|
|
|
embed_dim=embed_dim,
|
|
|
|
use_att=use_att,
|
|
|
|
dropout=dropout,
|
|
|
|
extra_feature_channels=extra_feature_channels,
|
|
|
|
width_multiplier=width_multiplier,
|
|
|
|
voxel_resolution_multiplier=voxel_resolution_multiplier,
|
2021-10-19 20:54:46 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class Model(nn.Module):
|
2023-04-11 09:12:58 +00:00
|
|
|
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type: str):
|
2021-10-19 20:54:46 +00:00
|
|
|
super(Model, self).__init__()
|
|
|
|
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type)
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
self.model = PVCNN2(
|
|
|
|
num_classes=args.nc,
|
|
|
|
embed_dim=args.embed_dim,
|
|
|
|
use_att=args.attention,
|
|
|
|
dropout=args.dropout,
|
|
|
|
extra_feature_channels=0,
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
def prior_kl(self, x0):
|
|
|
|
return self.diffusion._prior_bpd(x0)
|
|
|
|
|
|
|
|
def all_kl(self, x0, clip_denoised=True):
|
2023-04-11 09:12:58 +00:00
|
|
|
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
return {"total_bpd_b": total_bpd_b, "terms_bpd": vals_bt, "prior_bpd_b": prior_bpd_b, "mse_bt": mse_bt}
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
def _denoise(self, data, t):
|
2023-04-11 09:12:58 +00:00
|
|
|
B, D, N = data.shape
|
2021-10-19 20:54:46 +00:00
|
|
|
assert data.dtype == torch.float
|
|
|
|
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
|
|
|
|
|
|
|
|
out = self.model(data, t)
|
|
|
|
|
|
|
|
assert out.shape == torch.Size([B, D, N])
|
|
|
|
return out
|
|
|
|
|
|
|
|
def get_loss_iter(self, data, noises=None):
|
|
|
|
B, D, N = data.shape
|
|
|
|
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
|
|
|
|
|
|
|
|
if noises is not None:
|
2023-04-11 09:12:58 +00:00
|
|
|
noises[t != 0] = torch.randn((t != 0).sum(), *noises.shape[1:]).to(noises)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
losses = self.diffusion.p_losses(denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
|
2021-10-19 20:54:46 +00:00
|
|
|
assert losses.shape == t.shape == torch.Size([B])
|
|
|
|
return losses
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
def gen_samples(self, shape, device, noise_fn=torch.randn, clip_denoised=True, keep_running=False):
|
|
|
|
return self.diffusion.p_sample_loop(
|
|
|
|
self._denoise,
|
|
|
|
shape=shape,
|
|
|
|
device=device,
|
|
|
|
noise_fn=noise_fn,
|
|
|
|
clip_denoised=clip_denoised,
|
|
|
|
keep_running=keep_running,
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
def gen_sample_traj(self, shape, device, freq, noise_fn=torch.randn, clip_denoised=True, keep_running=False):
|
|
|
|
return self.diffusion.p_sample_loop_trajectory(
|
|
|
|
self._denoise,
|
|
|
|
shape=shape,
|
|
|
|
device=device,
|
|
|
|
noise_fn=noise_fn,
|
|
|
|
freq=freq,
|
|
|
|
clip_denoised=clip_denoised,
|
|
|
|
keep_running=keep_running,
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
def train(self):
|
|
|
|
self.model.train()
|
|
|
|
|
|
|
|
def eval(self):
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
def multi_gpu_wrapper(self, f):
|
|
|
|
self.model = f(self.model)
|
|
|
|
|
|
|
|
|
|
|
|
def get_betas(schedule_type, b_start, b_end, time_num):
|
2023-04-11 09:12:58 +00:00
|
|
|
if schedule_type == "linear":
|
2021-10-19 20:54:46 +00:00
|
|
|
betas = np.linspace(b_start, b_end, time_num)
|
2023-04-11 09:12:58 +00:00
|
|
|
elif schedule_type == "warm0.1":
|
2021-10-19 20:54:46 +00:00
|
|
|
betas = b_end * np.ones(time_num, dtype=np.float64)
|
|
|
|
warmup_time = int(time_num * 0.1)
|
|
|
|
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
2023-04-11 09:12:58 +00:00
|
|
|
elif schedule_type == "warm0.2":
|
2021-10-19 20:54:46 +00:00
|
|
|
betas = b_end * np.ones(time_num, dtype=np.float64)
|
|
|
|
warmup_time = int(time_num * 0.2)
|
|
|
|
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
2023-04-11 09:12:58 +00:00
|
|
|
elif schedule_type == "warm0.5":
|
2021-10-19 20:54:46 +00:00
|
|
|
betas = b_end * np.ones(time_num, dtype=np.float64)
|
|
|
|
warmup_time = int(time_num * 0.5)
|
|
|
|
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(schedule_type)
|
|
|
|
return betas
|
|
|
|
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
def get_dataset(dataroot, npoints, category):
|
2023-04-11 14:01:07 +00:00
|
|
|
# tr_dataset = ShapeNet15kPointClouds(
|
|
|
|
# root_dir=dataroot,
|
|
|
|
# categories=[category],
|
|
|
|
# split="train",
|
|
|
|
# tr_sample_size=npoints,
|
|
|
|
# te_sample_size=npoints,
|
|
|
|
# scale=1.0,
|
|
|
|
# normalize_per_shape=False,
|
|
|
|
# normalize_std_per_axis=False,
|
|
|
|
# random_subsample=True,
|
|
|
|
# )
|
|
|
|
# te_dataset = ShapeNet15kPointClouds(
|
|
|
|
# root_dir=dataroot,
|
|
|
|
# categories=[category],
|
|
|
|
# split="val",
|
|
|
|
# tr_sample_size=npoints,
|
|
|
|
# te_sample_size=npoints,
|
|
|
|
# scale=1.0,
|
|
|
|
# normalize_per_shape=False,
|
|
|
|
# normalize_std_per_axis=False,
|
|
|
|
# all_points_mean=tr_dataset.all_points_mean,
|
|
|
|
# all_points_std=tr_dataset.all_points_std,
|
|
|
|
# )
|
|
|
|
train_ds = datasets.load_dataset("dataset/rotor37_data.py", split="train")
|
|
|
|
train_ds = train_ds.with_format("torch")
|
|
|
|
|
|
|
|
test_ds = datasets.load_dataset("dataset/rotor37_data.py", split="test")
|
|
|
|
test_ds = test_ds.with_format("torch")
|
|
|
|
return train_ds, test_ds
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
|
|
|
|
def get_dataloader(opt, train_dataset, test_dataset=None):
|
2023-04-11 09:12:58 +00:00
|
|
|
if opt.distribution_type == "multi":
|
2021-10-19 20:54:46 +00:00
|
|
|
train_sampler = torch.utils.data.distributed.DistributedSampler(
|
2023-04-11 09:12:58 +00:00
|
|
|
train_dataset, num_replicas=opt.world_size, rank=opt.rank
|
2021-10-19 20:54:46 +00:00
|
|
|
)
|
|
|
|
if test_dataset is not None:
|
|
|
|
test_sampler = torch.utils.data.distributed.DistributedSampler(
|
2023-04-11 09:12:58 +00:00
|
|
|
test_dataset, num_replicas=opt.world_size, rank=opt.rank
|
2021-10-19 20:54:46 +00:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
test_sampler = None
|
|
|
|
else:
|
|
|
|
train_sampler = None
|
|
|
|
test_sampler = None
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
train_dataloader = torch.utils.data.DataLoader(
|
|
|
|
train_dataset,
|
|
|
|
batch_size=opt.bs,
|
|
|
|
sampler=train_sampler,
|
|
|
|
shuffle=train_sampler is None,
|
|
|
|
num_workers=int(opt.workers),
|
|
|
|
drop_last=True,
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
if test_dataset is not None:
|
2023-04-11 09:12:58 +00:00
|
|
|
test_dataloader = torch.utils.data.DataLoader(
|
|
|
|
train_dataset,
|
|
|
|
batch_size=opt.bs,
|
|
|
|
sampler=test_sampler,
|
|
|
|
shuffle=False,
|
|
|
|
num_workers=int(opt.workers),
|
|
|
|
drop_last=False,
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
else:
|
|
|
|
test_dataloader = None
|
|
|
|
|
|
|
|
return train_dataloader, test_dataloader, train_sampler, test_sampler
|
|
|
|
|
|
|
|
|
2023-04-11 14:01:07 +00:00
|
|
|
def train(gpu, opt, output_dir):
|
2021-10-19 20:54:46 +00:00
|
|
|
set_seed(opt)
|
|
|
|
logger = setup_logging(output_dir)
|
2023-04-11 09:12:58 +00:00
|
|
|
if opt.distribution_type == "multi":
|
|
|
|
should_diag = gpu == 0
|
2021-10-19 20:54:46 +00:00
|
|
|
else:
|
|
|
|
should_diag = True
|
|
|
|
if should_diag:
|
2023-04-11 09:12:58 +00:00
|
|
|
(outf_syn,) = setup_output_subdirs(output_dir, "syn")
|
2021-10-19 20:54:46 +00:00
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
if opt.distribution_type == "multi":
|
2021-10-19 20:54:46 +00:00
|
|
|
if opt.dist_url == "env://" and opt.rank == -1:
|
|
|
|
opt.rank = int(os.environ["RANK"])
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
base_rank = opt.rank * opt.ngpus_per_node
|
2021-10-19 20:54:46 +00:00
|
|
|
opt.rank = base_rank + gpu
|
2023-04-11 09:12:58 +00:00
|
|
|
dist.init_process_group(
|
|
|
|
backend=opt.dist_backend, init_method=opt.dist_url, world_size=opt.world_size, rank=opt.rank
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
opt.bs = int(opt.bs / opt.ngpus_per_node)
|
|
|
|
opt.workers = 0
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
opt.saveIter = int(opt.saveIter / opt.ngpus_per_node)
|
2021-10-19 20:54:46 +00:00
|
|
|
opt.diagIter = int(opt.diagIter / opt.ngpus_per_node)
|
|
|
|
opt.vizIter = int(opt.vizIter / opt.ngpus_per_node)
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
""" data """
|
2021-10-19 20:54:46 +00:00
|
|
|
train_dataset, _ = get_dataset(opt.dataroot, opt.npoints, opt.category)
|
|
|
|
dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None)
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
"""
|
2021-10-19 20:54:46 +00:00
|
|
|
create networks
|
2023-04-11 09:12:58 +00:00
|
|
|
"""
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
|
|
|
|
model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
if opt.distribution_type == "multi": # Multiple processes, single GPU per process
|
|
|
|
|
2021-10-19 20:54:46 +00:00
|
|
|
def _transform_(m):
|
2023-04-11 09:12:58 +00:00
|
|
|
return nn.parallel.DistributedDataParallel(m, device_ids=[gpu], output_device=gpu)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
torch.cuda.set_device(gpu)
|
|
|
|
model.cuda(gpu)
|
|
|
|
model.multi_gpu_wrapper(_transform_)
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
elif opt.distribution_type == "single":
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
def _transform_(m):
|
|
|
|
return nn.parallel.DataParallel(m)
|
2023-04-11 09:12:58 +00:00
|
|
|
|
2021-10-19 20:54:46 +00:00
|
|
|
model = model.cuda()
|
|
|
|
model.multi_gpu_wrapper(_transform_)
|
|
|
|
|
|
|
|
elif gpu is not None:
|
|
|
|
torch.cuda.set_device(gpu)
|
|
|
|
model = model.cuda(gpu)
|
|
|
|
else:
|
2023-04-11 09:12:58 +00:00
|
|
|
raise ValueError("distribution_type = multi | single | None")
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
if should_diag:
|
|
|
|
logger.info(opt)
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
optimizer = optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.decay, betas=(opt.beta1, 0.999))
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.lr_gamma)
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
if opt.model != "":
|
2021-10-19 20:54:46 +00:00
|
|
|
ckpt = torch.load(opt.model)
|
2023-04-11 09:12:58 +00:00
|
|
|
model.load_state_dict(ckpt["model_state"])
|
|
|
|
optimizer.load_state_dict(ckpt["optimizer_state"])
|
2021-10-19 20:54:46 +00:00
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
if opt.model != "":
|
|
|
|
start_epoch = torch.load(opt.model)["epoch"] + 1
|
2021-10-19 20:54:46 +00:00
|
|
|
else:
|
|
|
|
start_epoch = 0
|
|
|
|
|
|
|
|
def new_x_chain(x, num_chain):
|
|
|
|
return torch.randn(num_chain, *x.shape[1:], device=x.device)
|
|
|
|
|
|
|
|
for epoch in range(start_epoch, opt.niter):
|
2023-04-11 09:12:58 +00:00
|
|
|
if opt.distribution_type == "multi":
|
2021-10-19 20:54:46 +00:00
|
|
|
train_sampler.set_epoch(epoch)
|
|
|
|
|
|
|
|
lr_scheduler.step(epoch)
|
|
|
|
|
|
|
|
for i, data in enumerate(dataloader):
|
2023-04-11 14:01:07 +00:00
|
|
|
# x = data["train_points"].transpose(1, 2)
|
|
|
|
x = data["positions"].transpose(1, 2)
|
|
|
|
noises_batch = torch.randn_like(x)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
"""
|
2021-10-19 20:54:46 +00:00
|
|
|
train diffusion
|
2023-04-11 09:12:58 +00:00
|
|
|
"""
|
2021-10-19 20:54:46 +00:00
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
if opt.distribution_type == "multi" or (opt.distribution_type is None and gpu is not None):
|
2021-10-19 20:54:46 +00:00
|
|
|
x = x.cuda(gpu)
|
|
|
|
noises_batch = noises_batch.cuda(gpu)
|
2023-04-11 09:12:58 +00:00
|
|
|
elif opt.distribution_type == "single":
|
2021-10-19 20:54:46 +00:00
|
|
|
x = x.cuda()
|
|
|
|
noises_batch = noises_batch.cuda()
|
|
|
|
|
|
|
|
loss = model.get_loss_iter(x, noises_batch).mean()
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
loss.backward()
|
|
|
|
netpNorm, netgradNorm = getGradNorm(model)
|
|
|
|
if opt.grad_clip is not None:
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
|
|
|
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
if i % opt.print_freq == 0 and should_diag:
|
2023-04-11 09:12:58 +00:00
|
|
|
logger.info(
|
|
|
|
"[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, "
|
|
|
|
"netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} ".format(
|
|
|
|
epoch,
|
|
|
|
opt.niter,
|
|
|
|
i,
|
|
|
|
len(dataloader),
|
|
|
|
loss.item(),
|
|
|
|
netpNorm,
|
|
|
|
netgradNorm,
|
|
|
|
)
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
if (epoch + 1) % opt.diagIter == 0 and should_diag:
|
2023-04-11 09:12:58 +00:00
|
|
|
logger.info("Diagnosis:")
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
x_range = [x.min().item(), x.max().item()]
|
|
|
|
kl_stats = model.all_kl(x)
|
2023-04-11 09:12:58 +00:00
|
|
|
logger.info(
|
|
|
|
" [{:>3d}/{:>3d}] "
|
|
|
|
"x_range: [{:>10.4f}, {:>10.4f}], "
|
|
|
|
"total_bpd_b: {:>10.4f}, "
|
|
|
|
"terms_bpd: {:>10.4f}, "
|
|
|
|
"prior_bpd_b: {:>10.4f} "
|
|
|
|
"mse_bt: {:>10.4f} ".format(
|
|
|
|
epoch,
|
|
|
|
opt.niter,
|
|
|
|
*x_range,
|
|
|
|
kl_stats["total_bpd_b"].item(),
|
|
|
|
kl_stats["terms_bpd"].item(),
|
|
|
|
kl_stats["prior_bpd_b"].item(),
|
|
|
|
kl_stats["mse_bt"].item(),
|
|
|
|
)
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
if (epoch + 1) % opt.vizIter == 0 and should_diag:
|
2023-04-11 09:12:58 +00:00
|
|
|
logger.info("Generation: eval")
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
|
|
x_gen_eval = model.gen_samples(new_x_chain(x, 25).shape, x.device, clip_denoised=False)
|
|
|
|
x_gen_list = model.gen_sample_traj(new_x_chain(x, 1).shape, x.device, freq=40, clip_denoised=False)
|
|
|
|
x_gen_all = torch.cat(x_gen_list, dim=0)
|
|
|
|
|
|
|
|
gen_stats = [x_gen_eval.mean(), x_gen_eval.std()]
|
|
|
|
gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()]
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
logger.info(
|
|
|
|
" [{:>3d}/{:>3d}] "
|
|
|
|
"eval_gen_range: [{:>10.4f}, {:>10.4f}] "
|
|
|
|
"eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] ".format(
|
|
|
|
epoch,
|
|
|
|
opt.niter,
|
|
|
|
*gen_eval_range,
|
|
|
|
*gen_stats,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
visualize_pointcloud_batch(
|
|
|
|
"%s/epoch_%03d_samples_eval.png" % (outf_syn, epoch), x_gen_eval.transpose(1, 2), None, None, None
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
visualize_pointcloud_batch(
|
|
|
|
"%s/epoch_%03d_samples_eval_all.png" % (outf_syn, epoch), x_gen_all.transpose(1, 2), None, None, None
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
visualize_pointcloud_batch("%s/epoch_%03d_x.png" % (outf_syn, epoch), x.transpose(1, 2), None, None, None)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
logger.info("Generation: train")
|
2021-10-19 20:54:46 +00:00
|
|
|
model.train()
|
|
|
|
|
|
|
|
if (epoch + 1) % opt.saveIter == 0:
|
|
|
|
if should_diag:
|
|
|
|
save_dict = {
|
2023-04-11 09:12:58 +00:00
|
|
|
"epoch": epoch,
|
|
|
|
"model_state": model.state_dict(),
|
|
|
|
"optimizer_state": optimizer.state_dict(),
|
2021-10-19 20:54:46 +00:00
|
|
|
}
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
torch.save(save_dict, "%s/epoch_%d.pth" % (output_dir, epoch))
|
2021-10-19 20:54:46 +00:00
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
if opt.distribution_type == "multi":
|
2021-10-19 20:54:46 +00:00
|
|
|
dist.barrier()
|
2023-04-11 09:12:58 +00:00
|
|
|
map_location = {"cuda:%d" % 0: "cuda:%d" % gpu}
|
2021-10-19 20:54:46 +00:00
|
|
|
model.load_state_dict(
|
2023-04-11 09:12:58 +00:00
|
|
|
torch.load("%s/epoch_%d.pth" % (output_dir, epoch), map_location=map_location)["model_state"]
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
dist.destroy_process_group()
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
|
2021-10-19 20:54:46 +00:00
|
|
|
def main():
|
|
|
|
opt = parse_args()
|
2023-04-11 09:12:58 +00:00
|
|
|
if opt.category == "airplane":
|
2021-10-19 20:54:46 +00:00
|
|
|
opt.beta_start = 1e-5
|
|
|
|
opt.beta_end = 0.008
|
2023-04-11 09:12:58 +00:00
|
|
|
opt.schedule_type = "warm0.1"
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
exp_id = os.path.splitext(os.path.basename(__file__))[0]
|
|
|
|
dir_id = os.path.dirname(__file__)
|
|
|
|
output_dir = get_output_dir(dir_id, exp_id)
|
|
|
|
copy_source(__file__, output_dir)
|
|
|
|
|
|
|
|
if opt.dist_url == "env://" and opt.world_size == -1:
|
|
|
|
opt.world_size = int(os.environ["WORLD_SIZE"])
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
if opt.distribution_type == "multi":
|
2021-10-19 20:54:46 +00:00
|
|
|
opt.ngpus_per_node = torch.cuda.device_count()
|
|
|
|
opt.world_size = opt.ngpus_per_node * opt.world_size
|
|
|
|
mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init))
|
|
|
|
else:
|
2023-04-11 14:01:07 +00:00
|
|
|
train(opt.gpu, opt, output_dir)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
|
parser = argparse.ArgumentParser()
|
2023-04-11 09:12:58 +00:00
|
|
|
parser.add_argument("--dataroot", default="ShapeNetCore.v2.PC15k/")
|
|
|
|
parser.add_argument("--category", default="chair")
|
|
|
|
|
|
|
|
parser.add_argument("--bs", type=int, default=16, help="input batch size")
|
|
|
|
parser.add_argument("--workers", type=int, default=16, help="workers")
|
|
|
|
parser.add_argument("--niter", type=int, default=10000, help="number of epochs to train for")
|
|
|
|
|
|
|
|
parser.add_argument("--nc", default=3)
|
|
|
|
parser.add_argument("--npoints", default=2048)
|
|
|
|
"""model"""
|
|
|
|
parser.add_argument("--beta_start", default=0.0001)
|
|
|
|
parser.add_argument("--beta_end", default=0.02)
|
|
|
|
parser.add_argument("--schedule_type", default="linear")
|
|
|
|
parser.add_argument("--time_num", default=1000)
|
|
|
|
|
|
|
|
# params
|
|
|
|
parser.add_argument("--attention", default=True)
|
|
|
|
parser.add_argument("--dropout", default=0.1)
|
|
|
|
parser.add_argument("--embed_dim", type=int, default=64)
|
|
|
|
parser.add_argument("--loss_type", default="mse")
|
|
|
|
parser.add_argument("--model_mean_type", default="eps")
|
|
|
|
parser.add_argument("--model_var_type", default="fixedsmall")
|
|
|
|
|
|
|
|
parser.add_argument("--lr", type=float, default=2e-4, help="learning rate for E, default=0.0002")
|
|
|
|
parser.add_argument("--beta1", type=float, default=0.5, help="beta1 for adam. default=0.5")
|
|
|
|
parser.add_argument("--decay", type=float, default=0, help="weight decay for EBM")
|
|
|
|
parser.add_argument("--grad_clip", type=float, default=None, help="weight decay for EBM")
|
|
|
|
parser.add_argument("--lr_gamma", type=float, default=0.998, help="lr decay for EBM")
|
|
|
|
|
|
|
|
parser.add_argument("--model", default="", help="path to model (to continue training)")
|
|
|
|
|
|
|
|
"""distributed"""
|
|
|
|
parser.add_argument("--world_size", default=1, type=int, help="Number of distributed nodes.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--dist_url", default="tcp://127.0.0.1:9991", type=str, help="url used to set up distributed training"
|
|
|
|
)
|
|
|
|
parser.add_argument("--dist_backend", default="nccl", type=str, help="distributed backend")
|
|
|
|
parser.add_argument(
|
|
|
|
"--distribution_type",
|
|
|
|
default="single",
|
|
|
|
choices=["multi", "single", None],
|
|
|
|
help="Use multi-processing distributed training to launch "
|
|
|
|
"N processes per node, which has N GPUs. This is the "
|
|
|
|
"fastest way to use PyTorch for either single node or "
|
|
|
|
"multi node data parallel training",
|
|
|
|
)
|
|
|
|
parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training")
|
|
|
|
parser.add_argument("--gpu", default=None, type=int, help="GPU id to use. None means using all available GPUs.")
|
|
|
|
|
|
|
|
"""eval"""
|
2023-04-11 15:32:30 +00:00
|
|
|
parser.add_argument("--saveIter", default=100, type=int, help="unit: epoch")
|
2023-04-11 09:12:58 +00:00
|
|
|
parser.add_argument("--diagIter", default=50, help="unit: epoch")
|
|
|
|
parser.add_argument("--vizIter", default=50, help="unit: epoch")
|
|
|
|
parser.add_argument("--print_freq", default=50, help="unit: iter")
|
2021-10-19 20:54:46 +00:00
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
parser.add_argument("--manualSeed", default=42, type=int, help="random seed")
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
opt = parser.parse_args()
|
|
|
|
|
|
|
|
return opt
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2021-10-19 20:54:46 +00:00
|
|
|
main()
|