(Schedulers) create Abstract.jl holding scheduler type and methods

This commit is contained in:
Laureηt 2023-07-31 22:41:14 +02:00
parent 9d5201d068
commit 42c2bcb5bb
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
2 changed files with 45 additions and 30 deletions

View file

@ -0,0 +1,44 @@
"""
Abstract type for schedulers.
"""
abstract type Scheduler end
"""
Add noise to clean data using the forward diffusion process.
## Input
* `scheduler::Scheduler`: scheduler to use
* `x₀::AbstractArray`: clean data to add noise to
* `ϵ::AbstractArray`: noise to add to clean data
* `t::AbstractArray`: timesteps used to weight the noise
## Output
* `xₜ::AbstractArray`: noisy data at the given timesteps
"""
function add_noise(
scheduler::Scheduler,
x₀::AbstractArray,
ϵ::AbstractArray,
t::AbstractArray,
) end
"""
Remove noise from model output using the backward diffusion process.
## Input
* `scheduler::Scheduler`: scheduler to use
* `xₜ::AbstractArray`: sample to be denoised
* `ϵᵧ::AbstractArray`: predicted noise to remove
* `t::AbstractArray`: timestep t of `xₜ`
## Output
* `xₜ₋₁::AbstractArray`: denoised sample at t=t-1
* `x̂₀::AbstractArray`: denoised sample at t=0
"""
function step(
scheduler::Scheduler,
xₜ::AbstractArray,
ϵᵧ::AbstractArray,
t::AbstractArray,
) end

View file

@ -1,8 +1,4 @@
"""
Abstract type for schedulers.
"""
abstract type Scheduler end
include("Abstract.jl")
"""
Denoising Diffusion Probabilistic Models (DDPM) scheduler.
@ -69,18 +65,6 @@ function DDPM(V::DataType, β::AbstractVector)
)
end
"""
Add noise to clean data using the forward diffusion process.
## Input
* `scheduler::DDPM`: scheduler to use
* `x₀::AbstractArray`: clean data to add noise to
* `ϵ::AbstractArray`: noise to add to clean data
* `t::AbstractArray`: timesteps used to weight the noise
## Output
* `xₜ::AbstractArray`: noisy data at the given timesteps
"""
function add_noise(
scheduler::DDPM,
x₀::AbstractArray,
@ -102,19 +86,6 @@ function add_noise(
return xₜ
end
"""
Remove noise from model output using the backward diffusion process.
## Input
* `scheduler::DDPM`: scheduler to use
* `xₜ::AbstractArray`: sample to be denoised
* `ϵᵧ::AbstractArray`: predicted noise to remove
* `t::AbstractArray`: timestep t of `xₜ`
## Output
* `xₜ₋₁::AbstractArray`: denoised sample at t=t-1
* `x̂₀::AbstractArray`: denoised sample at t=0
"""
function step(
scheduler::DDPM,
xₜ::AbstractArray,