171 lines
7.1 KiB
Python
171 lines
7.1 KiB
Python
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||
#
|
||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||
# and proprietary rights in and to this software, related documentation
|
||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||
# distribution of this software and related documentation without an express
|
||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||
"""
|
||
copied and modified from source:
|
||
https://github.com/NVlabs/LSGM/blob/5eae2f385c014f2250c3130152b6be711f6a3a5a/diffusion_discretized.py
|
||
"""
|
||
from loguru import logger
|
||
import time
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from torch.nn import Module, Parameter, ModuleList
|
||
import numpy as np
|
||
|
||
|
||
def extract(input, t, shape):
|
||
B = t.shape[0]
|
||
out = torch.gather(input, 0, t.to(input.device))
|
||
reshape = [shape[0]] + [1] * (len(shape) - 1)
|
||
out = out.reshape(*reshape)
|
||
return out
|
||
|
||
|
||
def make_beta_schedule(schedule, start, end, n_timestep):
|
||
if schedule == "cust": # airplane
|
||
b_start = start
|
||
b_end = end
|
||
time_num = n_timestep
|
||
betas = b_end * np.ones(time_num, dtype=np.float64)
|
||
warmup_time = int(time_num * 0.1)
|
||
betas[:warmup_time] = np.linspace(
|
||
b_start, b_end, warmup_time, dtype=np.float64)
|
||
betas = torch.from_numpy(betas)
|
||
|
||
#betas = torch.zeros(n_timestep, dtype=torch.float64) + end
|
||
#n_timestep_90 = int(n_timestep*0.9)
|
||
# betas_0 = torch.linspace(start,
|
||
# end,
|
||
# n_timestep_90,
|
||
# dtype=torch.float64)
|
||
#betas[:n_timestep_90] = betas_0
|
||
|
||
elif schedule == "quad":
|
||
betas = torch.linspace(start**0.5,
|
||
end**0.5,
|
||
n_timestep,
|
||
dtype=torch.float64)**2
|
||
elif schedule == 'linear':
|
||
betas = torch.linspace(start, end, n_timestep, dtype=torch.float64)
|
||
elif schedule == 'warmup10':
|
||
betas = _warmup_beta(start, end, n_timestep, 0.1)
|
||
elif schedule == 'warmup50':
|
||
betas = _warmup_beta(start, end, n_timestep, 0.5)
|
||
elif schedule == 'const':
|
||
betas = end * torch.ones(n_timestep, dtype=torch.float64)
|
||
elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
||
betas = 1. / (torch.linspace(
|
||
n_timestep, 1, n_timestep, dtype=torch.float64))
|
||
else:
|
||
raise NotImplementedError(schedule)
|
||
return betas
|
||
|
||
|
||
class VarianceSchedule(Module):
|
||
def __init__(self, num_steps, beta_1, beta_T, mode='linear'):
|
||
super().__init__()
|
||
assert mode in ('linear', 'cust')
|
||
self.num_steps = num_steps
|
||
self.beta_1 = beta_1
|
||
self.beta_T = beta_T
|
||
self.mode = mode
|
||
beta_start = self.beta_1
|
||
beta_end = self.beta_T
|
||
assert (beta_start <= beta_end), 'require beta_start < beta_end '
|
||
|
||
logger.info('use beta: {} - {}', beta_1, beta_T)
|
||
tic = time.time()
|
||
# betas = torch.linspace(beta_1, beta_T, steps=num_steps)
|
||
betas = make_beta_schedule(mode, beta_start, beta_end, num_steps)
|
||
# elif mode == 'customer':
|
||
# beta_0 = 10−5 and beta_T = 0.008 for 90% step, beta_T=0.0088
|
||
## num_steps_90 = int(0.9 * num_steps)
|
||
# logger.info('use beta_0=1e-5 and beta_T=0.008 '
|
||
## 'for {} step and 0.008 for the rest',
|
||
# num_steps_90)
|
||
## betas_sub = torch.linspace(1e-5, 0.008, steps=num_steps_90)
|
||
## betas_full = torch.zeros(num_steps) + 0.008
|
||
## betas_full[:num_steps_90] = betas_sub
|
||
## betas = betas_full
|
||
|
||
# betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding
|
||
|
||
#alphas = 1 - betas
|
||
#log_alphas = torch.log(alphas)
|
||
# for i in range(1, log_alphas.size(0)): # 1 to T
|
||
# log_alphas[i] += log_alphas[i - 1]
|
||
#alpha_bars = log_alphas.exp()
|
||
|
||
#sigmas_flex = torch.sqrt(betas)
|
||
#sigmas_inflex = torch.zeros_like(sigmas_flex)
|
||
# for i in range(1, sigmas_flex.size(0)):
|
||
# sigmas_inflex[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i]
|
||
#sigmas_inflex = torch.sqrt(sigmas_inflex)
|
||
#sqrt_recip_alphas_cumprod = torch.rsqrt(alpha_bars)
|
||
#sqrt_recipm1_alphas_cumprod = torch.sqrt(1 / alpha_bars - 1)
|
||
|
||
#self.register_buffer('betas', betas)
|
||
#self.register_buffer('alphas', alphas)
|
||
#self.register_buffer('alpha_bars', alpha_bars)
|
||
#self.register_buffer('sigmas_flex', sigmas_flex)
|
||
#self.register_buffer('sigmas_inflex', sigmas_inflex)
|
||
#self.register_buffer('sqrt_recip_alphas_cumprod', sqrt_recip_alphas_cumprod)
|
||
# self.register_buffer('sqrt_recipm1_alphas_cumprod',
|
||
# sqrt_recipm1_alphas_cumprod)
|
||
alphas = 1 - betas
|
||
alphas_cumprod = torch.cumprod(alphas, 0)
|
||
alphas_cumprod_prev = torch.cat(
|
||
(torch.tensor([1], dtype=torch.float64), alphas_cumprod[:-1]), 0)
|
||
posterior_variance = betas * (1 - alphas_cumprod_prev) / (
|
||
1 - alphas_cumprod)
|
||
self.register("betas", betas)
|
||
self.register("alphas_cumprod", alphas_cumprod)
|
||
self.register("alphas_cumprod_prev", alphas_cumprod_prev)
|
||
|
||
self.register("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
|
||
self.register("sqrt_one_minus_alphas_cumprod",
|
||
torch.sqrt(1 - alphas_cumprod))
|
||
self.register("log_one_minus_alphas_cumprod",
|
||
torch.log(1 - alphas_cumprod))
|
||
self.register("sqrt_recip_alphas_cumprod", torch.rsqrt(alphas_cumprod))
|
||
self.register("sqrt_recipm1_alphas_cumprod",
|
||
torch.sqrt(1 / alphas_cumprod - 1))
|
||
self.register("posterior_variance", posterior_variance)
|
||
if len(posterior_variance) > 1:
|
||
self.register("posterior_log_variance_clipped",
|
||
torch.log(
|
||
torch.cat((posterior_variance[1].view(
|
||
1, 1), posterior_variance[1:].view(-1, 1)), 0)).view(-1)
|
||
)
|
||
else:
|
||
self.register("posterior_log_variance_clipped",
|
||
torch.log(posterior_variance[0].view(-1)))
|
||
self.register("posterior_mean_coef1",
|
||
(betas * torch.sqrt(alphas_cumprod_prev) /
|
||
(1 - alphas_cumprod)))
|
||
self.register("posterior_mean_coef2",
|
||
((1 - alphas_cumprod_prev) * torch.sqrt(alphas) /
|
||
(1 - alphas_cumprod)))
|
||
logger.info('built beta schedule: t={:.2f}s', time.time() - tic)
|
||
|
||
def register(self, name, tensor):
|
||
self.register_buffer(name, tensor.type(torch.float32))
|
||
|
||
def all_sample_t(self):
|
||
if self.num_steps > 20:
|
||
step = 50
|
||
else:
|
||
step = 1
|
||
ts = np.arange(0, self.num_steps, step)
|
||
return ts.tolist()
|
||
|
||
def get_sigmas(self, t, flexibility):
|
||
assert 0 <= flexibility and flexibility <= 1
|
||
sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (
|
||
1 - flexibility)
|
||
return sigmas
|