🚚 (Schedulers) move schedulers in their own module

This commit is contained in:
Laureηt 2023-07-29 16:14:55 +02:00
parent 3570bc8379
commit 3c9841ffd9
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
4 changed files with 39 additions and 37 deletions

View file

@ -1,12 +1,6 @@
module Diffusers
include("BetaSchedules/BetaSchedules.jl")
# abtract types
include("Schedulers.jl")
# concrete types
include("DDPM.jl")
# include("DDIM.jl")
include("Schedulers/Schedulers.jl")
end # module Diffusers

View file

@ -1,29 +0,0 @@
abstract type Scheduler end
"""
Add noise to clean data using the forward diffusion process.
## Input
* `scheduler::Scheduler`: scheduler to use
* `x₀::AbstractArray`: clean data to add noise to
* `ϵ::AbstractArray`: noise to add to clean data
* `t::AbstractArray`: timesteps used to weight the noise
## Output
* `xₜ::AbstractArray`: noisy data at the given timesteps
"""
function add_noise(
scheduler::Scheduler,
x₀::AbstractArray,
ϵ::AbstractArray,
t::AbstractArray,
)
⎷α̅ₜ = scheduler.⎷α̅[t]
⎷β̅ₜ = scheduler.⎷β̅[t]
# noisify clean data
# arxiv:2006.11239 Eq. 4
xₜ = ⎷α̅ₜ' .* x₀ + ⎷β̅ₜ' .* ϵ
return xₜ
end

View file

@ -1,4 +1,4 @@
include("Schedulers.jl")
abstract type Scheduler end
"""
Denoising Diffusion Probabilistic Models (DDPM) scheduler.
@ -65,6 +65,34 @@ function DDPM(V::DataType, β::AbstractVector)
)
end
"""
Add noise to clean data using the forward diffusion process.
## Input
* `scheduler::Scheduler`: scheduler to use
* `x₀::AbstractArray`: clean data to add noise to
* `ϵ::AbstractArray`: noise to add to clean data
* `t::AbstractArray`: timesteps used to weight the noise
## Output
* `xₜ::AbstractArray`: noisy data at the given timesteps
"""
function add_noise(
scheduler::DDPM,
x₀::AbstractArray,
ϵ::AbstractArray,
t::AbstractArray,
)
⎷α̅ₜ = scheduler.⎷α̅[t]
⎷β̅ₜ = scheduler.⎷β̅[t]
# noisify clean data
# arxiv:2006.11239 Eq. 4
xₜ = ⎷α̅ₜ' .* x₀ + ⎷β̅ₜ' .* ϵ
return xₜ
end
"""
Remove noise from model output using the backward diffusion process.

View file

@ -0,0 +1,9 @@
module Schedulers
include("DDPM.jl")
export DDPM
export add_noise, step
end # module Schedulers