LION/utils/diffusion_pvd.py

564 lines
24 KiB
Python
Raw Normal View History

2023-01-23 05:14:49 +00:00
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
"""copied and modified from https://github.com/NVlabs/LSGM/blob/5eae2f385c014f2250c3130152b6be711f6a3a5a/diffusion_discretized.py"""
import torch
from torch.cuda.amp import autocast
import numpy as np
from utils.diffusion import make_beta_schedule
from utils import utils
from loguru import logger
class DiffusionDiscretized(object):
"""
This class constructs the diffusion process and provides all related methods and constants.
"""
def __init__(self, args, var_fun, cfg): # alpha_bars_fun
self.cfg = cfg
self._diffusion_steps = cfg.ddpm.num_steps # args.diffusion_steps
self._denoising_stddevs = 'beta' # args.denoising_stddevs
#self._var_fun = var_fun
beta_start = cfg.ddpm.beta_1
beta_end = cfg.ddpm.beta_T
mode = cfg.ddpm.sched_mode
num_steps = cfg.ddpm.num_steps
self.p2_gamma = cfg.ddpm.p2_gamma
self.p2_k = cfg.ddpm.p2_k
self.use_p2_weight = self.cfg.ddpm.use_p2_weight
logger.info(
f'[Build Discrete Diffusion object] beta_start={beta_start}, beta_end={beta_end}, mode={mode}, num_steps={num_steps}')
self.betas = make_beta_schedule(
mode, beta_start, beta_end, num_steps).numpy()
self._betas_init, self._alphas, self._alpha_bars, self._betas_post_init, self.snr = \
self._generate_base_constants(
diffusion_steps=self._diffusion_steps)
def iw_quantities_t(self, B, timestep, *args):
timestep = timestep.view(B)
timestep = timestep + 1 # [1,T]
alpha_bars = torch.gather(self._alpha_bars, 0, timestep-1) # [0,T-1]
weight_init = alpha_bars_sqrt = torch.sqrt(alpha_bars)
weight_noise_power = 1.0 - alpha_bars
weight_noise_power = weight_noise_power[:, None, None, None]
weight_init = weight_init[:, None, None, None]
if self.use_p2_weight:
p2_weight = torch.gather(
1 / (self.p2_k + self.snr)**self.p2_gamma, 0, timestep-1).view(B)
loss_weight = p2_weight
else:
loss_weight = 1.0
return timestep, weight_noise_power, weight_init, loss_weight, None, None
def iw_quantities(self, B, *args):
rho = torch.rand(size=[B], device='cuda') * self._diffusion_steps
timestep = rho.type(torch.int64) # [0, T-1]
assert(timestep.max() <= self._diffusion_steps -
1), f'get max at {timestep.max()}'
timestep = timestep + 1 # [1,T]
alpha_bars = torch.gather(self._alpha_bars, 0, timestep-1) # [0,T-1]
weight_init = alpha_bars_sqrt = torch.sqrt(alpha_bars)
weight_noise_power = 1.0 - alpha_bars
weight_noise_power = weight_noise_power[:, None, None, None]
weight_init = weight_init[:, None, None, None]
if self.use_p2_weight:
p2_weight = torch.gather(
1 / (self.p2_k + self.snr)**self.p2_gamma, 0, timestep-1).view(B)
loss_weight = p2_weight
else:
loss_weight = 1.0
return timestep, weight_noise_power, weight_init, loss_weight, None, None
def debug_sheduler(self):
rho = torch.range(0, 1000-1).cuda() # / 1000.0 + time_eps
timestep = rho.type(torch.int64) # [0, T-1]
assert(timestep.max() <= self._diffusion_steps -
1), f'get max at {timestep.max()}'
timestep = timestep + 1 # [1,T]
alpha_bars = torch.gather(self._alpha_bars, 0, timestep-1) # [0,T-1]
weight_init = alpha_bars_sqrt = torch.sqrt(alpha_bars)
weight_noise_power = 1.0 - alpha_bars
weight_noise_power = weight_noise_power[:, None, None, None]
weight_init = weight_init[:, None, None, None]
return timestep, weight_noise_power, weight_init, 1, None, None
def sample_q(self, x_init, noise, var_t, m_t):
""" returns a sample from diffusion process at time t
x_init: [B,ND,1,1]
noise:
vae_t: weight noise; [B,1,1,1]
m_t: weight init; [B,1,1,1]
"""
assert(len(x_init.shape) == 4)
assert(len(var_t.shape) == 4)
assert(len(m_t.shape) == 4)
#CHECK4D(x_init)
#CHECK4D(var_t)
#CHECK4D(m_t)
#CHECKEQ(x_init.shape[0], m_t.shape[0])
assert(x_init.shape[0] == m_t.shape[0])
output = m_t * x_init + torch.sqrt(var_t) * noise
return output
def cross_entropy_const(self, ode_eps):
return 0
def _generate_base_constants(self, diffusion_steps):
"""
Generates torch tensors with basic constants for all timesteps.
"""
betas_np = self.betas # self._generate_betas_from_continuous_fun(diffusion_steps)
alphas_np = 1.0 - betas_np
alphas_cumprod = alpha_bars_np = np.cumprod(alphas_np)
snr = 1.0 / (1 - alphas_cumprod) - 1
# posterior variances only make sense for t>1, hence the array is short by 1
betas_post_np = betas_np[1:] * \
(1.0 - alpha_bars_np[:-1]) / (1.0 - alpha_bars_np[1:])
# we add beta_post_2 to the beginning of both beta arrays, since this is used as final decoder variance and
# requires special treatment (as in diffusion paper)
betas_post_init_np = np.append(betas_post_np[0], betas_post_np)
#betas_init_np = np.append(betas_post_np[0], betas_np[1:])
betas_init = torch.from_numpy(betas_np).float().cuda()
snr = torch.from_numpy(snr).float().cuda()
alphas = torch.from_numpy(alphas_np).float().cuda()
alpha_bars = torch.from_numpy(alpha_bars_np).float().cuda()
betas_post_init = torch.from_numpy(betas_post_init_np).float().cuda()
return betas_init, alphas, alpha_bars, betas_post_init, snr
# def _generate_betas_from_continuous_fun(self, diffusion_steps):
# t = np.arange(1, diffusion_steps + 1, dtype=np.float64)
# t = t / diffusion_steps
# # alpha_bars = self._alpha_bars_fun(t)
# alpha_bars = 1.0 - self._var_fun(torch.tensor(t)).numpy()
# betas = 1 - alpha_bars[1:] / alpha_bars[:-1]
# betas = np.hstack((1 - alpha_bars[0], betas))
# return betas
def get_p_log_scales(self, timestep, stddev_type):
"""
Grab log std devs. of backward denoising process p, if we decide to fix them.
"""
if stddev_type == 'beta':
# use diffusion variances, except for t=1, for which we use posterior variance beta_post_2
return 0.5 * torch.log(torch.gather(self._betas_init, 0, timestep-1))
elif stddev_type == 'beta_post':
# use diffusion posterior variances, except for t=1, for which there is no posterior, so we use beta_post_2
return 0.5 * torch.log(torch.gather(self._betas_post_init, 0, timestep-1))
elif stddev_type == 'learn':
return None
else:
raise ValueError('Unknown stddev_type: {}'.format(stddev_type))
# @torch.no_grad()
# def debug_run_denoising_diffusion(self, model, num_samples, shape, x_noisy, timestep,
# temp=1.0, enable_autocast=False, is_image=False, prior_var=1.0,
# condition_input=None):
# """
# Run the full denoising sampling loop.
# """
# # set model to eval mode
# # initialize sample
# #x_noisy_size = [num_samples] + shape
# #x_noisy = torch.randn(size=x_noisy_size, device='cuda') ## * np.sqrt(prior_var) * temp
# model.eval()
# x_noisy_size = x_noisy.shape
# x_noisy = x_noisy[0:1].expand(x_noisy.shape[0],-1,-1,-1) #
# timestep_start = timestep[0].item()
# output_list = []
# output_pred_list = []
# logger.info('timestep_start: {}', timestep_start)
# # denoising loop
# for t in reversed(range(0, self._diffusion_steps)):
# if t > timestep_start:
# continue
# if t % 100 == 0:
# logger.info('t={}', t)
# timestep = torch.ones(num_samples, dtype=torch.int64, device='cuda') * (t+1) # the model uses (1 ... T) without 0
# fixed_log_scales = self.get_p_log_scales(timestep=timestep, stddev_type=self._denoising_stddevs)
# mixing_component = self.get_mixing_component(x_noisy, timestep, enabled=model.mixed_prediction)
# # run model
# with autocast(enable_autocast):
# pred_logits = model(x=x_noisy, t=timestep.float() , condition_input=condition_input)
# # pred_logits = model(x_noisy, timestep.float() / self._diffusion_steps)
# logits = utils.get_mixed_prediction(model.mixed_prediction, pred_logits, model.mixing_logit, mixing_component)
# output_dist = utils.decoder_output('place_holder', logits, fixed_log_scales=fixed_log_scales)
# noise = torch.randn(size=x_noisy_size, device='cuda')
# mean = self.get_q_posterior_mean(x_noisy, output_dist.means, t)
# _, var_t_p, m_t_p, _, _, _ = self.iw_quantities_t(
# num_samples, timestep)
# pred_eps_t0 = (x_noisy - torch.sqrt(var_t_p) * pred_logits) / m_t_p
# if t == 0:
# x_image = mean
# else:
# x_noisy = mean + torch.exp(output_dist.log_scales) * noise * temp
# output_list.append(x_noisy)
# output_pred_list.append(pred_eps_t0)
# if is_image:
# x_image = x_image.clamp(min=-1., max=1.)
# x_image = utils.unsymmetrize_image_data(x_image)
# model.train()
# return x_image, output_list, output_pred_list
@torch.no_grad()
def run_denoising_diffusion(self, model, num_samples, shape, temp=1.0,
enable_autocast=False, is_image=False, prior_var=1.0,
condition_input=None, given_noise=None, clip_feat=None, cls_emb=None, grid_emb=None):
"""
Run the full denoising sampling loop.
"""
# set model to eval mode
model.eval()
# initialize sample
x_noisy_size = [num_samples] + shape
if given_noise is None:
# * np.sqrt(prior_var) * temp
x_noisy = torch.randn(size=x_noisy_size, device='cuda')
else:
x_noisy = given_noise[0]
output_list = {}
output_list['pred_x'] = []
# output_list['init_x_noisy'] = x_noisy
# output_list['input_x'] = []
# output_list['input_t'] = []
# output_list['output_e'] = []
# output_list['noise_t'] = []
# output_list['condition_input'] = []
# denoising loop
kwargs = {}
if grid_emb is not None:
kwargs['grid_emb'] = grid_emb
for t in reversed(range(0, self._diffusion_steps)):
if t % 500 == 0:
logger.info('t={}; shape={}, num_samples={}, sample shape: {}',
t, shape, num_samples, x_noisy.shape)
# the model uses (1 ... T) without 0
timestep = torch.ones(
num_samples, dtype=torch.int64, device='cuda') * (t+1)
fixed_log_scales = self.get_p_log_scales(
timestep=timestep, stddev_type=self._denoising_stddevs)
mixing_component = self.get_mixing_component(
x_noisy, timestep, enabled=model.mixed_prediction)
# run model
with autocast(enable_autocast):
if cls_emb is not None and condition_input is not None:
condition_input = torch.cat(
[condition_input, cls_emb], dim=1)
elif cls_emb is not None and condition_input is None:
condition_input = cls_emb
# output_list['input_x'].append(x_noisy)
# output_list['input_t'].append(timestep)
# output_list['condition_input'].append(condition_input)
pred_logits = model(x=x_noisy, t=timestep.float(),
condition_input=condition_input, clip_feat=clip_feat, **kwargs)
# output_list['output_e'].append(pred_logits)
# pred_logits = model(x_noisy, timestep.float() / self._diffusion_steps)
logits = utils.get_mixed_prediction(
model.mixed_prediction, pred_logits, model.mixing_logit, mixing_component)
output_dist = utils.decoder_output(
'place_holder', logits, fixed_log_scales=fixed_log_scales)
if given_noise is None:
noise = torch.randn(size=x_noisy_size, device='cuda')
else:
# torch.randn(size=x_noisy_size, device='cuda')
noise = given_noise[1][t]
mean = self.get_q_posterior_mean(x_noisy, output_dist.means, t)
if t == 0:
x_image = mean
else:
x_noisy = mean + \
torch.exp(output_dist.log_scales) * noise * temp
# output_list['noise_t'].append(noise)
output_list['pred_x'].append(x_noisy)
if is_image:
x_image = x_image.clamp(min=-1., max=1.)
x_image = utils.unsymmetrize_image_data(x_image)
model.train()
return x_image, output_list
def run_ddim_forward(self, dae, eps, ddim_step, ddim_skip_type, condition_input=None, clip_feat=None):
## raise NotImplementedError
""" calculates NLL based on ODE framework, assuming integration cutoff ode_eps """
model.eval()
# initialize sample
x_noisy_size = [num_samples] + shape
x_noisy = torch.randn(
size=x_noisy_size, device='cuda') if x_noisy is None else x_noisy.cuda()
output_list = []
S = ddim_step
# even spaced t
if skip_type == 'uniform':
c = (self._diffusion_steps - 1.0) / (S - 1.0)
list_tau = [np.floor(i * c) for i in range(S)]
list_tau = [int(s) for s in list_tau]
elif skip_type == 'quad':
seq = (np.linspace(
0, np.sqrt(self._diffusion_steps * 0.8), S
) ** 2
)
list_tau = [int(s) for s in list(seq)]
user_defined_steps = sorted(list(list_tau), reverse=True)
T_user = len(user_defined_steps)
kwargs = {}
if grid_emb is not None:
kwargs['grid_emb'] = grid_emb
def ode_func(t, x):
""" the ode function (including log probability integration for NLL calculation) """
global nfe_counter
nfe_counter = nfe_counter + 1
x = x.detach()
x.requires_grad_(False)
with torch.set_grad_enabled(False):
with autocast(enabled=enable_autocast):
variance = self.var(t=t)
mixing_component = self.mixing_component(
x_noisy=x, var_t=variance, t=t, enabled=dae.mixed_prediction)
pred_params = dae(
x=x, t=t, condition_input=condition_input, clip_feat=clip_feat)
# Warning: here mixing_logit can be NOne
params = get_mixed_prediction(
dae.mixed_prediction, pred_params, dae.mixing_logit, mixing_component)
dx_dt = self.f(t=t) * x + 0.5 * self.g2(t=t) * \
params / torch.sqrt(variance)
# dx_dt = - 0.5 * self.g2(t=t) * (x - params / torch.sqrt(variance))
# with autocast(enabled=False):
# dlogp_x_dt = -trace_df_dx_hutchinson(dx_dt, x, noise, no_autograd).view(x.shape[0], 1)
return dx_dt
# NFE counter
global nfe_counter
nll_all, nfe_all = [], []
for i in range(num_samples):
# integrated log probability
# logp_diff_t0 = torch.zeros(eps.shape[0], 1, device='cuda')
nfe_counter = 0
# solve the ODE
x_t = odeint(
ode_func,
eps,
torch.tensor([ode_eps, 1.0], device='cuda'),
atol=ode_solver_tol, # 1e-5
rtol=ode_solver_tol, # 1e-5
# 'dopri5' or 'dopri8' methods also seems good.
method="scipy_solver",
options={"solver": 'RK45'}, # only for scipy solvers
)
x_t0 = x_t[-1]
nfe_all.append(nfe_counter)
print('nfe_counter: ', nfe_counter)
return x_t0
@torch.no_grad()
def run_ddim(self, model, num_samples, shape, temp=1.0, enable_autocast=False, is_image=True, prior_var=1.0,
condition_input=None, ddim_step=100, skip_type='uniform', kappa=1.0, clip_feat=None, grid_emb=None,
x_noisy=None, dae_index=-1):
"""
Run the full denoising sampling loop.
kappa = 1.0 # this one is the eta in DDIM algorithm
"""
# set model to eval mode
model.eval()
# initialize sample
x_noisy_size = [num_samples] + shape
x_noisy = torch.randn(
size=x_noisy_size, device='cuda') if x_noisy is None else x_noisy.cuda()
output_list = []
S = ddim_step
# even spaced t
if skip_type == 'uniform':
c = (self._diffusion_steps - 1.0) / (S - 1.0)
list_tau = [np.floor(i * c) for i in range(S)]
list_tau = [int(s) for s in list_tau]
elif skip_type == 'quad':
seq = (np.linspace(
0, np.sqrt(self._diffusion_steps * 0.8), S
) ** 2
)
list_tau = [int(s) for s in list(seq)]
user_defined_steps = sorted(list(list_tau), reverse=True)
T_user = len(user_defined_steps)
kwargs = {}
if grid_emb is not None:
kwargs['grid_emb'] = grid_emb
# denoising loop
# for t in user_defined_steps: ## reversed(range(0, self._diffusion_steps)):
Alpha_bar = self._alpha_bars # self.var_sched.alphas_cumprod
# the following computation is the same as the function in https://github.com/ermongroup/ddim/blob/51cb290f83049e5381b09a4cc0389f16a4a02cc9/functions/denoising.py#L10
for i, t in enumerate(user_defined_steps):
if i % 500 == 0:
logger.info('t={} / {}, ori={}', i, S, self._diffusion_steps)
tau = t
# the model uses (1 ... T) without 0
timestep = torch.ones(
num_samples, dtype=torch.int64, device='cuda') * (t+1)
fixed_log_scales = self.get_p_log_scales(
timestep=timestep, stddev_type=self._denoising_stddevs)
mixing_component = self.get_mixing_component(
x_noisy, timestep, enabled=model.mixed_prediction)
# --- copied --- #
if i == T_user - 1: # the next step is to generate x_0
assert t == 0
alpha_next = torch.tensor(1.0)
sigma = torch.tensor(0.0)
else:
alpha_next = Alpha_bar[user_defined_steps[i+1]]
sigma = kappa * \
torch.sqrt(
(1-alpha_next) / (1-Alpha_bar[tau]) * (1 - Alpha_bar[tau] / alpha_next))
x = x_noisy * torch.sqrt(alpha_next / Alpha_bar[tau])
c = torch.sqrt(1 - alpha_next - sigma ** 2) - torch.sqrt(1 -
Alpha_bar[tau]) * torch.sqrt(alpha_next / Alpha_bar[tau])
# --- run model forward --- #
with autocast(enable_autocast):
pred_logits = model(x=x_noisy, t=timestep.float(
), condition_input=condition_input, clip_feat=clip_feat, **kwargs)
# pred_logits = model(x_noisy, timestep.float() / self._diffusion_steps)
logits = utils.get_mixed_prediction(
model.mixed_prediction, pred_logits, model.mixing_logit, mixing_component)
epsilon_theta = logits
# xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
# x_{t-1} = c * et + sigma * randn + sqrt(alpha_next / alpha_bar_t) * x_t
x += c * epsilon_theta + sigma * \
torch.randn(x_noisy_size).to(x.device)
x_noisy = x
output_list.append(x_noisy)
# if is_image:
# x_image = x_image.clamp(min=-1., max=1.)
# x_image = utils.unsymmetrize_image_data(x_image)
model.train()
return x_noisy, output_list
def get_q_posterior_mean(self, x_noisy, prediction, t):
# last step works differently (for better FIDs we NEVER sample in last conditional images output!)
# Line 4 in algorithm 2 in DDPM:
if t == 0:
mean = 1.0 / torch.sqrt(self._alpha_bars[0]) * \
(x_noisy - torch.sqrt(1.0 - self._alpha_bars[0]) * prediction)
else:
mean = 1.0 / torch.sqrt(self._alphas[t]) * \
(x_noisy - self._betas_init[t] * prediction /
torch.sqrt(1.0 - self._alpha_bars[t]))
return mean
def get_mixing_component(self, x_noisy, timestep, enabled):
size = x_noisy.size()
alpha_bars = torch.gather(self._alpha_bars, 0, timestep-1)
if enabled:
one_minus_alpha_bars_sqrt = utils.view4D(
torch.sqrt(1.0 - alpha_bars), size)
mixing_component = one_minus_alpha_bars_sqrt * x_noisy
else:
mixing_component = None
return mixing_component
def mixing_component(self, eps, var, t, enabled):
return self.get_mixing_component(eps, t, enabled)
@torch.no_grad()
def run_denoising_diffusion_from_t(self, model, num_samples, shape, time_start, x_noisy,
temp=1.0, enable_autocast=False, is_image=False, prior_var=1.0,
condition_input=None, given_noise=None):
"""
Run the full denoising sampling loop.
given_noise: Nstep,*x_noisy_size
"""
# set model to eval mode
model.eval()
# initialize sample
x_noisy_size = [num_samples] + shape
# if given_noise is None:
## raise ValueError('given_noise is required')
# raise NotImplementedError
# x_noisy = torch.randn(size=x_noisy_size, device='cuda') ## * np.sqrt(prior_var) * temp
# else:
## x_noisy = given_noise[0]
output_list = []
# denoising loop
for t in reversed(range(0, time_start)): # self._diffusion_steps)):
# if t % 100 == 0:
# logger.info('t={}', t)
# the model uses (1 ... T) without 0
timestep = torch.ones(
num_samples, dtype=torch.int64, device='cuda') * (t+1)
fixed_log_scales = self.get_p_log_scales(
timestep=timestep, stddev_type=self._denoising_stddevs)
mixing_component = self.get_mixing_component(
x_noisy, timestep, enabled=model.mixed_prediction)
# run model
with autocast(enable_autocast):
pred_logits = model(
x=x_noisy, t=timestep.float(), condition_input=condition_input)
# pred_logits = model(x_noisy, timestep.float() / self._diffusion_steps)
logits = utils.get_mixed_prediction(
model.mixed_prediction, pred_logits, model.mixing_logit, mixing_component)
output_dist = utils.decoder_output(
'place_holder', logits, fixed_log_scales=fixed_log_scales)
if given_noise is None:
noise = torch.randn(size=x_noisy_size, device='cuda')
else:
# torch.randn(size=x_noisy_size, device='cuda')
noise = given_noise[1][t]
mean = self.get_q_posterior_mean(x_noisy, output_dist.means, t)
if t == 0: # < 10:
x_image = mean
else:
x_noisy = mean + \
torch.exp(output_dist.log_scales) * noise * temp
output_list.append(x_noisy)
# if is_image:
# x_image = x_image.clamp(min=-1., max=1.)
# x_image = utils.unsymmetrize_image_data(x_image)
model.train()
return x_image, output_list