diff --git a/src/Diffusers.jl b/src/Diffusers.jl index f4d683a..eeb80a2 100644 --- a/src/Diffusers.jl +++ b/src/Diffusers.jl @@ -1,12 +1,6 @@ module Diffusers include("BetaSchedules/BetaSchedules.jl") - -# abtract types -include("Schedulers.jl") - -# concrete types -include("DDPM.jl") -# include("DDIM.jl") +include("Schedulers/Schedulers.jl") end # module Diffusers diff --git a/src/Schedulers.jl b/src/Schedulers.jl deleted file mode 100644 index 1a1abe4..0000000 --- a/src/Schedulers.jl +++ /dev/null @@ -1,29 +0,0 @@ -abstract type Scheduler end - -""" -Add noise to clean data using the forward diffusion process. - -## Input - * `scheduler::Scheduler`: scheduler to use - * `x₀::AbstractArray`: clean data to add noise to - * `ϵ::AbstractArray`: noise to add to clean data - * `t::AbstractArray`: timesteps used to weight the noise - -## Output - * `xₜ::AbstractArray`: noisy data at the given timesteps -""" -function add_noise( - scheduler::Scheduler, - x₀::AbstractArray, - ϵ::AbstractArray, - t::AbstractArray, -) - ⎷α̅ₜ = scheduler.⎷α̅[t] - ⎷β̅ₜ = scheduler.⎷β̅[t] - - # noisify clean data - # arxiv:2006.11239 Eq. 4 - xₜ = ⎷α̅ₜ' .* x₀ + ⎷β̅ₜ' .* ϵ - - return xₜ -end diff --git a/src/DDPM.jl b/src/Schedulers/DDPM.jl similarity index 80% rename from src/DDPM.jl rename to src/Schedulers/DDPM.jl index d8aec81..1f655e7 100644 --- a/src/DDPM.jl +++ b/src/Schedulers/DDPM.jl @@ -1,4 +1,4 @@ -include("Schedulers.jl") +abstract type Scheduler end """ Denoising Diffusion Probabilistic Models (DDPM) scheduler. @@ -65,6 +65,34 @@ function DDPM(V::DataType, β::AbstractVector) ) end +""" +Add noise to clean data using the forward diffusion process. + +## Input + * `scheduler::Scheduler`: scheduler to use + * `x₀::AbstractArray`: clean data to add noise to + * `ϵ::AbstractArray`: noise to add to clean data + * `t::AbstractArray`: timesteps used to weight the noise + +## Output + * `xₜ::AbstractArray`: noisy data at the given timesteps +""" +function add_noise( + scheduler::DDPM, + x₀::AbstractArray, + ϵ::AbstractArray, + t::AbstractArray, +) + ⎷α̅ₜ = scheduler.⎷α̅[t] + ⎷β̅ₜ = scheduler.⎷β̅[t] + + # noisify clean data + # arxiv:2006.11239 Eq. 4 + xₜ = ⎷α̅ₜ' .* x₀ + ⎷β̅ₜ' .* ϵ + + return xₜ +end + """ Remove noise from model output using the backward diffusion process. diff --git a/src/Schedulers/Schedulers.jl b/src/Schedulers/Schedulers.jl new file mode 100644 index 0000000..f765611 --- /dev/null +++ b/src/Schedulers/Schedulers.jl @@ -0,0 +1,9 @@ +module Schedulers + +include("DDPM.jl") + +export DDPM + +export add_noise, step + +end # module Schedulers