diff --git a/src/DDPM.jl b/src/DDPM.jl index 826b6a8..be491cf 100644 --- a/src/DDPM.jl +++ b/src/DDPM.jl @@ -3,34 +3,35 @@ include("Schedulers.jl") """ Denoising Diffusion Probabilistic Models (DDPM) scheduler. -cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006). +cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) """ struct DDPM{V<:AbstractVector} <: Scheduler # number of diffusion steps used to train the model. - num_train_timesteps::Int + T_train::Int - # the betas to use for the diffusion steps - βs::V - αs::V + # the betas used for the diffusion steps + β::V + + # internal variables used for computation (derived from β) + α::V α_cumprods::V α_cumprod_prevs::V - sqrt_α_cumprods::V sqrt_one_minus_α_cumprods::V end -function DDPM(V::DataType, βs::AbstractVector) - αs = 1 .- βs - α_cumprods = cumprod(αs) +function DDPM(V::DataType, β::AbstractVector) + α = 1 .- β + α_cumprods = cumprod(α) α_cumprod_prevs = [1, (α_cumprods[1:end-1])...] sqrt_α_cumprods = sqrt.(α_cumprods) sqrt_one_minus_α_cumprods = sqrt.(1 .- α_cumprods) DDPM{V}( - length(βs), - βs, - αs, + length(β), + β, + α, α_cumprods, α_cumprod_prevs, sqrt_α_cumprods, @@ -81,7 +82,7 @@ function step( pred_prev_sample = pred_original_sample_coeff' .* x_0_pred + current_sample_coeff' .* sample # 6. Add noise - variance = sqrt.(scheduler.βs[timesteps])' .* randn(size(model_output)) + variance = sqrt.(scheduler.β[timesteps])' .* randn(size(model_output)) pred_prev_sample = pred_prev_sample + variance return pred_prev_sample, x_0_pred diff --git a/src/Schedulers.jl b/src/Schedulers.jl index 6c7a5f2..707d21f 100644 --- a/src/Schedulers.jl +++ b/src/Schedulers.jl @@ -3,6 +3,8 @@ abstract type Scheduler end """ Add noise to clean data using the forward diffusion process. +cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) (Eq. 4) + ## Input * scheduler (`Scheduler`): scheduler to use * clean_data (`AbstractArray`): clean data to add noise to