846 lines
39 KiB
Python
846 lines
39 KiB
Python
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
|
# and proprietary rights in and to this software, related documentation
|
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
# distribution of this software and related documentation without an express
|
|
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
|
"""modified from https://github.com/NVlabs/LSGM/blob/5eae2f385c014f2250c3130152b6be711f6a3a5a/diffusion_continuous.py"""
|
|
from abc import ABC, abstractmethod
|
|
import numpy as np
|
|
import torch
|
|
import gc
|
|
# import utils.distributions as distributions
|
|
from utils.utils import trace_df_dx_hutchinson, sample_gaussian_like, sample_rademacher_like, get_mixed_prediction
|
|
from third_party.torchdiffeq.torchdiffeq import odeint
|
|
from torch.cuda.amp import autocast
|
|
from timeit import default_timer as timer
|
|
from loguru import logger
|
|
|
|
|
|
def make_diffusion(args):
|
|
""" simple diffusion factory function to return diffusion instances. Only use this to create continuous diffusions """
|
|
if args.sde_type == 'geometric_sde':
|
|
return DiffusionGeometric(args)
|
|
elif args.sde_type == 'vpsde':
|
|
return DiffusionVPSDE(args)
|
|
elif args.sde_type == 'sub_vpsde':
|
|
return DiffusionSubVPSDE(args)
|
|
elif args.sde_type == 'power_vpsde':
|
|
return DiffusionPowerVPSDE(args)
|
|
elif args.sde_type == 'sub_power_vpsde':
|
|
return DiffusionSubPowerVPSDE(args)
|
|
elif args.sde_type == 'vesde':
|
|
return DiffusionVESDE(args)
|
|
else:
|
|
raise ValueError("Unrecognized sde type: {}".format(args.sde_type))
|
|
|
|
|
|
class DiffusionBase(ABC):
|
|
"""
|
|
Abstract base class for all diffusion implementations.
|
|
"""
|
|
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
self.sigma2_0 = args.sigma2_0
|
|
self.sde_type = args.sde_type
|
|
|
|
@abstractmethod
|
|
def f(self, t):
|
|
""" returns the drift coefficient at time t: f(t) """
|
|
pass
|
|
|
|
@abstractmethod
|
|
def g2(self, t):
|
|
""" returns the squared diffusion coefficient at time t: g^2(t) """
|
|
pass
|
|
|
|
@abstractmethod
|
|
def var(self, t):
|
|
""" returns variance at time t, \sigma_t^2
|
|
q(zt|z0) = N(zt; \mu_t(z0), \sigma_t^2 I)
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def e2int_f(self, t):
|
|
""" returns e^{\int_0^t f(s) ds} which corresponds to the coefficient of mean at time t. """
|
|
pass
|
|
|
|
@abstractmethod
|
|
def inv_var(self, var):
|
|
""" inverse of the variance function at input variance var. """
|
|
pass
|
|
|
|
@abstractmethod
|
|
def mixing_component(self, x_noisy, var_t, t, enabled):
|
|
""" returns mixing component which is the optimal denoising model assuming that q(z_0) is N(0, 1) """
|
|
pass
|
|
|
|
def sample_q(self, x_init, noise, var_t, m_t):
|
|
""" returns a sample from diffusion process at time t """
|
|
return m_t * x_init + torch.sqrt(var_t) * noise
|
|
|
|
def cross_entropy_const(self, ode_eps):
|
|
""" returns cross entropy factor with variance according to ode integration cutoff ode_eps """
|
|
# _, c, h, w = x_init.shape
|
|
return 0.5 * (1.0 + torch.log(2.0 * np.pi * self.var(t=torch.tensor(ode_eps, device='cuda'))))
|
|
|
|
def compute_ode_nll(self, dae, eps, ode_eps, ode_solver_tol, enable_autocast=False,
|
|
no_autograd=False, num_samples=1, report_std=False,
|
|
condition_input=None, clip_feat=None):
|
|
## raise NotImplementedError
|
|
""" calculates NLL based on ODE framework, assuming integration cutoff ode_eps """
|
|
# ODE solver starts consuming the CPU memory without this on large models
|
|
# https://github.com/scipy/scipy/issues/10070
|
|
gc.collect()
|
|
|
|
dae.eval()
|
|
|
|
def ode_func(t, x):
|
|
""" the ode function (including log probability integration for NLL calculation) """
|
|
global nfe_counter
|
|
nfe_counter = nfe_counter + 1
|
|
|
|
# x = state[0].detach()
|
|
x = x.detach()
|
|
x.requires_grad_(False)
|
|
# noise = sample_gaussian_like(x) # could also use rademacher noise (sample_rademacher_like)
|
|
with torch.set_grad_enabled(False):
|
|
with autocast(enabled=enable_autocast):
|
|
variance = self.var(t=t)
|
|
mixing_component = self.mixing_component(
|
|
x_noisy=x, var_t=variance, t=t, enabled=dae.mixed_prediction)
|
|
pred_params = dae(
|
|
x=x, t=t, condition_input=condition_input, clip_feat=clip_feat)
|
|
# Warning: here mixing_logit can be NOne
|
|
params = get_mixed_prediction(
|
|
dae.mixed_prediction, pred_params, dae.mixing_logit, mixing_component)
|
|
dx_dt = self.f(t=t) * x + 0.5 * self.g2(t=t) * \
|
|
params / torch.sqrt(variance)
|
|
# dx_dt = - 0.5 * self.g2(t=t) * (x - params / torch.sqrt(variance))
|
|
|
|
# with autocast(enabled=False):
|
|
# dlogp_x_dt = -trace_df_dx_hutchinson(dx_dt, x, noise, no_autograd).view(x.shape[0], 1)
|
|
|
|
return dx_dt
|
|
|
|
# NFE counter
|
|
global nfe_counter
|
|
|
|
nll_all, nfe_all = [], []
|
|
for i in range(num_samples):
|
|
# integrated log probability
|
|
# logp_diff_t0 = torch.zeros(eps.shape[0], 1, device='cuda')
|
|
|
|
nfe_counter = 0
|
|
|
|
# solve the ODE
|
|
x_t = odeint(
|
|
ode_func,
|
|
eps,
|
|
torch.tensor([ode_eps, 1.0], device='cuda'),
|
|
atol=ode_solver_tol, # 1e-5
|
|
rtol=ode_solver_tol, # 1e-5
|
|
# 'dopri5' or 'dopri8' methods also seems good.
|
|
method="scipy_solver",
|
|
options={"solver": 'RK45'}, # only for scipy solvers
|
|
)
|
|
# last output values
|
|
x_t0 = x_t[-1]
|
|
## x_t0, logp_diff_t0 = x_t[-1], logp_diff_t[-1]
|
|
|
|
# prior
|
|
# if self.sde_type == 'vesde':
|
|
# logp_prior = torch.sum(distributions.log_p_var_normal(x_t0, var=self.sigma2_max), dim=[1, 2, 3])
|
|
# else:
|
|
# logp_prior = torch.sum(distributions.log_p_standard_normal(x_t0), dim=[1, 2, 3])
|
|
|
|
#log_likelihood = logp_prior - logp_diff_t0.view(-1)
|
|
|
|
# nll_all.append(-log_likelihood)
|
|
nfe_all.append(nfe_counter)
|
|
print('nfe_counter: ', nfe_counter)
|
|
|
|
#nfe_mean = np.mean(nfe_all)
|
|
##nll_all = torch.stack(nll_all, dim=1)
|
|
#nll_mean = torch.mean(nll_all, dim=1)
|
|
# if num_samples > 1 and report_std:
|
|
# nll_stddev = torch.std(nll_all ,dii=1)
|
|
# nll_stddev_batch = torch.mean(nll_stddev)
|
|
# nll_stderror_batch = nll_stddev_batch / np.sqrt(num_samples)
|
|
# else:
|
|
# nll_stddev_batch = None
|
|
# nll_stderror_batch = None
|
|
return x_t0 # nll_mean, nfe_mean, nll_stddev_batch, nll_stderror_batch
|
|
|
|
def sample_model_ode(self, dae, num_samples, shape, ode_eps,
|
|
ode_solver_tol, enable_autocast, temp, noise=None,
|
|
condition_input=None, mixing_logit=None,
|
|
use_cust_ode_func=0, init_t=1.0, return_all_sample=False, clip_feat=None
|
|
):
|
|
""" generates samples using the ODE framework, assuming integration cutoff ode_eps """
|
|
# ODE solver starts consuming the CPU memory without this on large models
|
|
# https://github.com/scipy/scipy/issues/10070
|
|
gc.collect()
|
|
|
|
dae.eval()
|
|
|
|
def cust_ode_func(t, x):
|
|
""" the ode function (sampling only, no NLL stuff) """
|
|
global nfe_counter
|
|
nfe_counter = nfe_counter + 1
|
|
if nfe_counter % 100 == 0:
|
|
logger.info('nfe_counter={}', nfe_counter)
|
|
with autocast(enabled=enable_autocast):
|
|
variance = self.var(t=t)
|
|
params = dae(x, x, t, condition_input=condition_input)
|
|
dx_dt = self.f(t=t) * x + 0.5 * self.g2(t=t) * \
|
|
params / torch.sqrt(variance)
|
|
# dx_dt = - 0.5 * self.g2(t=t) * (x - params / torch.sqrt(variance))
|
|
|
|
return dx_dt
|
|
|
|
def ode_func(t, x):
|
|
""" the ode function (sampling only, no NLL stuff) """
|
|
global nfe_counter
|
|
nfe_counter = nfe_counter + 1
|
|
if nfe_counter % 100 == 0:
|
|
logger.info('nfe_counter={}', nfe_counter)
|
|
with autocast(enabled=enable_autocast):
|
|
variance = self.var(t=t)
|
|
mixing_component = self.mixing_component(
|
|
x_noisy=x, var_t=variance, t=t, enabled=dae.mixed_prediction)
|
|
pred_params = dae(
|
|
x=x, t=t, condition_input=condition_input, clip_feat=clip_feat)
|
|
input_mixing_logit = mixing_logit if mixing_logit is not None else dae.mixing_logit
|
|
params = get_mixed_prediction(
|
|
dae.mixed_prediction, pred_params, input_mixing_logit, mixing_component)
|
|
dx_dt = self.f(t=t) * x + 0.5 * self.g2(t=t) * \
|
|
params / torch.sqrt(variance)
|
|
# dx_dt = - 0.5 * self.g2(t=t) * (x - params / torch.sqrt(variance))
|
|
|
|
return dx_dt
|
|
|
|
# the initial noise
|
|
if noise is None:
|
|
noise = torch.randn(size=[num_samples] + shape, device='cuda')
|
|
|
|
if self.sde_type == 'vesde':
|
|
noise_init = temp * noise * np.sqrt(self.sigma2_max)
|
|
else:
|
|
noise_init = temp * noise
|
|
|
|
# NFE counter
|
|
global nfe_counter
|
|
nfe_counter = 0
|
|
|
|
# solve the ODE
|
|
start = timer()
|
|
samples_out = odeint(
|
|
ode_func if not use_cust_ode_func else cust_ode_func,
|
|
noise_init,
|
|
torch.tensor([init_t, ode_eps], device='cuda'),
|
|
atol=ode_solver_tol, # 1e-5
|
|
rtol=ode_solver_tol, # 1e-5
|
|
# 'dopri5' or 'dopri8' methods also seems good.
|
|
method="scipy_solver",
|
|
options={"solver": 'RK45'}, # only for scipy solvers
|
|
)
|
|
end = timer()
|
|
ode_solve_time = end - start
|
|
if return_all_sample:
|
|
return samples_out[-1], samples_out, nfe_counter, ode_solve_time
|
|
return samples_out[-1], nfe_counter, ode_solve_time
|
|
|
|
# def compute_dsm_nll(self, dae, eps, time_eps, enable_autocast, num_samples, report_std):
|
|
# with torch.no_grad():
|
|
# neg_log_prob_all = []
|
|
# for i in range(num_samples):
|
|
# assert self.sde_type in ['geometric_sde', 'vpsde', 'power_vpsde'], "we don't support subVPSDE yet."
|
|
# t, var_t, m_t, obj_weight_t, _, _ = \
|
|
# self.iw_quantities(eps.shape[0], time_eps, iw_sample_mode='ll_iw', iw_subvp_like_vp_sde=False)
|
|
|
|
# noise = torch.randn(size=eps.size(), device='cuda')
|
|
# eps_t = self.sample_q(eps, noise, var_t, m_t)
|
|
# mixing_component = self.mixing_component(eps_t, var_t, t, enabled=dae.mixed_prediction)
|
|
# with autocast(enabled=enable_autocast):
|
|
# pred_params = dae(eps_t, t)
|
|
# params = get_mixed_prediction(dae.mixed_prediction, pred_params, dae.mixing_logit, mixing_component)
|
|
# l2_term = torch.square(params - noise)
|
|
|
|
# neg_log_prob_per_var = obj_weight_t * l2_term
|
|
# neg_log_prob_per_var += self.cross_entropy_const(time_eps)
|
|
# neg_log_prob = torch.sum(neg_log_prob_per_var, dim=[1, 2, 3])
|
|
|
|
# neg_log_prob_all.append(neg_log_prob)
|
|
|
|
# neg_log_prob_all = torch.stack(neg_log_prob_all, dim=1)
|
|
# nll = torch.mean(neg_log_prob_all, dim=1)
|
|
# if num_samples > 1 and report_std:
|
|
# nll_std = torch.std(neg_log_prob_all, dim=1)
|
|
# print('std nll:', nll_std)
|
|
|
|
# return nll
|
|
|
|
def iw_quantities(self, size, time_eps, iw_sample_mode, iw_subvp_like_vp_sde):
|
|
if self.sde_type in ['geometric_sde', 'vpsde', 'power_vpsde']:
|
|
return self._iw_quantities_vpsdelike(size, time_eps, iw_sample_mode)
|
|
elif self.sde_type in ['sub_vpsde', 'sub_power_vpsde']:
|
|
return self._iw_quantities_subvpsdelike(size, time_eps, iw_sample_mode, iw_subvp_like_vp_sde)
|
|
elif self.sde_type in ['vesde']:
|
|
return self._iw_quantities_vesde(size, time_eps, iw_sample_mode)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def debug_sheduler(self, time_eps):
|
|
# time_eps, 1-time_eps, 1000) ##-1) / 1000.0 + time_eps
|
|
t = torch.linspace(0, 1, 1000)
|
|
t = torch.range(1, 1000) / 1000.0
|
|
t = t.cuda()
|
|
var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
obj_weight_t_p = torch.ones(1, device='cuda')
|
|
obj_weight_t_q = g2_t / (2.0 * var_t)
|
|
return t, var_t.view(-1, 1, 1, 1), m_t.view(-1, 1, 1, 1), \
|
|
obj_weight_t_p.view(-1, 1, 1, 1), \
|
|
obj_weight_t_q.view(-1, 1, 1, 1), g2_t.view(-1, 1, 1, 1)
|
|
|
|
def _iw_quantities_vpsdelike(self, size, time_eps, iw_sample_mode):
|
|
"""
|
|
For all SDEs where the underlying SDE is of the form dz = -0.5 * beta(t) * z * dt + sqrt{beta(t)} * dw, like
|
|
for the VPSDE.
|
|
"""
|
|
rho = torch.rand(size=[size], device='cuda')
|
|
|
|
if iw_sample_mode == 'll_uniform':
|
|
# uniform t sampling - likelihood obj. for both q and p
|
|
t = rho * (1. - time_eps) + time_eps
|
|
var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
obj_weight_t_p = obj_weight_t_q = g2_t / (2.0 * var_t)
|
|
|
|
elif iw_sample_mode == 'll_iw':
|
|
# importance sampling for likelihood obj. - likelihood obj. for both q and p
|
|
ones = torch.ones_like(rho, device='cuda')
|
|
sigma2_1, sigma2_eps = self.var(ones), self.var(time_eps * ones)
|
|
log_sigma2_1, log_sigma2_eps = torch.log(
|
|
sigma2_1), torch.log(sigma2_eps)
|
|
var_t = torch.exp(rho * log_sigma2_1 + (1 - rho) * log_sigma2_eps)
|
|
t = self.inv_var(var_t)
|
|
m_t, g2_t = self.e2int_f(t), self.g2(t)
|
|
obj_weight_t_p = obj_weight_t_q = 0.5 * \
|
|
(log_sigma2_1 - log_sigma2_eps) / (1.0 - var_t)
|
|
|
|
elif iw_sample_mode == 'drop_all_uniform':
|
|
# uniform t sampling - likelihood obj. for q, all-prefactors-dropped obj. for p
|
|
t = rho * (1. - time_eps) + time_eps
|
|
var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
obj_weight_t_p = torch.ones(1, device='cuda')
|
|
obj_weight_t_q = g2_t / (2.0 * var_t)
|
|
|
|
elif iw_sample_mode == 'drop_all_iw':
|
|
# importance sampling for all-pref.-dropped obj. - likelihood obj. for q, all-pref.-dropped obj. for p
|
|
assert self.sde_type == 'vpsde', 'Importance sampling for fully unweighted objective is currently only ' \
|
|
'implemented for the regular VPSDE.'
|
|
t = torch.sqrt(1.0 / self.delta_beta_half) * torch.erfinv(rho *
|
|
self.const_norm_2 + self.const_erf) - self.beta_frac
|
|
var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
obj_weight_t_p = self.const_norm / (1.0 - var_t)
|
|
obj_weight_t_q = obj_weight_t_p * g2_t / (2.0 * var_t)
|
|
|
|
elif iw_sample_mode == 'drop_sigma2t_iw':
|
|
# importance sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p
|
|
ones = torch.ones_like(rho, device='cuda')
|
|
sigma2_1, sigma2_eps = self.var(ones), self.var(time_eps * ones)
|
|
var_t = rho * sigma2_1 + (1 - rho) * sigma2_eps
|
|
t = self.inv_var(var_t)
|
|
m_t, g2_t = self.e2int_f(t), self.g2(t)
|
|
obj_weight_t_p = 0.5 * (sigma2_1 - sigma2_eps) / (1.0 - var_t)
|
|
obj_weight_t_q = obj_weight_t_p / var_t
|
|
|
|
elif iw_sample_mode == 'drop_sigma2t_uniform':
|
|
# uniform sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p
|
|
t = rho * (1. - time_eps) + time_eps
|
|
var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
obj_weight_t_p = g2_t / 2.0
|
|
obj_weight_t_q = g2_t / (2.0 * var_t)
|
|
|
|
elif iw_sample_mode == 'rescale_iw':
|
|
# importance sampling for 1/(1-sigma2_t) resc. obj. - likelihood obj. for q, 1/(1-sigma2_t) resc. obj. for p
|
|
t = rho * (1. - time_eps) + time_eps
|
|
var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
obj_weight_t_p = 0.5 / (1.0 - var_t)
|
|
obj_weight_t_q = g2_t / (2.0 * var_t)
|
|
|
|
else:
|
|
raise ValueError(
|
|
"Unrecognized importance sampling type: {}".format(iw_sample_mode))
|
|
|
|
return t, var_t.view(-1, 1, 1, 1), m_t.view(-1, 1, 1, 1), obj_weight_t_p.view(-1, 1, 1, 1), \
|
|
obj_weight_t_q.view(-1, 1, 1, 1), g2_t.view(-1, 1, 1, 1)
|
|
|
|
# def _iw_quantities_subvpsdelike(self, size, time_eps, iw_sample_mode, iw_subvp_like_vp_sde):
|
|
# """
|
|
# For all SDEs where the underlying SDE is of the form
|
|
# dz = -0.5 * beta(t) * z * dt + sqrt{beta(t) * (1 - exp[-2 * betaintegral])} * dw, like for the Sub-VPSDE.
|
|
# When iw_subvp_like_vp_sde is True, then we define the importance sampling distributions based on an analogous
|
|
# VPSDE, while stile using the Sub-VPSDE. The motivation is that deriving the correct importance sampling
|
|
# distributions for the Sub-VPSDE itself is hard, but the importance sampling distributions from analogous VPSDEs
|
|
# probably already significantly reduce the variance also for the Sub-VPSDE.
|
|
# """
|
|
# rho = torch.rand(size=[size], device='cuda')
|
|
|
|
# if iw_sample_mode == 'll_uniform':
|
|
# # uniform t sampling - likelihood obj. for both q and p
|
|
# t = rho * (1. - time_eps) + time_eps
|
|
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
# obj_weight_t_p = obj_weight_t_q = g2_t / (2.0 * var_t)
|
|
|
|
# elif iw_sample_mode == 'll_iw':
|
|
# if iw_subvp_like_vp_sde:
|
|
# # importance sampling for vpsde likelihood obj. - sub-vpsde likelihood obj. for both q and p
|
|
# ones = torch.ones_like(rho, device='cuda')
|
|
# sigma2_1, sigma2_eps = self.var_vpsde(ones), self.var_vpsde(time_eps * ones)
|
|
# log_sigma2_1, log_sigma2_eps = torch.log(sigma2_1), torch.log(sigma2_eps)
|
|
# var_t_vpsde = torch.exp(rho * log_sigma2_1 + (1 - rho) * log_sigma2_eps)
|
|
# t = self.inv_var_vpsde(var_t_vpsde)
|
|
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
# obj_weight_t_p = obj_weight_t_q = g2_t / (2.0 * var_t) * \
|
|
# (log_sigma2_1 - log_sigma2_eps) * var_t_vpsde / (1 - var_t_vpsde) / self.beta(t)
|
|
# else:
|
|
# raise NotImplementedError
|
|
|
|
# elif iw_sample_mode == 'drop_all_uniform':
|
|
# # uniform t sampling - likelihood obj. for q, all-prefactors-dropped obj. for p
|
|
# t = rho * (1. - time_eps) + time_eps
|
|
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
# obj_weight_t_p = torch.ones(1, device='cuda')
|
|
# obj_weight_t_q = g2_t / (2.0 * var_t)
|
|
|
|
# elif iw_sample_mode == 'drop_all_iw':
|
|
# if iw_subvp_like_vp_sde:
|
|
# # importance sampling for all-pref.-dropped obj. - likelihood obj. for q, all-pref.-dropped obj. for p
|
|
# assert self.sde_type == 'sub_vpsde', 'Importance sampling for fully unweighted objective is ' \
|
|
# 'currently only implemented for the Sub-VPSDE.'
|
|
# t = torch.sqrt(1.0 / self.delta_beta_half) * torch.erfinv(rho * self.const_norm_2 + self.const_erf) - self.beta_frac
|
|
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
# obj_weight_t_p = self.const_norm / (1.0 - self.var_vpsde(t))
|
|
# obj_weight_t_q = obj_weight_t_p * g2_t / (2.0 * var_t)
|
|
# else:
|
|
# raise NotImplementedError
|
|
|
|
# elif iw_sample_mode == 'drop_sigma2t_iw':
|
|
# if iw_subvp_like_vp_sde:
|
|
# # importance sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p
|
|
# ones = torch.ones_like(rho, device='cuda')
|
|
# sigma2_1, sigma2_eps = self.var_vpsde(ones), self.var_vpsde(time_eps * ones)
|
|
# var_t_vpsde = rho * sigma2_1 + (1 - rho) * sigma2_eps
|
|
# t = self.inv_var_vpsde(var_t_vpsde)
|
|
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
# obj_weight_t_p = 0.5 * g2_t / self.beta(t) * (sigma2_1 - sigma2_eps) / (1.0 - var_t_vpsde)
|
|
# obj_weight_t_q = obj_weight_t_p / var_t
|
|
# else:
|
|
# raise NotImplementedError
|
|
|
|
# elif iw_sample_mode == 'drop_sigma2t_uniform':
|
|
# # uniform sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p
|
|
# t = rho * (1. - time_eps) + time_eps
|
|
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
# obj_weight_t_p = g2_t / 2.0
|
|
# obj_weight_t_q = g2_t / (2.0 * var_t)
|
|
|
|
# elif iw_sample_mode == 'rescale_iw':
|
|
# # importance sampling for 1/(1-sigma2_t) resc. obj. - likelihood obj. for q, 1/(1-sigma2_t) resc. obj. for p
|
|
# # Note that we use the sub-vpsde variance to scale the p objective! It's not clear what's optimal here!
|
|
# t = rho * (1. - time_eps) + time_eps
|
|
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
# obj_weight_t_p = 0.5 / (1.0 - var_t)
|
|
# obj_weight_t_q = g2_t / (2.0 * var_t)
|
|
|
|
# else:
|
|
# raise ValueError("Unrecognized importance sampling type: {}".format(iw_sample_mode))
|
|
|
|
# return t, var_t.view(-1, 1, 1, 1), m_t.view(-1, 1, 1, 1), obj_weight_t_p.view(-1, 1, 1, 1), \
|
|
# obj_weight_t_q.view(-1, 1, 1, 1), g2_t.view(-1, 1, 1, 1)
|
|
|
|
# def _iw_quantities_vesde(self, size, time_eps, iw_sample_mode):
|
|
# """
|
|
# For the VESDE.
|
|
# """
|
|
# rho = torch.rand(size=[size], device='cuda')
|
|
|
|
# if iw_sample_mode == 'll_uniform':
|
|
# # uniform t sampling - likelihood obj. for both q and p
|
|
# t = rho * (1. - time_eps) + time_eps
|
|
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
# obj_weight_t_p = obj_weight_t_q = g2_t / (2.0 * var_t)
|
|
|
|
# elif iw_sample_mode == 'll_iw':
|
|
# # importance sampling for likelihood obj. - likelihood obj. for both q and p
|
|
# ones = torch.ones_like(rho, device='cuda')
|
|
# nsigma2_1, nsigma2_eps, sigma2_eps = self.var_N(ones), self.var_N(time_eps * ones), self.var(time_eps * ones)
|
|
# log_frac_sigma2_1, log_frac_sigma2_eps = torch.log(self.sigma2_max / nsigma2_1), torch.log(nsigma2_eps / sigma2_eps)
|
|
# var_N_t = (1.0 - self.sigma2_min) / (1.0 - torch.exp(rho * (log_frac_sigma2_1 + log_frac_sigma2_eps) - log_frac_sigma2_eps))
|
|
# t = self.inv_var_N(var_N_t)
|
|
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
# obj_weight_t_p = obj_weight_t_q = 0.5 * (log_frac_sigma2_1 + log_frac_sigma2_eps) * self.var_N(t) / (1.0 - self.sigma2_min)
|
|
|
|
# elif iw_sample_mode == 'drop_all_uniform':
|
|
# # uniform t sampling - likelihood obj. for q, all-prefactors-dropped obj. for p
|
|
# t = rho * (1. - time_eps) + time_eps
|
|
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
# obj_weight_t_p = torch.ones(1, device='cuda')
|
|
# obj_weight_t_q = g2_t / (2.0 * var_t)
|
|
|
|
# elif iw_sample_mode == 'drop_all_iw':
|
|
# # importance sampling for all-pref.-dropped obj. - likelihood obj. for q, all-pref.-dropped obj. for p
|
|
# ones = torch.ones_like(rho, device='cuda')
|
|
# nsigma2_1, nsigma2_eps, sigma2_eps = self.var_N(ones), self.var_N(time_eps * ones), self.var(time_eps * ones)
|
|
# log_frac_sigma2_1, log_frac_sigma2_eps = torch.log(self.sigma2_max / nsigma2_1), torch.log(nsigma2_eps / sigma2_eps)
|
|
# var_N_t = (1.0 - self.sigma2_min) / (1.0 - torch.exp(rho * (log_frac_sigma2_1 + log_frac_sigma2_eps) - log_frac_sigma2_eps))
|
|
# t = self.inv_var_N(var_N_t)
|
|
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
# obj_weight_t_q = 0.5 * (log_frac_sigma2_1 + log_frac_sigma2_eps) * self.var_N(t) / (1.0 - self.sigma2_min)
|
|
# obj_weight_t_p = 2.0 * obj_weight_t_q / np.log(self.sigma2_max / self.sigma2_min)
|
|
|
|
# elif iw_sample_mode == 'drop_sigma2t_iw':
|
|
# # importance sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p
|
|
# ones = torch.ones_like(rho, device='cuda')
|
|
# nsigma2_1, nsigma2_eps = self.var_N(ones), self.var_N(time_eps * ones)
|
|
# var_N_t = torch.exp(rho * torch.log(nsigma2_1) + (1 - rho) * torch.log(nsigma2_eps))
|
|
# t = self.inv_var_N(var_N_t)
|
|
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
# obj_weight_t_p = 0.5 * torch.log(nsigma2_1 / nsigma2_eps) * self.var_N(t)
|
|
# obj_weight_t_q = obj_weight_t_p / var_t
|
|
|
|
# elif iw_sample_mode == 'drop_sigma2t_uniform':
|
|
# # uniform sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p
|
|
# t = rho * (1. - time_eps) + time_eps
|
|
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
# obj_weight_t_p = g2_t / 2.0
|
|
# obj_weight_t_q = g2_t / (2.0 * var_t)
|
|
|
|
# elif iw_sample_mode == 'rescale_iw':
|
|
# # uniform sampling for 1/(1-sigma2_t) resc. obj. - likelihood obj. for q, 1/(1-sigma2_t) resc. obj. for p
|
|
# t = rho * (1. - time_eps) + time_eps
|
|
# var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
|
|
# obj_weight_t_p = 0.5 / (1.0 - var_t)
|
|
# obj_weight_t_q = g2_t / (2.0 * var_t)
|
|
|
|
# else:
|
|
# raise ValueError("Unrecognized importance sampling type: {}".format(iw_sample_mode))
|
|
|
|
# return t, var_t.view(-1, 1, 1, 1), m_t.view(-1, 1, 1, 1), obj_weight_t_p.view(-1, 1, 1, 1), \
|
|
# obj_weight_t_q.view(-1, 1, 1, 1), g2_t.view(-1, 1, 1, 1)
|
|
|
|
|
|
# class DiffusionGeometric(DiffusionBase):
|
|
# """
|
|
# Diffusion implementation with dz = -0.5 * beta(t) * z * dt + sqrt(beta(t)) * dW SDE and geometric progression of
|
|
# variance. This is our new diffusion.
|
|
# """
|
|
# def __init__(self, args):
|
|
# super().__init__(args)
|
|
# self.sigma2_min = args.sigma2_min
|
|
# self.sigma2_max = args.sigma2_max
|
|
#
|
|
# def f(self, t):
|
|
# return -0.5 * self.g2(t)
|
|
#
|
|
# def g2(self, t):
|
|
# sigma2_geom = self.sigma2_min * ((self.sigma2_max / self.sigma2_min) ** t)
|
|
# log_term = np.log(self.sigma2_max / self.sigma2_min)
|
|
# return sigma2_geom * log_term / (1.0 - self.sigma2_0 + self.sigma2_min - sigma2_geom)
|
|
#
|
|
# def var(self, t):
|
|
# return self.sigma2_min * ((self.sigma2_max / self.sigma2_min) ** t) - self.sigma2_min + self.sigma2_0
|
|
#
|
|
# def e2int_f(self, t):
|
|
# return torch.sqrt(1.0 + self.sigma2_min * (1.0 - (self.sigma2_max / self.sigma2_min) ** t) / (1.0 - self.sigma2_0))
|
|
#
|
|
# def inv_var(self, var):
|
|
# return torch.log((var + self.sigma2_min - self.sigma2_0) / self.sigma2_min) / np.log(self.sigma2_max / self.sigma2_min)
|
|
#
|
|
# def mixing_component(self, x_noisy, var_t, t, enabled):
|
|
# if enabled:
|
|
# return torch.sqrt(var_t) * x_noisy
|
|
# else:
|
|
# return None
|
|
#
|
|
|
|
class DiffusionVPSDE(DiffusionBase):
|
|
"""
|
|
Diffusion implementation of the VPSDE. This uses the same SDE like DiffusionGeometric but with linear beta(t).
|
|
Note that we need to scale beta_start and beta_end by 1000 relative to JH's DDPM values, since our t is in [0,1].
|
|
"""
|
|
|
|
def __init__(self, args):
|
|
super().__init__(args)
|
|
self.beta_start = args.beta_start
|
|
self.beta_end = args.beta_end
|
|
logger.info('VPSDE: beta_start={}, beta_end={}, sigma2_0={}',
|
|
self.beta_start, self.beta_end, self.sigma2_0)
|
|
# auxiliary constants (yes, this is not super clean...)
|
|
self.time_eps = args.time_eps
|
|
self.delta_beta_half = torch.tensor(
|
|
0.5 * (self.beta_end - self.beta_start), device='cuda')
|
|
self.beta_frac = torch.tensor(
|
|
self.beta_start / (self.beta_end - self.beta_start), device='cuda')
|
|
self.const_aq = (1.0 - self.sigma2_0) * torch.exp(0.5 *
|
|
self.beta_frac) * torch.sqrt(0.25 * np.pi / self.delta_beta_half)
|
|
self.const_erf = torch.erf(torch.sqrt(
|
|
self.delta_beta_half) * (self.time_eps + self.beta_frac))
|
|
self.const_norm = self.const_aq * \
|
|
(torch.erf(torch.sqrt(self.delta_beta_half)
|
|
* (1.0 + self.beta_frac)) - self.const_erf)
|
|
self.const_norm_2 = torch.erf(torch.sqrt(
|
|
self.delta_beta_half) * (1.0 + self.beta_frac)) - self.const_erf
|
|
|
|
def f(self, t):
|
|
return -0.5 * self.g2(t)
|
|
|
|
def g2(self, t):
|
|
return self.beta_start + (self.beta_end - self.beta_start) * t
|
|
|
|
def var(self, t):
|
|
return 1.0 - (1.0 - self.sigma2_0) * torch.exp(-self.beta_start * t - 0.5 * (self.beta_end - self.beta_start) * t * t)
|
|
|
|
def e2int_f(self, t):
|
|
return torch.exp(-0.5 * self.beta_start * t - 0.25 * (self.beta_end - self.beta_start) * t * t)
|
|
|
|
def inv_var(self, var):
|
|
c = torch.log((1 - var) / (1 - self.sigma2_0))
|
|
a = self.beta_end - self.beta_start
|
|
t = (-self.beta_start + torch.sqrt(np.square(self.beta_start) - 2 * a * c)) / a
|
|
return t
|
|
|
|
def mixing_component(self, x_noisy, var_t, t, enabled):
|
|
if enabled:
|
|
return torch.sqrt(var_t) * x_noisy
|
|
else:
|
|
return None
|
|
|
|
|
|
# class DiffusionSubVPSDE(DiffusionBase):
|
|
# """
|
|
# Diffusion implementation of the sub-VPSDE. Note that this uses a different SDE compared to the above two diffusions.
|
|
# """
|
|
# def __init__(self, args):
|
|
# super().__init__(args)
|
|
# self.beta_start = args.beta_start
|
|
# self.beta_end = args.beta_end
|
|
#
|
|
# # auxiliary constants (assumes regular VPSDE... yes, this is not super clean...)
|
|
# self.time_eps = args.time_eps
|
|
# self.delta_beta_half = torch.tensor(0.5 * (self.beta_end - self.beta_start), device='cuda')
|
|
# self.beta_frac = torch.tensor(self.beta_start / (self.beta_end - self.beta_start), device='cuda')
|
|
# self.const_aq = (1.0 - self.sigma2_0) * torch.exp(0.5 * self.beta_frac) * torch.sqrt(0.25 * np.pi / self.delta_beta_half)
|
|
# self.const_erf = torch.erf(torch.sqrt(self.delta_beta_half) * (self.time_eps + self.beta_frac))
|
|
# self.const_norm = self.const_aq * (torch.erf(torch.sqrt(self.delta_beta_half) * (1.0 + self.beta_frac)) - self.const_erf)
|
|
# self.const_norm_2 = torch.erf(torch.sqrt(self.delta_beta_half) * (1.0 + self.beta_frac)) - self.const_erf
|
|
#
|
|
# def f(self, t):
|
|
# return -0.5 * self.beta(t)
|
|
#
|
|
# def g2(self, t):
|
|
# return self.beta(t) * (1.0 - torch.exp(-2.0 * self.beta_start * t - (self.beta_end - self.beta_start) * t * t))
|
|
#
|
|
# def var(self, t):
|
|
# int_term = torch.exp(-self.beta_start * t - 0.5 * (self.beta_end - self.beta_start) * t * t)
|
|
# return torch.square(1.0 - int_term) + self.sigma2_0 * int_term
|
|
#
|
|
# def e2int_f(self, t):
|
|
# return torch.exp(-0.5 * self.beta_start * t - 0.25 * (self.beta_end - self.beta_start) * t * t)
|
|
#
|
|
# def beta(self, t):
|
|
# """ auxiliary beta function """
|
|
# return self.beta_start + (self.beta_end - self.beta_start) * t
|
|
#
|
|
# def inv_var(self, var):
|
|
# raise NotImplementedError
|
|
#
|
|
# def mixing_component(self, x_noisy, var_t, t, enabled):
|
|
# if enabled:
|
|
# int_term = torch.exp(-self.beta_start * t - 0.5 * (self.beta_end - self.beta_start) * t * t).view(-1, 1, 1, 1)
|
|
# return torch.sqrt(var_t) * x_noisy / (torch.square(1.0 - int_term) + int_term)
|
|
# else:
|
|
# return None
|
|
#
|
|
# def var_vpsde(self, t):
|
|
# return 1.0 - (1.0 - self.sigma2_0) * torch.exp(-self.beta_start * t - 0.5 * (self.beta_end - self.beta_start) * t * t)
|
|
#
|
|
# def inv_var_vpsde(self, var):
|
|
# c = torch.log((1 - var) / (1 - self.sigma2_0))
|
|
# a = self.beta_end - self.beta_start
|
|
# t = (-self.beta_start + torch.sqrt(np.square(self.beta_start) - 2 * a * c)) / a
|
|
# return t
|
|
#
|
|
#
|
|
# class DiffusionPowerVPSDE(DiffusionBase):
|
|
# """
|
|
# Diffusion implementation of the power-VPSDE. This uses the same SDE like DiffusionGeometric but with beta function
|
|
# that is a power function with user specified power. Note that for power=1, this reproduces the vanilla
|
|
# DiffusionVPSDE above.
|
|
# """
|
|
# def __init__(self, args):
|
|
# super().__init__(args)
|
|
# self.beta_start = args.beta_start
|
|
# self.beta_end = args.beta_end
|
|
# self.power = args.vpsde_power
|
|
#
|
|
# def f(self, t):
|
|
# return -0.5 * self.g2(t)
|
|
#
|
|
# def g2(self, t):
|
|
# return self.beta_start + (self.beta_end - self.beta_start) * t ** self.power
|
|
#
|
|
# def var(self, t):
|
|
# return 1.0 - (1.0 - self.sigma2_0) * torch.exp(-self.beta_start * t - (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0))
|
|
#
|
|
# def e2int_f(self, t):
|
|
# return torch.exp(-0.5 * self.beta_start * t - 0.5 * (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0))
|
|
#
|
|
# def inv_var(self, var):
|
|
# if self.power == 2:
|
|
# c = torch.log((1 - var) / (1 - self.sigma2_0))
|
|
# p = 3.0 * self.beta_start / (self.beta_end - self.beta_start)
|
|
# q = 3.0 * c / (self.beta_end - self.beta_start)
|
|
# a = -0.5 * q + torch.sqrt(q ** 2 / 4.0 + p ** 3 / 27.0)
|
|
# b = -0.5 * q - torch.sqrt(q ** 2 / 4.0 + p ** 3 / 27.0)
|
|
# return torch.pow(a, 1.0 / 3.0) + torch.pow(b, 1.0 / 3.0)
|
|
# else:
|
|
# raise NotImplementedError
|
|
#
|
|
# def mixing_component(self, x_noisy, var_t, t, enabled):
|
|
# if enabled:
|
|
# return torch.sqrt(var_t) * x_noisy
|
|
# else:
|
|
# return None
|
|
#
|
|
#
|
|
# class DiffusionSubPowerVPSDE(DiffusionBase):
|
|
# """
|
|
# Diffusion implementation of the sub-power-VPSDE. This uses the same SDE like DiffusionSubVPSDE but with beta
|
|
# function that is a power function with user specified power. Note that for power=1, this reproduces the vanilla
|
|
# DiffusionSubVPSDE above.
|
|
# """
|
|
# def __init__(self, args):
|
|
# super().__init__(args)
|
|
# self.beta_start = args.beta_start
|
|
# self.beta_end = args.beta_end
|
|
# self.power = args.vpsde_power
|
|
#
|
|
# def f(self, t):
|
|
# return -0.5 * self.beta(t)
|
|
#
|
|
# def g2(self, t):
|
|
# return self.beta(t) * (1.0 - torch.exp(-2.0 * self.beta_start * t - 2.0 * (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0)))
|
|
#
|
|
# def var(self, t):
|
|
# int_term = torch.exp(-self.beta_start * t - (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0))
|
|
# return torch.square(1.0 - int_term) + self.sigma2_0 * int_term
|
|
#
|
|
# def e2int_f(self, t):
|
|
# return torch.exp(-0.5 * self.beta_start * t - 0.5 * (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0))
|
|
#
|
|
# def beta(self, t):
|
|
# """ internal auxiliary beta function """
|
|
# return self.beta_start + (self.beta_end - self.beta_start) * t ** self.power
|
|
#
|
|
# def inv_var(self, var):
|
|
# raise NotImplementedError
|
|
#
|
|
# def mixing_component(self, x_noisy, var_t, t, enabled):
|
|
# if enabled:
|
|
# int_term = torch.exp(-self.beta_start * t - (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0)).view(-1, 1, 1, 1)
|
|
# return torch.sqrt(var_t) * x_noisy / (torch.square(1.0 - int_term) + int_term)
|
|
# else:
|
|
# return None
|
|
#
|
|
# def var_vpsde(self, t):
|
|
# return 1.0 - (1.0 - self.sigma2_0) * torch.exp(-self.beta_start * t - (self.beta_end - self.beta_start) * t ** (self.power + 1) / (self.power + 1.0))
|
|
#
|
|
# def inv_var_vpsde(self, var):
|
|
# if self.power == 2:
|
|
# c = torch.log((1 - var) / (1 - self.sigma2_0))
|
|
# p = 3.0 * self.beta_start / (self.beta_end - self.beta_start)
|
|
# q = 3.0 * c / (self.beta_end - self.beta_start)
|
|
# a = -0.5 * q + torch.sqrt(q ** 2 / 4.0 + p ** 3 / 27.0)
|
|
# b = -0.5 * q - torch.sqrt(q ** 2 / 4.0 + p ** 3 / 27.0)
|
|
# return torch.pow(a, 1.0 / 3.0) + torch.pow(b, 1.0 / 3.0)
|
|
# else:
|
|
# raise NotImplementedError
|
|
#
|
|
#
|
|
# class DiffusionVESDE(DiffusionBase):
|
|
# """
|
|
# Diffusion implementation of the VESDE with dz = sqrt(beta(t)) * dW
|
|
# """
|
|
# def __init__(self, args):
|
|
# super().__init__(args)
|
|
# self.sigma2_min = args.sigma2_min
|
|
# self.sigma2_max = args.sigma2_max
|
|
# assert self.sigma2_min == self.sigma2_0, "VESDE was proposed implicitly assuming sigma2_min = sigma2_0!"
|
|
#
|
|
# def f(self, t):
|
|
# return torch.zeros_like(t, device='cuda')
|
|
#
|
|
# def g2(self, t):
|
|
# return self.sigma2_min * np.log(self.sigma2_max / self.sigma2_min) * ((self.sigma2_max / self.sigma2_min) ** t)
|
|
#
|
|
# def var(self, t):
|
|
# return self.sigma2_min * ((self.sigma2_max / self.sigma2_min) ** t) - self.sigma2_min + self.sigma2_0
|
|
#
|
|
# def e2int_f(self, t):
|
|
# return torch.ones_like(t, device='cuda')
|
|
#
|
|
# def inv_var(self, var):
|
|
# return torch.log((var + self.sigma2_min - self.sigma2_0) / self.sigma2_min) / np.log(self.sigma2_max / self.sigma2_min)
|
|
#
|
|
# def mixing_component(self, x_noisy, var_t, t, enabled):
|
|
# if enabled:
|
|
# return torch.sqrt(var_t) * x_noisy / (self.sigma2_min * ((self.sigma2_max / self.sigma2_min) ** t.view(-1, 1, 1, 1)) - self.sigma2_min + 1.0)
|
|
# else:
|
|
# return None
|
|
#
|
|
# def var_N(self, t):
|
|
# return 1.0 - self.sigma2_min + self.sigma2_min * ((self.sigma2_max / self.sigma2_min) ** t)
|
|
#
|
|
# def inv_var_N(self, var):
|
|
# return torch.log((var + self.sigma2_min - 1.0) / self.sigma2_min) / np.log(self.sigma2_max / self.sigma2_min)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import matplotlib.pyplot as plt
|
|
|
|
class Foo:
|
|
def __init__(self):
|
|
self.sde_type = 'vpsde'
|
|
self.sigma2_0 = 0.01
|
|
self.sigma2_min = 3e-5
|
|
self.sigma2_max = 0.999
|
|
self.beta_start = 0.1
|
|
self.beta_end = 20
|
|
|
|
# A unit test to check the implementation of e2intf and var_t
|
|
diff = make_diffusion(Foo())
|
|
|
|
print(diff.inv_var(diff.var(torch.tensor(0.5))))
|
|
exit()
|
|
|
|
delta = 1e-8
|
|
t = np.arange(start=0.001, stop=0.999, step=delta)
|
|
t = torch.tensor(t)
|
|
|
|
f_t = diff.f(t)
|
|
e2intf = diff.e2int_f(t)
|
|
# compute finite differences for e2intf
|
|
grad_fd = (e2intf[1:] - e2intf[:-1]) / delta
|
|
grad_analytic = f_t[:-1] * e2intf[:-1]
|
|
print(torch.max(torch.abs(grad_fd - grad_analytic)))
|
|
|
|
var_t = diff.var(t)
|
|
grad_fd = (var_t[1:] - var_t[:-1]) / delta
|
|
grad_analytic = (2 * f_t * var_t + diff.g2(t))[:-1]
|
|
print(torch.max(torch.abs(grad_fd - grad_analytic)))
|