diff --git a/src/Schedulers/Abstract.jl b/src/Schedulers/Abstract.jl new file mode 100644 index 0000000..5625d71 --- /dev/null +++ b/src/Schedulers/Abstract.jl @@ -0,0 +1,44 @@ +""" +Abstract type for schedulers. + +""" +abstract type Scheduler end + +""" +Add noise to clean data using the forward diffusion process. + +## Input + * `scheduler::Scheduler`: scheduler to use + * `x₀::AbstractArray`: clean data to add noise to + * `ϵ::AbstractArray`: noise to add to clean data + * `t::AbstractArray`: timesteps used to weight the noise + +## Output + * `xₜ::AbstractArray`: noisy data at the given timesteps +""" +function add_noise( + scheduler::Scheduler, + x₀::AbstractArray, + ϵ::AbstractArray, + t::AbstractArray, +) end + +""" +Remove noise from model output using the backward diffusion process. + +## Input + * `scheduler::Scheduler`: scheduler to use + * `xₜ::AbstractArray`: sample to be denoised + * `ϵᵧ::AbstractArray`: predicted noise to remove + * `t::AbstractArray`: timestep t of `xₜ` + +## Output + * `xₜ₋₁::AbstractArray`: denoised sample at t=t-1 + * `x̂₀::AbstractArray`: denoised sample at t=0 +""" +function step( + scheduler::Scheduler, + xₜ::AbstractArray, + ϵᵧ::AbstractArray, + t::AbstractArray, +) end diff --git a/src/Schedulers/DDPM.jl b/src/Schedulers/DDPM.jl index 9e14dad..846e6f1 100644 --- a/src/Schedulers/DDPM.jl +++ b/src/Schedulers/DDPM.jl @@ -1,8 +1,4 @@ -""" -Abstract type for schedulers. - -""" -abstract type Scheduler end +include("Abstract.jl") """ Denoising Diffusion Probabilistic Models (DDPM) scheduler. @@ -69,18 +65,6 @@ function DDPM(V::DataType, β::AbstractVector) ) end -""" -Add noise to clean data using the forward diffusion process. - -## Input - * `scheduler::DDPM`: scheduler to use - * `x₀::AbstractArray`: clean data to add noise to - * `ϵ::AbstractArray`: noise to add to clean data - * `t::AbstractArray`: timesteps used to weight the noise - -## Output - * `xₜ::AbstractArray`: noisy data at the given timesteps -""" function add_noise( scheduler::DDPM, x₀::AbstractArray, @@ -102,19 +86,6 @@ function add_noise( return xₜ end -""" -Remove noise from model output using the backward diffusion process. - -## Input - * `scheduler::DDPM`: scheduler to use - * `xₜ::AbstractArray`: sample to be denoised - * `ϵᵧ::AbstractArray`: predicted noise to remove - * `t::AbstractArray`: timestep t of `xₜ` - -## Output - * `xₜ₋₁::AbstractArray`: denoised sample at t=t-1 - * `x̂₀::AbstractArray`: denoised sample at t=0 -""" function step( scheduler::DDPM, xₜ::AbstractArray,