LION/utils/diffusion.py

171 lines
7.1 KiB
Python
Raw Permalink Normal View History

2023-01-23 05:14:49 +00:00
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
"""
copied and modified from source:
https://github.com/NVlabs/LSGM/blob/5eae2f385c014f2250c3130152b6be711f6a3a5a/diffusion_discretized.py
"""
from loguru import logger
import time
import torch
import torch.nn.functional as F
from torch.nn import Module, Parameter, ModuleList
import numpy as np
def extract(input, t, shape):
B = t.shape[0]
out = torch.gather(input, 0, t.to(input.device))
reshape = [shape[0]] + [1] * (len(shape) - 1)
out = out.reshape(*reshape)
return out
def make_beta_schedule(schedule, start, end, n_timestep):
if schedule == "cust": # airplane
b_start = start
b_end = end
time_num = n_timestep
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(
b_start, b_end, warmup_time, dtype=np.float64)
betas = torch.from_numpy(betas)
#betas = torch.zeros(n_timestep, dtype=torch.float64) + end
#n_timestep_90 = int(n_timestep*0.9)
# betas_0 = torch.linspace(start,
# end,
# n_timestep_90,
# dtype=torch.float64)
#betas[:n_timestep_90] = betas_0
elif schedule == "quad":
betas = torch.linspace(start**0.5,
end**0.5,
n_timestep,
dtype=torch.float64)**2
elif schedule == 'linear':
betas = torch.linspace(start, end, n_timestep, dtype=torch.float64)
elif schedule == 'warmup10':
betas = _warmup_beta(start, end, n_timestep, 0.1)
elif schedule == 'warmup50':
betas = _warmup_beta(start, end, n_timestep, 0.5)
elif schedule == 'const':
betas = end * torch.ones(n_timestep, dtype=torch.float64)
elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1
betas = 1. / (torch.linspace(
n_timestep, 1, n_timestep, dtype=torch.float64))
else:
raise NotImplementedError(schedule)
return betas
class VarianceSchedule(Module):
def __init__(self, num_steps, beta_1, beta_T, mode='linear'):
super().__init__()
assert mode in ('linear', 'cust')
self.num_steps = num_steps
self.beta_1 = beta_1
self.beta_T = beta_T
self.mode = mode
beta_start = self.beta_1
beta_end = self.beta_T
assert (beta_start <= beta_end), 'require beta_start < beta_end '
logger.info('use beta: {} - {}', beta_1, beta_T)
tic = time.time()
# betas = torch.linspace(beta_1, beta_T, steps=num_steps)
betas = make_beta_schedule(mode, beta_start, beta_end, num_steps)
# elif mode == 'customer':
# beta_0 = 105 and beta_T = 0.008 for 90% step, beta_T=0.0088
## num_steps_90 = int(0.9 * num_steps)
# logger.info('use beta_0=1e-5 and beta_T=0.008 '
## 'for {} step and 0.008 for the rest',
# num_steps_90)
## betas_sub = torch.linspace(1e-5, 0.008, steps=num_steps_90)
## betas_full = torch.zeros(num_steps) + 0.008
## betas_full[:num_steps_90] = betas_sub
## betas = betas_full
# betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding
#alphas = 1 - betas
#log_alphas = torch.log(alphas)
# for i in range(1, log_alphas.size(0)): # 1 to T
# log_alphas[i] += log_alphas[i - 1]
#alpha_bars = log_alphas.exp()
#sigmas_flex = torch.sqrt(betas)
#sigmas_inflex = torch.zeros_like(sigmas_flex)
# for i in range(1, sigmas_flex.size(0)):
# sigmas_inflex[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i]
#sigmas_inflex = torch.sqrt(sigmas_inflex)
#sqrt_recip_alphas_cumprod = torch.rsqrt(alpha_bars)
#sqrt_recipm1_alphas_cumprod = torch.sqrt(1 / alpha_bars - 1)
#self.register_buffer('betas', betas)
#self.register_buffer('alphas', alphas)
#self.register_buffer('alpha_bars', alpha_bars)
#self.register_buffer('sigmas_flex', sigmas_flex)
#self.register_buffer('sigmas_inflex', sigmas_inflex)
#self.register_buffer('sqrt_recip_alphas_cumprod', sqrt_recip_alphas_cumprod)
# self.register_buffer('sqrt_recipm1_alphas_cumprod',
# sqrt_recipm1_alphas_cumprod)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, 0)
alphas_cumprod_prev = torch.cat(
(torch.tensor([1], dtype=torch.float64), alphas_cumprod[:-1]), 0)
posterior_variance = betas * (1 - alphas_cumprod_prev) / (
1 - alphas_cumprod)
self.register("betas", betas)
self.register("alphas_cumprod", alphas_cumprod)
self.register("alphas_cumprod_prev", alphas_cumprod_prev)
self.register("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
self.register("sqrt_one_minus_alphas_cumprod",
torch.sqrt(1 - alphas_cumprod))
self.register("log_one_minus_alphas_cumprod",
torch.log(1 - alphas_cumprod))
self.register("sqrt_recip_alphas_cumprod", torch.rsqrt(alphas_cumprod))
self.register("sqrt_recipm1_alphas_cumprod",
torch.sqrt(1 / alphas_cumprod - 1))
self.register("posterior_variance", posterior_variance)
if len(posterior_variance) > 1:
self.register("posterior_log_variance_clipped",
torch.log(
torch.cat((posterior_variance[1].view(
1, 1), posterior_variance[1:].view(-1, 1)), 0)).view(-1)
)
else:
self.register("posterior_log_variance_clipped",
torch.log(posterior_variance[0].view(-1)))
self.register("posterior_mean_coef1",
(betas * torch.sqrt(alphas_cumprod_prev) /
(1 - alphas_cumprod)))
self.register("posterior_mean_coef2",
((1 - alphas_cumprod_prev) * torch.sqrt(alphas) /
(1 - alphas_cumprod)))
logger.info('built beta schedule: t={:.2f}s', time.time() - tic)
def register(self, name, tensor):
self.register_buffer(name, tensor.type(torch.float32))
def all_sample_t(self):
if self.num_steps > 20:
step = 50
else:
step = 1
ts = np.arange(0, self.num_steps, step)
return ts.tolist()
def get_sigmas(self, t, flexibility):
assert 0 <= flexibility and flexibility <= 1
sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (
1 - flexibility)
return sigmas