📝 docstrings yet again + rename some variables

This commit is contained in:
Laureηt 2023-07-28 19:23:03 +02:00
parent 01a172db1e
commit 425469e873
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
2 changed files with 16 additions and 13 deletions

View file

@ -3,34 +3,35 @@ include("Schedulers.jl")
""" """
Denoising Diffusion Probabilistic Models (DDPM) scheduler. Denoising Diffusion Probabilistic Models (DDPM) scheduler.
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006). cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
""" """
struct DDPM{V<:AbstractVector} <: Scheduler struct DDPM{V<:AbstractVector} <: Scheduler
# number of diffusion steps used to train the model. # number of diffusion steps used to train the model.
num_train_timesteps::Int T_train::Int
# the betas to use for the diffusion steps # the betas used for the diffusion steps
βs::V β::V
αs::V
# internal variables used for computation (derived from β)
α::V
α_cumprods::V α_cumprods::V
α_cumprod_prevs::V α_cumprod_prevs::V
sqrt_α_cumprods::V sqrt_α_cumprods::V
sqrt_one_minus_α_cumprods::V sqrt_one_minus_α_cumprods::V
end end
function DDPM(V::DataType, βs::AbstractVector) function DDPM(V::DataType, β::AbstractVector)
αs = 1 .- βs α = 1 .- β
α_cumprods = cumprod(αs) α_cumprods = cumprod(α)
α_cumprod_prevs = [1, (α_cumprods[1:end-1])...] α_cumprod_prevs = [1, (α_cumprods[1:end-1])...]
sqrt_α_cumprods = sqrt.(α_cumprods) sqrt_α_cumprods = sqrt.(α_cumprods)
sqrt_one_minus_α_cumprods = sqrt.(1 .- α_cumprods) sqrt_one_minus_α_cumprods = sqrt.(1 .- α_cumprods)
DDPM{V}( DDPM{V}(
length(βs), length(β),
βs, β,
αs, α,
α_cumprods, α_cumprods,
α_cumprod_prevs, α_cumprod_prevs,
sqrt_α_cumprods, sqrt_α_cumprods,
@ -81,7 +82,7 @@ function step(
pred_prev_sample = pred_original_sample_coeff' .* x_0_pred + current_sample_coeff' .* sample pred_prev_sample = pred_original_sample_coeff' .* x_0_pred + current_sample_coeff' .* sample
# 6. Add noise # 6. Add noise
variance = sqrt.(scheduler.βs[timesteps])' .* randn(size(model_output)) variance = sqrt.(scheduler.β[timesteps])' .* randn(size(model_output))
pred_prev_sample = pred_prev_sample + variance pred_prev_sample = pred_prev_sample + variance
return pred_prev_sample, x_0_pred return pred_prev_sample, x_0_pred

View file

@ -3,6 +3,8 @@ abstract type Scheduler end
""" """
Add noise to clean data using the forward diffusion process. Add noise to clean data using the forward diffusion process.
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) (Eq. 4)
## Input ## Input
* scheduler (`Scheduler`): scheduler to use * scheduler (`Scheduler`): scheduler to use
* clean_data (`AbstractArray`): clean data to add noise to * clean_data (`AbstractArray`): clean data to add noise to