LION/utils/diffusion_continuous.py
2023-01-23 00:14:49 -05:00

846 lines
39 KiB
Python

# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
"""modified from https://github.com/NVlabs/LSGM/blob/5eae2f385c014f2250c3130152b6be711f6a3a5a/diffusion_continuous.py"""
from abc import ABC, abstractmethod
import numpy as np
import torch
import gc
# import utils.distributions as distributions
from utils.utils import trace_df_dx_hutchinson, sample_gaussian_like, sample_rademacher_like, get_mixed_prediction
from third_party.torchdiffeq.torchdiffeq import odeint
from torch.cuda.amp import autocast
from timeit import default_timer as timer
from loguru import logger
def make_diffusion(args):
""" simple diffusion factory function to return diffusion instances. Only use this to create continuous diffusions """
if args.sde_type == 'geometric_sde':
return DiffusionGeometric(args)
elif args.sde_type == 'vpsde':
return DiffusionVPSDE(args)
elif args.sde_type == 'sub_vpsde':
return DiffusionSubVPSDE(args)
elif args.sde_type == 'power_vpsde':
return DiffusionPowerVPSDE(args)
elif args.sde_type == 'sub_power_vpsde':
return DiffusionSubPowerVPSDE(args)
elif args.sde_type == 'vesde':
return DiffusionVESDE(args)
else:
raise ValueError("Unrecognized sde type: {}".format(args.sde_type))
class DiffusionBase(ABC):
"""
Abstract base class for all diffusion implementations.
"""
def __init__(self, args):
super().__init__()
self.sigma2_0 = args.sigma2_0
self.sde_type = args.sde_type
@abstractmethod
def f(self, t):
""" returns the drift coefficient at time t: f(t) """
pass
@abstractmethod
def g2(self, t):
""" returns the squared diffusion coefficient at time t: g^2(t) """
pass
@abstractmethod
def var(self, t):
""" returns variance at time t, \sigma_t^2
q(zt|z0) = N(zt; \mu_t(z0), \sigma_t^2 I)
"""
pass
@abstractmethod
def e2int_f(self, t):
""" returns e^{\int_0^t f(s) ds} which corresponds to the coefficient of mean at time t. """
pass
@abstractmethod
def inv_var(self, var):
""" inverse of the variance function at input variance var. """
pass
@abstractmethod
def mixing_component(self, x_noisy, var_t, t, enabled):
""" returns mixing component which is the optimal denoising model assuming that q(z_0) is N(0, 1) """
pass
def sample_q(self, x_init, noise, var_t, m_t):
""" returns a sample from diffusion process at time t """
return m_t * x_init + torch.sqrt(var_t) * noise
def cross_entropy_const(self, ode_eps):
""" returns cross entropy factor with variance according to ode integration cutoff ode_eps """
# _, c, h, w = x_init.shape
return 0.5 * (1.0 + torch.log(2.0 * np.pi * self.var(t=torch.tensor(ode_eps, device='cuda'))))
def compute_ode_nll(self, dae, eps, ode_eps, ode_solver_tol, enable_autocast=False,
no_autograd=False, num_samples=1, report_std=False,
condition_input=None, clip_feat=None):
## raise NotImplementedError
""" calculates NLL based on ODE framework, assuming integration cutoff ode_eps """
# ODE solver starts consuming the CPU memory without this on large models
# https://github.com/scipy/scipy/issues/10070
gc.collect()
dae.eval()
def ode_func(t, x):
""" the ode function (including log probability integration for NLL calculation) """
global nfe_counter
nfe_counter = nfe_counter + 1
# x = state[0].detach()
x = x.detach()
x.requires_grad_(False)
# noise = sample_gaussian_like(x) # could also use rademacher noise (sample_rademacher_like)
with torch.set_grad_enabled(False):
with autocast(enabled=enable_autocast):
variance = self.var(t=t)
mixing_component = self.mixing_component(
x_noisy=x, var_t=variance, t=t, enabled=dae.mixed_prediction)
pred_params = dae(
x=x, t=t, condition_input=condition_input, clip_feat=clip_feat)
# Warning: here mixing_logit can be NOne
params = get_mixed_prediction(
dae.mixed_prediction, pred_params, dae.mixing_logit, mixing_component)
dx_dt = self.f(t=t) * x + 0.5 * self.g2(t=t) * \
params / torch.sqrt(variance)
# dx_dt = - 0.5 * self.g2(t=t) * (x - params / torch.sqrt(variance))
# with autocast(enabled=False):
# dlogp_x_dt = -trace_df_dx_hutchinson(dx_dt, x, noise, no_autograd).view(x.shape[0], 1)
return dx_dt
# NFE counter
global nfe_counter
nll_all, nfe_all = [], []
for i in range(num_samples):
# integrated log probability
# logp_diff_t0 = torch.zeros(eps.shape[0], 1, device='cuda')
nfe_counter = 0
# solve the ODE
x_t = odeint(
ode_func,
eps,
torch.tensor([ode_eps, 1.0], device='cuda'),
atol=ode_solver_tol, # 1e-5
rtol=ode_solver_tol, # 1e-5
# 'dopri5' or 'dopri8' methods also seems good.
method="scipy_solver",
options={"solver": 'RK45'}, # only for scipy solvers
)
# last output values
x_t0 = x_t[-1]
## x_t0, logp_diff_t0 = x_t[-1], logp_diff_t[-1]
# prior
# if self.sde_type == 'vesde':
# logp_prior = torch.sum(distributions.log_p_var_normal(x_t0, var=self.sigma2_max), dim=[1, 2, 3])
# else:
# logp_prior = torch.sum(distributions.log_p_standard_normal(x_t0), dim=[1, 2, 3])
#log_likelihood = logp_prior - logp_diff_t0.view(-1)
# nll_all.append(-log_likelihood)
nfe_all.append(nfe_counter)
print('nfe_counter: ', nfe_counter)
#nfe_mean = np.mean(nfe_all)
##nll_all = torch.stack(nll_all, dim=1)
#nll_mean = torch.mean(nll_all, dim=1)
# if num_samples > 1 and report_std:
# nll_stddev = torch.std(nll_all ,dii=1)
# nll_stddev_batch = torch.mean(nll_stddev)
# nll_stderror_batch = nll_stddev_batch / np.sqrt(num_samples)
# else:
# nll_stddev_batch = None
# nll_stderror_batch = None
return x_t0 # nll_mean, nfe_mean, nll_stddev_batch, nll_stderror_batch
def sample_model_ode(self, dae, num_samples, shape, ode_eps,
ode_solver_tol, enable_autocast, temp, noise=None,
condition_input=None, mixing_logit=None,
use_cust_ode_func=0, init_t=1.0, return_all_sample=False, clip_feat=None
):
""" generates samples using the ODE framework, assuming integration cutoff ode_eps """
# ODE solver starts consuming the CPU memory without this on large models
# https://github.com/scipy/scipy/issues/10070
gc.collect()
dae.eval()
def cust_ode_func(t, x):
""" the ode function (sampling only, no NLL stuff) """
global nfe_counter
nfe_counter = nfe_counter + 1
if nfe_counter % 100 == 0:
logger.info('nfe_counter={}', nfe_counter)
with autocast(enabled=enable_autocast):
variance = self.var(t=t)
params = dae(x, x, t, condition_input=condition_input)
dx_dt = self.f(t=t) * x + 0.5 * self.g2(t=t) * \
params / torch.sqrt(variance)
# dx_dt = - 0.5 * self.g2(t=t) * (x - params / torch.sqrt(variance))
return dx_dt
def ode_func(t, x):
""" the ode function (sampling only, no NLL stuff) """
global nfe_counter
nfe_counter = nfe_counter + 1
if nfe_counter % 100 == 0:
logger.info('nfe_counter={}', nfe_counter)
with autocast(enabled=enable_autocast):
variance = self.var(t=t)
mixing_component = self.mixing_component(
x_noisy=x, var_t=variance, t=t, enabled=dae.mixed_prediction)
pred_params = dae(
x=x, t=t, condition_input=condition_input, clip_feat=clip_feat)
input_mixing_logit = mixing_logit if mixing_logit is not None else dae.mixing_logit
params = get_mixed_prediction(
dae.mixed_prediction, pred_params, input_mixing_logit, mixing_component)
dx_dt = self.f(t=t) * x + 0.5 * self.g2(t=t) * \
params / torch.sqrt(variance)
# dx_dt = - 0.5 * self.g2(t=t) * (x - params / torch.sqrt(variance))
return dx_dt
# the initial noise
if noise is None:
noise = torch.randn(size=[num_samples] + shape, device='cuda')
if self.sde_type == 'vesde':
noise_init = temp * noise * np.sqrt(self.sigma2_max)
else:
noise_init = temp * noise
# NFE counter
global nfe_counter
nfe_counter = 0
# solve the ODE
start = timer()
samples_out = odeint(
ode_func if not use_cust_ode_func else cust_ode_func,
noise_init,
torch.tensor([init_t, ode_eps], device='cuda'),
atol=ode_solver_tol, # 1e-5
rtol=ode_solver_tol, # 1e-5
# 'dopri5' or 'dopri8' methods also seems good.
method="scipy_solver",
options={"solver": 'RK45'}, # only for scipy solvers
)
end = timer()
ode_solve_time = end - start
if return_all_sample:
return samples_out[-1], samples_out, nfe_counter, ode_solve_time
return samples_out[-1], nfe_counter, ode_solve_time
# def compute_dsm_nll(self, dae, eps, time_eps, enable_autocast, num_samples, report_std):
# with torch.no_grad():
# neg_log_prob_all = []
# for i in range(num_samples):
# assert self.sde_type in ['geometric_sde', 'vpsde', 'power_vpsde'], "we don't support subVPSDE yet."
# t, var_t, m_t, obj_weight_t, _, _ = \
# self.iw_quantities(eps.shape[0], time_eps, iw_sample_mode='ll_iw', iw_subvp_like_vp_sde=False)
# noise = torch.randn(size=eps.size(), device='cuda')
# eps_t = self.sample_q(eps, noise, var_t, m_t)
# mixing_component = self.mixing_component(eps_t, var_t, t, enabled=dae.mixed_prediction)
# with autocast(enabled=enable_autocast):
# pred_params = dae(eps_t, t)
# params = get_mixed_prediction(dae.mixed_prediction, pred_params, dae.mixing_logit, mixing_component)
# l2_term = torch.square(params - noise)
# neg_log_prob_per_var = obj_weight_t * l2_term
# neg_log_prob_per_var += self.cross_entropy_const(time_eps)
# neg_log_prob = torch.sum(neg_log_prob_per_var, dim=[1, 2, 3])
# neg_log_prob_all.append(neg_log_prob)
# neg_log_prob_all = torch.stack(neg_log_prob_all, dim=1)
# nll = torch.mean(neg_log_prob_all, dim=1)
# if num_samples > 1 and report_std:
# nll_std = torch.std(neg_log_prob_all, dim=1)
# print('std nll:', nll_std)
# return nll
def iw_quantities(self, size, time_eps, iw_sample_mode, iw_subvp_like_vp_sde):
if self.sde_type in ['geometric_sde', 'vpsde', 'power_vpsde']:
return self._iw_quantities_vpsdelike(size, time_eps, iw_sample_mode)
elif self.sde_type in ['sub_vpsde', 'sub_power_vpsde']:
return self._iw_quantities_subvpsdelike(size, time_eps, iw_sample_mode, iw_subvp_like_vp_sde)
elif self.sde_type in ['vesde']:
return self._iw_quantities_vesde(size, time_eps, iw_sample_mode)
else:
raise NotImplementedError
def debug_sheduler(self, time_eps):
# time_eps, 1-time_eps, 1000) ##-1) / 1000.0 + time_eps
t = torch.linspace(0, 1, 1000)
t = torch.range(1, 1000) / 1000.0
t = t.cuda()
var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
obj_weight_t_p = torch.ones(1, device='cuda')
obj_weight_t_q = g2_t / (2.0 * var_t)
return t, var_t.view(-1, 1, 1, 1), m_t.view(-1, 1, 1, 1), \
obj_weight_t_p.view(-1, 1, 1, 1), \
obj_weight_t_q.view(-1, 1, 1, 1), g2_t.view(-1, 1, 1, 1)
def _iw_quantities_vpsdelike(self, size, time_eps, iw_sample_mode):
"""
For all SDEs where the underlying SDE is of the form dz = -0.5 * beta(t) * z * dt + sqrt{beta(t)} * dw, like
for the VPSDE.
"""
rho = torch.rand(size=[size], device='cuda')
if iw_sample_mode == 'll_uniform':
# uniform t sampling - likelihood obj. for both q and p
t = rho * (1. - time_eps) + time_eps
var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
obj_weight_t_p = obj_weight_t_q = g2_t / (2.0 * var_t)
elif iw_sample_mode == 'll_iw':
# importance sampling for likelihood obj. - likelihood obj. for both q and p
ones = torch.ones_like(rho, device='cuda')
sigma2_1, sigma2_eps = self.var(ones), self.var(time_eps * ones)
log_sigma2_1, log_sigma2_eps = torch.log(
sigma2_1), torch.log(sigma2_eps)
var_t = torch.exp(rho * log_sigma2_1 + (1 - rho) * log_sigma2_eps)
t = self.inv_var(var_t)
m_t, g2_t = self.e2int_f(t), self.g2(t)
obj_weight_t_p = obj_weight_t_q = 0.5 * \
(log_sigma2_1 - log_sigma2_eps) / (1.0 - var_t)
elif iw_sample_mode == 'drop_all_uniform':
# uniform t sampling - likelihood obj. for q, all-prefactors-dropped obj. for p
t = rho * (1. - time_eps) + time_eps
var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
obj_weight_t_p = torch.ones(1, device='cuda')
obj_weight_t_q = g2_t / (2.0 * var_t)
elif iw_sample_mode == 'drop_all_iw':
# importance sampling for all-pref.-dropped obj. - likelihood obj. for q, all-pref.-dropped obj. for p
assert self.sde_type == 'vpsde', 'Importance sampling for fully unweighted objective is currently only ' \
'implemented for the regular VPSDE.'
t = torch.sqrt(1.0 / self.delta_beta_half) * torch.erfinv(rho *
self.const_norm_2 + self.const_erf) - self.beta_frac
var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
obj_weight_t_p = self.const_norm / (1.0 - var_t)
obj_weight_t_q = obj_weight_t_p * g2_t / (2.0 * var_t)
elif iw_sample_mode == 'drop_sigma2t_iw':
# importance sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p
ones = torch.ones_like(rho, device='cuda')
sigma2_1, sigma2_eps = self.var(ones), self.var(time_eps * ones)
var_t = rho * sigma2_1 + (1 - rho) * sigma2_eps
t = self.inv_var(var_t)
m_t, g2_t = self.e2int_f(t), self.g2(t)
obj_weight_t_p = 0.5 * (sigma2_1 - sigma2_eps) / (1.0 - var_t)
obj_weight_t_q = obj_weight_t_p / var_t
elif iw_sample_mode == 'drop_sigma2t_uniform':
# uniform sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p
t = rho * (1. - time_eps) + time_eps
var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
obj_weight_t_p = g2_t / 2.0
obj_weight_t_q = g2_t / (2.0 * var_t)
elif iw_sample_mode == 'rescale_iw':
# importance sampling for 1/(1-sigma2_t) resc. obj. - likelihood obj. for q, 1/(1-sigma2_t) resc. obj. for p
t = rho * (1. - time_eps) + time_eps
var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
obj_weight_t_p = 0.5 / (1.0 - var_t)
obj_weight_t_q = g2_t / (2.0 * var_t)
else:
raise ValueError(
"Unrecognized importance sampling type: {}".format(iw_sample_mode))
return t, var_t.view(-1, 1, 1, 1), m_t.view(-1, 1, 1, 1), obj_weight_t_p.view(-1, 1, 1, 1), \
obj_weight_t_q.view(-1, 1, 1, 1), g2_t.view(-1, 1, 1, 1)
# def _iw_quantities_subvpsdelike(self, size, time_eps, iw_sample_mode, iw_subvp_like_vp_sde):
# """
# For all SDEs where the underlying SDE is of the form
# dz = -0.5 * beta(t) * z * dt + sqrt{beta(t) * (1 - exp[-2 * betaintegral])} * dw, like for the Sub-VPSDE.
# When iw_subvp_like_vp_sde is True, then we define the importance sampling distributions based on an analogous
# VPSDE, while stile using the Sub-VPSDE. The motivation is that deriving the correct importance sampling
# distributions for the Sub-VPSDE itself is hard, but the importance sampling distributions from analogous VPSDEs
# probably already significantly reduce the variance also for the Sub-VPSDE.
# """
# rho = torch.rand(size=[size], device='cuda')
# if iw_sample_mode == 'll_uniform':
# # uniform t sampling - likelihood obj. for both q and p
# t = rho * (1. - time_eps) + time_eps
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
# obj_weight_t_p = obj_weight_t_q = g2_t / (2.0 * var_t)
# elif iw_sample_mode == 'll_iw':
# if iw_subvp_like_vp_sde:
# # importance sampling for vpsde likelihood obj. - sub-vpsde likelihood obj. for both q and p
# ones = torch.ones_like(rho, device='cuda')
# sigma2_1, sigma2_eps = self.var_vpsde(ones), self.var_vpsde(time_eps * ones)
# log_sigma2_1, log_sigma2_eps = torch.log(sigma2_1), torch.log(sigma2_eps)
# var_t_vpsde = torch.exp(rho * log_sigma2_1 + (1 - rho) * log_sigma2_eps)
# t = self.inv_var_vpsde(var_t_vpsde)
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
# obj_weight_t_p = obj_weight_t_q = g2_t / (2.0 * var_t) * \
# (log_sigma2_1 - log_sigma2_eps) * var_t_vpsde / (1 - var_t_vpsde) / self.beta(t)
# else:
# raise NotImplementedError
# elif iw_sample_mode == 'drop_all_uniform':
# # uniform t sampling - likelihood obj. for q, all-prefactors-dropped obj. for p
# t = rho * (1. - time_eps) + time_eps
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
# obj_weight_t_p = torch.ones(1, device='cuda')
# obj_weight_t_q = g2_t / (2.0 * var_t)
# elif iw_sample_mode == 'drop_all_iw':
# if iw_subvp_like_vp_sde:
# # importance sampling for all-pref.-dropped obj. - likelihood obj. for q, all-pref.-dropped obj. for p
# assert self.sde_type == 'sub_vpsde', 'Importance sampling for fully unweighted objective is ' \
# 'currently only implemented for the Sub-VPSDE.'
# t = torch.sqrt(1.0 / self.delta_beta_half) * torch.erfinv(rho * self.const_norm_2 + self.const_erf) - self.beta_frac
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
# obj_weight_t_p = self.const_norm / (1.0 - self.var_vpsde(t))
# obj_weight_t_q = obj_weight_t_p * g2_t / (2.0 * var_t)
# else:
# raise NotImplementedError
# elif iw_sample_mode == 'drop_sigma2t_iw':
# if iw_subvp_like_vp_sde:
# # importance sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p
# ones = torch.ones_like(rho, device='cuda')
# sigma2_1, sigma2_eps = self.var_vpsde(ones), self.var_vpsde(time_eps * ones)
# var_t_vpsde = rho * sigma2_1 + (1 - rho) * sigma2_eps
# t = self.inv_var_vpsde(var_t_vpsde)
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
# obj_weight_t_p = 0.5 * g2_t / self.beta(t) * (sigma2_1 - sigma2_eps) / (1.0 - var_t_vpsde)
# obj_weight_t_q = obj_weight_t_p / var_t
# else:
# raise NotImplementedError
# elif iw_sample_mode == 'drop_sigma2t_uniform':
# # uniform sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p
# t = rho * (1. - time_eps) + time_eps
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
# obj_weight_t_p = g2_t / 2.0
# obj_weight_t_q = g2_t / (2.0 * var_t)
# elif iw_sample_mode == 'rescale_iw':
# # importance sampling for 1/(1-sigma2_t) resc. obj. - likelihood obj. for q, 1/(1-sigma2_t) resc. obj. for p
# # Note that we use the sub-vpsde variance to scale the p objective! It's not clear what's optimal here!
# t = rho * (1. - time_eps) + time_eps
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
# obj_weight_t_p = 0.5 / (1.0 - var_t)
# obj_weight_t_q = g2_t / (2.0 * var_t)
# else:
# raise ValueError("Unrecognized importance sampling type: {}".format(iw_sample_mode))
# return t, var_t.view(-1, 1, 1, 1), m_t.view(-1, 1, 1, 1), obj_weight_t_p.view(-1, 1, 1, 1), \
# obj_weight_t_q.view(-1, 1, 1, 1), g2_t.view(-1, 1, 1, 1)
# def _iw_quantities_vesde(self, size, time_eps, iw_sample_mode):
# """
# For the VESDE.
# """
# rho = torch.rand(size=[size], device='cuda')
# if iw_sample_mode == 'll_uniform':
# # uniform t sampling - likelihood obj. for both q and p
# t = rho * (1. - time_eps) + time_eps
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
# obj_weight_t_p = obj_weight_t_q = g2_t / (2.0 * var_t)
# elif iw_sample_mode == 'll_iw':
# # importance sampling for likelihood obj. - likelihood obj. for both q and p
# ones = torch.ones_like(rho, device='cuda')
# nsigma2_1, nsigma2_eps, sigma2_eps = self.var_N(ones), self.var_N(time_eps * ones), self.var(time_eps * ones)
# log_frac_sigma2_1, log_frac_sigma2_eps = torch.log(self.sigma2_max / nsigma2_1), torch.log(nsigma2_eps / sigma2_eps)
# var_N_t = (1.0 - self.sigma2_min) / (1.0 - torch.exp(rho * (log_frac_sigma2_1 + log_frac_sigma2_eps) - log_frac_sigma2_eps))
# t = self.inv_var_N(var_N_t)
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
# obj_weight_t_p = obj_weight_t_q = 0.5 * (log_frac_sigma2_1 + log_frac_sigma2_eps) * self.var_N(t) / (1.0 - self.sigma2_min)
# elif iw_sample_mode == 'drop_all_uniform':
# # uniform t sampling - likelihood obj. for q, all-prefactors-dropped obj. for p
# t = rho * (1. - time_eps) + time_eps
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
# obj_weight_t_p = torch.ones(1, device='cuda')
# obj_weight_t_q = g2_t / (2.0 * var_t)
# elif iw_sample_mode == 'drop_all_iw':
# # importance sampling for all-pref.-dropped obj. - likelihood obj. for q, all-pref.-dropped obj. for p
# ones = torch.ones_like(rho, device='cuda')
# nsigma2_1, nsigma2_eps, sigma2_eps = self.var_N(ones), self.var_N(time_eps * ones), self.var(time_eps * ones)
# log_frac_sigma2_1, log_frac_sigma2_eps = torch.log(self.sigma2_max / nsigma2_1), torch.log(nsigma2_eps / sigma2_eps)
# var_N_t = (1.0 - self.sigma2_min) / (1.0 - torch.exp(rho * (log_frac_sigma2_1 + log_frac_sigma2_eps) - log_frac_sigma2_eps))
# t = self.inv_var_N(var_N_t)
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
# obj_weight_t_q = 0.5 * (log_frac_sigma2_1 + log_frac_sigma2_eps) * self.var_N(t) / (1.0 - self.sigma2_min)
# obj_weight_t_p = 2.0 * obj_weight_t_q / np.log(self.sigma2_max / self.sigma2_min)
# elif iw_sample_mode == 'drop_sigma2t_iw':
# # importance sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p
# ones = torch.ones_like(rho, device='cuda')
# nsigma2_1, nsigma2_eps = self.var_N(ones), self.var_N(time_eps * ones)
# var_N_t = torch.exp(rho * torch.log(nsigma2_1) + (1 - rho) * torch.log(nsigma2_eps))
# t = self.inv_var_N(var_N_t)
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
# obj_weight_t_p = 0.5 * torch.log(nsigma2_1 / nsigma2_eps) * self.var_N(t)
# obj_weight_t_q = obj_weight_t_p / var_t
# elif iw_sample_mode == 'drop_sigma2t_uniform':
# # uniform sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p
# t = rho * (1. - time_eps) + time_eps
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
# obj_weight_t_p = g2_t / 2.0
# obj_weight_t_q = g2_t / (2.0 * var_t)
# elif iw_sample_mode == 'rescale_iw':
# # uniform sampling for 1/(1-sigma2_t) resc. obj. - likelihood obj. for q, 1/(1-sigma2_t) resc. obj. for p
# t = rho * (1. - time_eps) + time_eps
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
# obj_weight_t_p = 0.5 / (1.0 - var_t)
# obj_weight_t_q = g2_t / (2.0 * var_t)
# else:
# raise ValueError("Unrecognized importance sampling type: {}".format(iw_sample_mode))
# return t, var_t.view(-1, 1, 1, 1), m_t.view(-1, 1, 1, 1), obj_weight_t_p.view(-1, 1, 1, 1), \
# obj_weight_t_q.view(-1, 1, 1, 1), g2_t.view(-1, 1, 1, 1)
# class DiffusionGeometric(DiffusionBase):
# """
# Diffusion implementation with dz = -0.5 * beta(t) * z * dt + sqrt(beta(t)) * dW SDE and geometric progression of
# variance. This is our new diffusion.
# """
# def __init__(self, args):
# super().__init__(args)
# self.sigma2_min = args.sigma2_min
# self.sigma2_max = args.sigma2_max
#
# def f(self, t):
# return -0.5 * self.g2(t)
#
# def g2(self, t):
# sigma2_geom = self.sigma2_min * ((self.sigma2_max / self.sigma2_min) ** t)
# log_term = np.log(self.sigma2_max / self.sigma2_min)
# return sigma2_geom * log_term / (1.0 - self.sigma2_0 + self.sigma2_min - sigma2_geom)
#
# def var(self, t):
# return self.sigma2_min * ((self.sigma2_max / self.sigma2_min) ** t) - self.sigma2_min + self.sigma2_0
#
# def e2int_f(self, t):
# return torch.sqrt(1.0 + self.sigma2_min * (1.0 - (self.sigma2_max / self.sigma2_min) ** t) / (1.0 - self.sigma2_0))
#
# def inv_var(self, var):
# return torch.log((var + self.sigma2_min - self.sigma2_0) / self.sigma2_min) / np.log(self.sigma2_max / self.sigma2_min)
#
# def mixing_component(self, x_noisy, var_t, t, enabled):
# if enabled:
# return torch.sqrt(var_t) * x_noisy
# else:
# return None
#
class DiffusionVPSDE(DiffusionBase):
"""
Diffusion implementation of the VPSDE. This uses the same SDE like DiffusionGeometric but with linear beta(t).
Note that we need to scale beta_start and beta_end by 1000 relative to JH's DDPM values, since our t is in [0,1].
"""
def __init__(self, args):
super().__init__(args)
self.beta_start = args.beta_start
self.beta_end = args.beta_end
logger.info('VPSDE: beta_start={}, beta_end={}, sigma2_0={}',
self.beta_start, self.beta_end, self.sigma2_0)
# auxiliary constants (yes, this is not super clean...)
self.time_eps = args.time_eps
self.delta_beta_half = torch.tensor(
0.5 * (self.beta_end - self.beta_start), device='cuda')
self.beta_frac = torch.tensor(
self.beta_start / (self.beta_end - self.beta_start), device='cuda')
self.const_aq = (1.0 - self.sigma2_0) * torch.exp(0.5 *
self.beta_frac) * torch.sqrt(0.25 * np.pi / self.delta_beta_half)
self.const_erf = torch.erf(torch.sqrt(
self.delta_beta_half) * (self.time_eps + self.beta_frac))
self.const_norm = self.const_aq * \
(torch.erf(torch.sqrt(self.delta_beta_half)
* (1.0 + self.beta_frac)) - self.const_erf)
self.const_norm_2 = torch.erf(torch.sqrt(
self.delta_beta_half) * (1.0 + self.beta_frac)) - self.const_erf
def f(self, t):
return -0.5 * self.g2(t)
def g2(self, t):
return self.beta_start + (self.beta_end - self.beta_start) * t
def var(self, t):
return 1.0 - (1.0 - self.sigma2_0) * torch.exp(-self.beta_start * t - 0.5 * (self.beta_end - self.beta_start) * t * t)
def e2int_f(self, t):
return torch.exp(-0.5 * self.beta_start * t - 0.25 * (self.beta_end - self.beta_start) * t * t)
def inv_var(self, var):
c = torch.log((1 - var) / (1 - self.sigma2_0))
a = self.beta_end - self.beta_start
t = (-self.beta_start + torch.sqrt(np.square(self.beta_start) - 2 * a * c)) / a
return t
def mixing_component(self, x_noisy, var_t, t, enabled):
if enabled:
return torch.sqrt(var_t) * x_noisy
else:
return None
# class DiffusionSubVPSDE(DiffusionBase):
# """
# Diffusion implementation of the sub-VPSDE. Note that this uses a different SDE compared to the above two diffusions.
# """
# def __init__(self, args):
# super().__init__(args)
# self.beta_start = args.beta_start
# self.beta_end = args.beta_end
#
# # auxiliary constants (assumes regular VPSDE... yes, this is not super clean...)
# self.time_eps = args.time_eps
# self.delta_beta_half = torch.tensor(0.5 * (self.beta_end - self.beta_start), device='cuda')
# self.beta_frac = torch.tensor(self.beta_start / (self.beta_end - self.beta_start), device='cuda')
# self.const_aq = (1.0 - self.sigma2_0) * torch.exp(0.5 * self.beta_frac) * torch.sqrt(0.25 * np.pi / self.delta_beta_half)
# self.const_erf = torch.erf(torch.sqrt(self.delta_beta_half) * (self.time_eps + self.beta_frac))
# self.const_norm = self.const_aq * (torch.erf(torch.sqrt(self.delta_beta_half) * (1.0 + self.beta_frac)) - self.const_erf)
# self.const_norm_2 = torch.erf(torch.sqrt(self.delta_beta_half) * (1.0 + self.beta_frac)) - self.const_erf
#
# def f(self, t):
# return -0.5 * self.beta(t)
#
# def g2(self, t):
# return self.beta(t) * (1.0 - torch.exp(-2.0 * self.beta_start * t - (self.beta_end - self.beta_start) * t * t))
#
# def var(self, t):
# int_term = torch.exp(-self.beta_start * t - 0.5 * (self.beta_end - self.beta_start) * t * t)
# return torch.square(1.0 - int_term) + self.sigma2_0 * int_term
#
# def e2int_f(self, t):
# return torch.exp(-0.5 * self.beta_start * t - 0.25 * (self.beta_end - self.beta_start) * t * t)
#
# def beta(self, t):
# """ auxiliary beta function """
# return self.beta_start + (self.beta_end - self.beta_start) * t
#
# def inv_var(self, var):
# raise NotImplementedError
#
# def mixing_component(self, x_noisy, var_t, t, enabled):
# if enabled:
# int_term = torch.exp(-self.beta_start * t - 0.5 * (self.beta_end - self.beta_start) * t * t).view(-1, 1, 1, 1)
# return torch.sqrt(var_t) * x_noisy / (torch.square(1.0 - int_term) + int_term)
# else:
# return None
#
# def var_vpsde(self, t):
# return 1.0 - (1.0 - self.sigma2_0) * torch.exp(-self.beta_start * t - 0.5 * (self.beta_end - self.beta_start) * t * t)
#
# def inv_var_vpsde(self, var):
# c = torch.log((1 - var) / (1 - self.sigma2_0))
# a = self.beta_end - self.beta_start
# t = (-self.beta_start + torch.sqrt(np.square(self.beta_start) - 2 * a * c)) / a
# return t
#
#
# class DiffusionPowerVPSDE(DiffusionBase):
# """
# Diffusion implementation of the power-VPSDE. This uses the same SDE like DiffusionGeometric but with beta function
# that is a power function with user specified power. Note that for power=1, this reproduces the vanilla
# DiffusionVPSDE above.
# """
# def __init__(self, args):
# super().__init__(args)
# self.beta_start = args.beta_start
# self.beta_end = args.beta_end
# self.power = args.vpsde_power
#
# def f(self, t):
# return -0.5 * self.g2(t)
#
# def g2(self, t):
# return self.beta_start + (self.beta_end - self.beta_start) * t ** self.power
#
# def var(self, t):
# return 1.0 - (1.0 - self.sigma2_0) * torch.exp(-self.beta_start * t - (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0))
#
# def e2int_f(self, t):
# return torch.exp(-0.5 * self.beta_start * t - 0.5 * (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0))
#
# def inv_var(self, var):
# if self.power == 2:
# c = torch.log((1 - var) / (1 - self.sigma2_0))
# p = 3.0 * self.beta_start / (self.beta_end - self.beta_start)
# q = 3.0 * c / (self.beta_end - self.beta_start)
# a = -0.5 * q + torch.sqrt(q ** 2 / 4.0 + p ** 3 / 27.0)
# b = -0.5 * q - torch.sqrt(q ** 2 / 4.0 + p ** 3 / 27.0)
# return torch.pow(a, 1.0 / 3.0) + torch.pow(b, 1.0 / 3.0)
# else:
# raise NotImplementedError
#
# def mixing_component(self, x_noisy, var_t, t, enabled):
# if enabled:
# return torch.sqrt(var_t) * x_noisy
# else:
# return None
#
#
# class DiffusionSubPowerVPSDE(DiffusionBase):
# """
# Diffusion implementation of the sub-power-VPSDE. This uses the same SDE like DiffusionSubVPSDE but with beta
# function that is a power function with user specified power. Note that for power=1, this reproduces the vanilla
# DiffusionSubVPSDE above.
# """
# def __init__(self, args):
# super().__init__(args)
# self.beta_start = args.beta_start
# self.beta_end = args.beta_end
# self.power = args.vpsde_power
#
# def f(self, t):
# return -0.5 * self.beta(t)
#
# def g2(self, t):
# return self.beta(t) * (1.0 - torch.exp(-2.0 * self.beta_start * t - 2.0 * (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0)))
#
# def var(self, t):
# int_term = torch.exp(-self.beta_start * t - (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0))
# return torch.square(1.0 - int_term) + self.sigma2_0 * int_term
#
# def e2int_f(self, t):
# return torch.exp(-0.5 * self.beta_start * t - 0.5 * (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0))
#
# def beta(self, t):
# """ internal auxiliary beta function """
# return self.beta_start + (self.beta_end - self.beta_start) * t ** self.power
#
# def inv_var(self, var):
# raise NotImplementedError
#
# def mixing_component(self, x_noisy, var_t, t, enabled):
# if enabled:
# int_term = torch.exp(-self.beta_start * t - (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0)).view(-1, 1, 1, 1)
# return torch.sqrt(var_t) * x_noisy / (torch.square(1.0 - int_term) + int_term)
# else:
# return None
#
# def var_vpsde(self, t):
# return 1.0 - (1.0 - self.sigma2_0) * torch.exp(-self.beta_start * t - (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0))
#
# def inv_var_vpsde(self, var):
# if self.power == 2:
# c = torch.log((1 - var) / (1 - self.sigma2_0))
# p = 3.0 * self.beta_start / (self.beta_end - self.beta_start)
# q = 3.0 * c / (self.beta_end - self.beta_start)
# a = -0.5 * q + torch.sqrt(q ** 2 / 4.0 + p ** 3 / 27.0)
# b = -0.5 * q - torch.sqrt(q ** 2 / 4.0 + p ** 3 / 27.0)
# return torch.pow(a, 1.0 / 3.0) + torch.pow(b, 1.0 / 3.0)
# else:
# raise NotImplementedError
#
#
# class DiffusionVESDE(DiffusionBase):
# """
# Diffusion implementation of the VESDE with dz = sqrt(beta(t)) * dW
# """
# def __init__(self, args):
# super().__init__(args)
# self.sigma2_min = args.sigma2_min
# self.sigma2_max = args.sigma2_max
# assert self.sigma2_min == self.sigma2_0, "VESDE was proposed implicitly assuming sigma2_min = sigma2_0!"
#
# def f(self, t):
# return torch.zeros_like(t, device='cuda')
#
# def g2(self, t):
# return self.sigma2_min * np.log(self.sigma2_max / self.sigma2_min) * ((self.sigma2_max / self.sigma2_min) ** t)
#
# def var(self, t):
# return self.sigma2_min * ((self.sigma2_max / self.sigma2_min) ** t) - self.sigma2_min + self.sigma2_0
#
# def e2int_f(self, t):
# return torch.ones_like(t, device='cuda')
#
# def inv_var(self, var):
# return torch.log((var + self.sigma2_min - self.sigma2_0) / self.sigma2_min) / np.log(self.sigma2_max / self.sigma2_min)
#
# def mixing_component(self, x_noisy, var_t, t, enabled):
# if enabled:
# return torch.sqrt(var_t) * x_noisy / (self.sigma2_min * ((self.sigma2_max / self.sigma2_min) ** t.view(-1, 1, 1, 1)) - self.sigma2_min + 1.0)
# else:
# return None
#
# def var_N(self, t):
# return 1.0 - self.sigma2_min + self.sigma2_min * ((self.sigma2_max / self.sigma2_min) ** t)
#
# def inv_var_N(self, var):
# return torch.log((var + self.sigma2_min - 1.0) / self.sigma2_min) / np.log(self.sigma2_max / self.sigma2_min)
if __name__ == "__main__":
import matplotlib.pyplot as plt
class Foo:
def __init__(self):
self.sde_type = 'vpsde'
self.sigma2_0 = 0.01
self.sigma2_min = 3e-5
self.sigma2_max = 0.999
self.beta_start = 0.1
self.beta_end = 20
# A unit test to check the implementation of e2intf and var_t
diff = make_diffusion(Foo())
print(diff.inv_var(diff.var(torch.tensor(0.5))))
exit()
delta = 1e-8
t = np.arange(start=0.001, stop=0.999, step=delta)
t = torch.tensor(t)
f_t = diff.f(t)
e2intf = diff.e2int_f(t)
# compute finite differences for e2intf
grad_fd = (e2intf[1:] - e2intf[:-1]) / delta
grad_analytic = f_t[:-1] * e2intf[:-1]
print(torch.max(torch.abs(grad_fd - grad_analytic)))
var_t = diff.var(t)
grad_fd = (var_t[1:] - var_t[:-1]) / delta
grad_analytic = (2 * f_t * var_t + diff.g2(t))[:-1]
print(torch.max(torch.abs(grad_fd - grad_analytic)))