mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-11-09 23:12:03 +00:00
📝 docstrings yet again + rename some variables
This commit is contained in:
parent
01a172db1e
commit
425469e873
27
src/DDPM.jl
27
src/DDPM.jl
|
@ -3,34 +3,35 @@ include("Schedulers.jl")
|
||||||
"""
|
"""
|
||||||
Denoising Diffusion Probabilistic Models (DDPM) scheduler.
|
Denoising Diffusion Probabilistic Models (DDPM) scheduler.
|
||||||
|
|
||||||
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006).
|
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
||||||
"""
|
"""
|
||||||
struct DDPM{V<:AbstractVector} <: Scheduler
|
struct DDPM{V<:AbstractVector} <: Scheduler
|
||||||
# number of diffusion steps used to train the model.
|
# number of diffusion steps used to train the model.
|
||||||
num_train_timesteps::Int
|
T_train::Int
|
||||||
|
|
||||||
# the betas to use for the diffusion steps
|
# the betas used for the diffusion steps
|
||||||
βs::V
|
β::V
|
||||||
αs::V
|
|
||||||
|
# internal variables used for computation (derived from β)
|
||||||
|
α::V
|
||||||
α_cumprods::V
|
α_cumprods::V
|
||||||
α_cumprod_prevs::V
|
α_cumprod_prevs::V
|
||||||
|
|
||||||
sqrt_α_cumprods::V
|
sqrt_α_cumprods::V
|
||||||
sqrt_one_minus_α_cumprods::V
|
sqrt_one_minus_α_cumprods::V
|
||||||
end
|
end
|
||||||
|
|
||||||
function DDPM(V::DataType, βs::AbstractVector)
|
function DDPM(V::DataType, β::AbstractVector)
|
||||||
αs = 1 .- βs
|
α = 1 .- β
|
||||||
α_cumprods = cumprod(αs)
|
α_cumprods = cumprod(α)
|
||||||
α_cumprod_prevs = [1, (α_cumprods[1:end-1])...]
|
α_cumprod_prevs = [1, (α_cumprods[1:end-1])...]
|
||||||
|
|
||||||
sqrt_α_cumprods = sqrt.(α_cumprods)
|
sqrt_α_cumprods = sqrt.(α_cumprods)
|
||||||
sqrt_one_minus_α_cumprods = sqrt.(1 .- α_cumprods)
|
sqrt_one_minus_α_cumprods = sqrt.(1 .- α_cumprods)
|
||||||
|
|
||||||
DDPM{V}(
|
DDPM{V}(
|
||||||
length(βs),
|
length(β),
|
||||||
βs,
|
β,
|
||||||
αs,
|
α,
|
||||||
α_cumprods,
|
α_cumprods,
|
||||||
α_cumprod_prevs,
|
α_cumprod_prevs,
|
||||||
sqrt_α_cumprods,
|
sqrt_α_cumprods,
|
||||||
|
@ -81,7 +82,7 @@ function step(
|
||||||
pred_prev_sample = pred_original_sample_coeff' .* x_0_pred + current_sample_coeff' .* sample
|
pred_prev_sample = pred_original_sample_coeff' .* x_0_pred + current_sample_coeff' .* sample
|
||||||
|
|
||||||
# 6. Add noise
|
# 6. Add noise
|
||||||
variance = sqrt.(scheduler.βs[timesteps])' .* randn(size(model_output))
|
variance = sqrt.(scheduler.β[timesteps])' .* randn(size(model_output))
|
||||||
pred_prev_sample = pred_prev_sample + variance
|
pred_prev_sample = pred_prev_sample + variance
|
||||||
|
|
||||||
return pred_prev_sample, x_0_pred
|
return pred_prev_sample, x_0_pred
|
||||||
|
|
|
@ -3,6 +3,8 @@ abstract type Scheduler end
|
||||||
"""
|
"""
|
||||||
Add noise to clean data using the forward diffusion process.
|
Add noise to clean data using the forward diffusion process.
|
||||||
|
|
||||||
|
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) (Eq. 4)
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* scheduler (`Scheduler`): scheduler to use
|
* scheduler (`Scheduler`): scheduler to use
|
||||||
* clean_data (`AbstractArray`): clean data to add noise to
|
* clean_data (`AbstractArray`): clean data to add noise to
|
||||||
|
|
Loading…
Reference in a new issue