📝 docstrings yet again + rename some variables

This commit is contained in:
Laureηt 2023-07-28 19:23:03 +02:00
parent 01a172db1e
commit 425469e873
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
2 changed files with 16 additions and 13 deletions

View file

@ -3,34 +3,35 @@ include("Schedulers.jl")
"""
Denoising Diffusion Probabilistic Models (DDPM) scheduler.
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006).
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
"""
struct DDPM{V<:AbstractVector} <: Scheduler
# number of diffusion steps used to train the model.
num_train_timesteps::Int
T_train::Int
# the betas to use for the diffusion steps
βs::V
αs::V
# the betas used for the diffusion steps
β::V
# internal variables used for computation (derived from β)
α::V
α_cumprods::V
α_cumprod_prevs::V
sqrt_α_cumprods::V
sqrt_one_minus_α_cumprods::V
end
function DDPM(V::DataType, βs::AbstractVector)
αs = 1 .- βs
α_cumprods = cumprod(αs)
function DDPM(V::DataType, β::AbstractVector)
α = 1 .- β
α_cumprods = cumprod(α)
α_cumprod_prevs = [1, (α_cumprods[1:end-1])...]
sqrt_α_cumprods = sqrt.(α_cumprods)
sqrt_one_minus_α_cumprods = sqrt.(1 .- α_cumprods)
DDPM{V}(
length(βs),
βs,
αs,
length(β),
β,
α,
α_cumprods,
α_cumprod_prevs,
sqrt_α_cumprods,
@ -81,7 +82,7 @@ function step(
pred_prev_sample = pred_original_sample_coeff' .* x_0_pred + current_sample_coeff' .* sample
# 6. Add noise
variance = sqrt.(scheduler.βs[timesteps])' .* randn(size(model_output))
variance = sqrt.(scheduler.β[timesteps])' .* randn(size(model_output))
pred_prev_sample = pred_prev_sample + variance
return pred_prev_sample, x_0_pred

View file

@ -3,6 +3,8 @@ abstract type Scheduler end
"""
Add noise to clean data using the forward diffusion process.
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) (Eq. 4)
## Input
* scheduler (`Scheduler`): scheduler to use
* clean_data (`AbstractArray`): clean data to add noise to