LION/utils/diffusion.py
2023-01-23 00:14:49 -05:00

171 lines
7.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
"""
copied and modified from source:
https://github.com/NVlabs/LSGM/blob/5eae2f385c014f2250c3130152b6be711f6a3a5a/diffusion_discretized.py
"""
from loguru import logger
import time
import torch
import torch.nn.functional as F
from torch.nn import Module, Parameter, ModuleList
import numpy as np
def extract(input, t, shape):
B = t.shape[0]
out = torch.gather(input, 0, t.to(input.device))
reshape = [shape[0]] + [1] * (len(shape) - 1)
out = out.reshape(*reshape)
return out
def make_beta_schedule(schedule, start, end, n_timestep):
if schedule == "cust": # airplane
b_start = start
b_end = end
time_num = n_timestep
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(
b_start, b_end, warmup_time, dtype=np.float64)
betas = torch.from_numpy(betas)
#betas = torch.zeros(n_timestep, dtype=torch.float64) + end
#n_timestep_90 = int(n_timestep*0.9)
# betas_0 = torch.linspace(start,
# end,
# n_timestep_90,
# dtype=torch.float64)
#betas[:n_timestep_90] = betas_0
elif schedule == "quad":
betas = torch.linspace(start**0.5,
end**0.5,
n_timestep,
dtype=torch.float64)**2
elif schedule == 'linear':
betas = torch.linspace(start, end, n_timestep, dtype=torch.float64)
elif schedule == 'warmup10':
betas = _warmup_beta(start, end, n_timestep, 0.1)
elif schedule == 'warmup50':
betas = _warmup_beta(start, end, n_timestep, 0.5)
elif schedule == 'const':
betas = end * torch.ones(n_timestep, dtype=torch.float64)
elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1
betas = 1. / (torch.linspace(
n_timestep, 1, n_timestep, dtype=torch.float64))
else:
raise NotImplementedError(schedule)
return betas
class VarianceSchedule(Module):
def __init__(self, num_steps, beta_1, beta_T, mode='linear'):
super().__init__()
assert mode in ('linear', 'cust')
self.num_steps = num_steps
self.beta_1 = beta_1
self.beta_T = beta_T
self.mode = mode
beta_start = self.beta_1
beta_end = self.beta_T
assert (beta_start <= beta_end), 'require beta_start < beta_end '
logger.info('use beta: {} - {}', beta_1, beta_T)
tic = time.time()
# betas = torch.linspace(beta_1, beta_T, steps=num_steps)
betas = make_beta_schedule(mode, beta_start, beta_end, num_steps)
# elif mode == 'customer':
# beta_0 = 105 and beta_T = 0.008 for 90% step, beta_T=0.0088
## num_steps_90 = int(0.9 * num_steps)
# logger.info('use beta_0=1e-5 and beta_T=0.008 '
## 'for {} step and 0.008 for the rest',
# num_steps_90)
## betas_sub = torch.linspace(1e-5, 0.008, steps=num_steps_90)
## betas_full = torch.zeros(num_steps) + 0.008
## betas_full[:num_steps_90] = betas_sub
## betas = betas_full
# betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding
#alphas = 1 - betas
#log_alphas = torch.log(alphas)
# for i in range(1, log_alphas.size(0)): # 1 to T
# log_alphas[i] += log_alphas[i - 1]
#alpha_bars = log_alphas.exp()
#sigmas_flex = torch.sqrt(betas)
#sigmas_inflex = torch.zeros_like(sigmas_flex)
# for i in range(1, sigmas_flex.size(0)):
# sigmas_inflex[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i]
#sigmas_inflex = torch.sqrt(sigmas_inflex)
#sqrt_recip_alphas_cumprod = torch.rsqrt(alpha_bars)
#sqrt_recipm1_alphas_cumprod = torch.sqrt(1 / alpha_bars - 1)
#self.register_buffer('betas', betas)
#self.register_buffer('alphas', alphas)
#self.register_buffer('alpha_bars', alpha_bars)
#self.register_buffer('sigmas_flex', sigmas_flex)
#self.register_buffer('sigmas_inflex', sigmas_inflex)
#self.register_buffer('sqrt_recip_alphas_cumprod', sqrt_recip_alphas_cumprod)
# self.register_buffer('sqrt_recipm1_alphas_cumprod',
# sqrt_recipm1_alphas_cumprod)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, 0)
alphas_cumprod_prev = torch.cat(
(torch.tensor([1], dtype=torch.float64), alphas_cumprod[:-1]), 0)
posterior_variance = betas * (1 - alphas_cumprod_prev) / (
1 - alphas_cumprod)
self.register("betas", betas)
self.register("alphas_cumprod", alphas_cumprod)
self.register("alphas_cumprod_prev", alphas_cumprod_prev)
self.register("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
self.register("sqrt_one_minus_alphas_cumprod",
torch.sqrt(1 - alphas_cumprod))
self.register("log_one_minus_alphas_cumprod",
torch.log(1 - alphas_cumprod))
self.register("sqrt_recip_alphas_cumprod", torch.rsqrt(alphas_cumprod))
self.register("sqrt_recipm1_alphas_cumprod",
torch.sqrt(1 / alphas_cumprod - 1))
self.register("posterior_variance", posterior_variance)
if len(posterior_variance) > 1:
self.register("posterior_log_variance_clipped",
torch.log(
torch.cat((posterior_variance[1].view(
1, 1), posterior_variance[1:].view(-1, 1)), 0)).view(-1)
)
else:
self.register("posterior_log_variance_clipped",
torch.log(posterior_variance[0].view(-1)))
self.register("posterior_mean_coef1",
(betas * torch.sqrt(alphas_cumprod_prev) /
(1 - alphas_cumprod)))
self.register("posterior_mean_coef2",
((1 - alphas_cumprod_prev) * torch.sqrt(alphas) /
(1 - alphas_cumprod)))
logger.info('built beta schedule: t={:.2f}s', time.time() - tic)
def register(self, name, tensor):
self.register_buffer(name, tensor.type(torch.float32))
def all_sample_t(self):
if self.num_steps > 20:
step = 50
else:
step = 1
ts = np.arange(0, self.num_steps, step)
return ts.tolist()
def get_sigmas(self, t, flexibility):
assert 0 <= flexibility and flexibility <= 1
sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (
1 - flexibility)
return sigmas