mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-11-09 15:02:02 +00:00
🚚 (Schedulers) move schedulers in their own module
This commit is contained in:
parent
3570bc8379
commit
3c9841ffd9
|
@ -1,12 +1,6 @@
|
||||||
module Diffusers
|
module Diffusers
|
||||||
|
|
||||||
include("BetaSchedules/BetaSchedules.jl")
|
include("BetaSchedules/BetaSchedules.jl")
|
||||||
|
include("Schedulers/Schedulers.jl")
|
||||||
# abtract types
|
|
||||||
include("Schedulers.jl")
|
|
||||||
|
|
||||||
# concrete types
|
|
||||||
include("DDPM.jl")
|
|
||||||
# include("DDIM.jl")
|
|
||||||
|
|
||||||
end # module Diffusers
|
end # module Diffusers
|
||||||
|
|
|
@ -1,29 +0,0 @@
|
||||||
abstract type Scheduler end
|
|
||||||
|
|
||||||
"""
|
|
||||||
Add noise to clean data using the forward diffusion process.
|
|
||||||
|
|
||||||
## Input
|
|
||||||
* `scheduler::Scheduler`: scheduler to use
|
|
||||||
* `x₀::AbstractArray`: clean data to add noise to
|
|
||||||
* `ϵ::AbstractArray`: noise to add to clean data
|
|
||||||
* `t::AbstractArray`: timesteps used to weight the noise
|
|
||||||
|
|
||||||
## Output
|
|
||||||
* `xₜ::AbstractArray`: noisy data at the given timesteps
|
|
||||||
"""
|
|
||||||
function add_noise(
|
|
||||||
scheduler::Scheduler,
|
|
||||||
x₀::AbstractArray,
|
|
||||||
ϵ::AbstractArray,
|
|
||||||
t::AbstractArray,
|
|
||||||
)
|
|
||||||
⎷α̅ₜ = scheduler.⎷α̅[t]
|
|
||||||
⎷β̅ₜ = scheduler.⎷β̅[t]
|
|
||||||
|
|
||||||
# noisify clean data
|
|
||||||
# arxiv:2006.11239 Eq. 4
|
|
||||||
xₜ = ⎷α̅ₜ' .* x₀ + ⎷β̅ₜ' .* ϵ
|
|
||||||
|
|
||||||
return xₜ
|
|
||||||
end
|
|
|
@ -1,4 +1,4 @@
|
||||||
include("Schedulers.jl")
|
abstract type Scheduler end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Denoising Diffusion Probabilistic Models (DDPM) scheduler.
|
Denoising Diffusion Probabilistic Models (DDPM) scheduler.
|
||||||
|
@ -65,6 +65,34 @@ function DDPM(V::DataType, β::AbstractVector)
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
Add noise to clean data using the forward diffusion process.
|
||||||
|
|
||||||
|
## Input
|
||||||
|
* `scheduler::Scheduler`: scheduler to use
|
||||||
|
* `x₀::AbstractArray`: clean data to add noise to
|
||||||
|
* `ϵ::AbstractArray`: noise to add to clean data
|
||||||
|
* `t::AbstractArray`: timesteps used to weight the noise
|
||||||
|
|
||||||
|
## Output
|
||||||
|
* `xₜ::AbstractArray`: noisy data at the given timesteps
|
||||||
|
"""
|
||||||
|
function add_noise(
|
||||||
|
scheduler::DDPM,
|
||||||
|
x₀::AbstractArray,
|
||||||
|
ϵ::AbstractArray,
|
||||||
|
t::AbstractArray,
|
||||||
|
)
|
||||||
|
⎷α̅ₜ = scheduler.⎷α̅[t]
|
||||||
|
⎷β̅ₜ = scheduler.⎷β̅[t]
|
||||||
|
|
||||||
|
# noisify clean data
|
||||||
|
# arxiv:2006.11239 Eq. 4
|
||||||
|
xₜ = ⎷α̅ₜ' .* x₀ + ⎷β̅ₜ' .* ϵ
|
||||||
|
|
||||||
|
return xₜ
|
||||||
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Remove noise from model output using the backward diffusion process.
|
Remove noise from model output using the backward diffusion process.
|
||||||
|
|
9
src/Schedulers/Schedulers.jl
Normal file
9
src/Schedulers/Schedulers.jl
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
module Schedulers
|
||||||
|
|
||||||
|
include("DDPM.jl")
|
||||||
|
|
||||||
|
export DDPM
|
||||||
|
|
||||||
|
export add_noise, step
|
||||||
|
|
||||||
|
end # module Schedulers
|
Loading…
Reference in a new issue