From 42c2bcb5bb3fb78f10f8031187ee978408309594 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Mon, 31 Jul 2023 22:41:14 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20(Schedulers)=20create=20Abstract.jl?= =?UTF-8?q?=20holding=20scheduler=20type=20and=20methods?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/Schedulers/Abstract.jl | 44 ++++++++++++++++++++++++++++++++++++++ src/Schedulers/DDPM.jl | 31 +-------------------------- 2 files changed, 45 insertions(+), 30 deletions(-) create mode 100644 src/Schedulers/Abstract.jl diff --git a/src/Schedulers/Abstract.jl b/src/Schedulers/Abstract.jl new file mode 100644 index 0000000..5625d71 --- /dev/null +++ b/src/Schedulers/Abstract.jl @@ -0,0 +1,44 @@ +""" +Abstract type for schedulers. + +""" +abstract type Scheduler end + +""" +Add noise to clean data using the forward diffusion process. + +## Input + * `scheduler::Scheduler`: scheduler to use + * `x₀::AbstractArray`: clean data to add noise to + * `ϵ::AbstractArray`: noise to add to clean data + * `t::AbstractArray`: timesteps used to weight the noise + +## Output + * `xₜ::AbstractArray`: noisy data at the given timesteps +""" +function add_noise( + scheduler::Scheduler, + x₀::AbstractArray, + ϵ::AbstractArray, + t::AbstractArray, +) end + +""" +Remove noise from model output using the backward diffusion process. + +## Input + * `scheduler::Scheduler`: scheduler to use + * `xₜ::AbstractArray`: sample to be denoised + * `ϵᵧ::AbstractArray`: predicted noise to remove + * `t::AbstractArray`: timestep t of `xₜ` + +## Output + * `xₜ₋₁::AbstractArray`: denoised sample at t=t-1 + * `x̂₀::AbstractArray`: denoised sample at t=0 +""" +function step( + scheduler::Scheduler, + xₜ::AbstractArray, + ϵᵧ::AbstractArray, + t::AbstractArray, +) end diff --git a/src/Schedulers/DDPM.jl b/src/Schedulers/DDPM.jl index 9e14dad..846e6f1 100644 --- a/src/Schedulers/DDPM.jl +++ b/src/Schedulers/DDPM.jl @@ -1,8 +1,4 @@ -""" -Abstract type for schedulers. - -""" -abstract type Scheduler end +include("Abstract.jl") """ Denoising Diffusion Probabilistic Models (DDPM) scheduler. @@ -69,18 +65,6 @@ function DDPM(V::DataType, β::AbstractVector) ) end -""" -Add noise to clean data using the forward diffusion process. - -## Input - * `scheduler::DDPM`: scheduler to use - * `x₀::AbstractArray`: clean data to add noise to - * `ϵ::AbstractArray`: noise to add to clean data - * `t::AbstractArray`: timesteps used to weight the noise - -## Output - * `xₜ::AbstractArray`: noisy data at the given timesteps -""" function add_noise( scheduler::DDPM, x₀::AbstractArray, @@ -102,19 +86,6 @@ function add_noise( return xₜ end -""" -Remove noise from model output using the backward diffusion process. - -## Input - * `scheduler::DDPM`: scheduler to use - * `xₜ::AbstractArray`: sample to be denoised - * `ϵᵧ::AbstractArray`: predicted noise to remove - * `t::AbstractArray`: timestep t of `xₜ` - -## Output - * `xₜ₋₁::AbstractArray`: denoised sample at t=t-1 - * `x̂₀::AbstractArray`: denoised sample at t=0 -""" function step( scheduler::DDPM, xₜ::AbstractArray,