♻️ (Schedulers) rename methods names

add_noise -> forward
step -> reverse
This commit is contained in:
Laureηt 2023-08-14 16:26:15 +02:00
parent b6c8309733
commit cc0326450b
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
5 changed files with 13 additions and 13 deletions

View file

@ -16,7 +16,7 @@ Add noise to clean data using the forward diffusion process.
## Output ## Output
* `xₜ::AbstractArray`: noisy data at the given timesteps * `xₜ::AbstractArray`: noisy data at the given timesteps
""" """
function add_noise( function forward(
scheduler::Scheduler, scheduler::Scheduler,
x₀::AbstractArray, x₀::AbstractArray,
ϵ::AbstractArray, ϵ::AbstractArray,
@ -36,7 +36,7 @@ Remove noise from model output using the backward diffusion process.
* `xₜ₋₁::AbstractArray`: denoised sample at t=t-1 * `xₜ₋₁::AbstractArray`: denoised sample at t=t-1
* `x̂₀::AbstractArray`: denoised sample at t=0 * `x̂₀::AbstractArray`: denoised sample at t=0
""" """
function step( function reverse(
scheduler::Scheduler, scheduler::Scheduler,
xₜ::AbstractArray, xₜ::AbstractArray,
ϵᵧ::AbstractArray, ϵᵧ::AbstractArray,

View file

@ -65,7 +65,7 @@ function DDPM(β::AbstractVector)
) )
end end
function add_noise( function forward(
scheduler::DDPM, scheduler::DDPM,
x₀::AbstractArray, x₀::AbstractArray,
ϵ::AbstractArray, ϵ::AbstractArray,
@ -86,7 +86,7 @@ function add_noise(
return xₜ return xₜ
end end
function step( function reverse(
scheduler::DDPM, scheduler::DDPM,
xₜ::AbstractArray, xₜ::AbstractArray,
ϵᵧ::AbstractArray, ϵᵧ::AbstractArray,

View file

@ -7,7 +7,7 @@ export
DDPM, DDPM,
# Scheduler methods # Scheduler methods
add_noise, forward,
step reverse
end # module Schedulers end # module Schedulers

View file

@ -15,7 +15,7 @@ using Test
α̅_cosine = cumprod(1 .- β_cosine) α̅_cosine = cumprod(1 .- β_cosine)
α̅_sigmoid = cumprod(1 .- β_sigmoid) α̅_sigmoid = cumprod(1 .- β_sigmoid)
# arxiv:2208.11970 (eq. 109) # arxiv:2208.11970 Eq. 109
SNR_linear = α̅_linear ./ (1 .- α̅_linear) SNR_linear = α̅_linear ./ (1 .- α̅_linear)
SNR_scaled_linear = α̅_scaled_linear ./ (1 .- α̅_scaled_linear) SNR_scaled_linear = α̅_scaled_linear ./ (1 .- α̅_scaled_linear)
SNR_cosine = α̅_cosine ./ (1 .- α̅_cosine) SNR_cosine = α̅_cosine ./ (1 .- α̅_cosine)

View file

@ -1,9 +1,9 @@
import Diffusers: step, add_noise, DDPM, cosine_beta_schedule, rescale_zero_terminal_snr import Diffusers: reverse, forward, DDPM, cosine_beta_schedule
using Statistics using Statistics
using Test using Test
@testset "Schedulers tests" begin @testset "Schedulers tests" begin
@testset "check `step` correctness" begin @testset "check `reverse` correctness" begin
T = 10 T = 10
batch_size = 8 batch_size = 8
size = 128 size = 128
@ -20,15 +20,15 @@ using Test
for t in 1:T for t in 1:T
t = ones(UInt32, batch_size) .* t t = ones(UInt32, batch_size) .* t
# corrupt x₀ with noise # corrupt x₀ with noise
xₜ = Diffusers.add_noise(ddpm, x₀, ϵ, t) xₜ = forward(ddpm, x₀, ϵ, t)
# suppose a model predicted ϵ perfectly # suppose a model predicted ϵ perfectly
_, x̂₀ = Diffusers.step(ddpm, xₜ, ϵ, t) _, x̂₀ = reverse(ddpm, xₜ, ϵ, t)
# test that we recover x₀ # test that we recover x₀
@test x̂₀ x₀ @test x̂₀ x₀
end end
end end
@testset "check `add_noise` terminal SNR" begin @testset "check `forward` terminal SNR" begin
T = 10 T = 10
batch_size = 1 batch_size = 1
size = 2500 size = 2500
@ -46,7 +46,7 @@ using Test
t = ones(UInt32, batch_size) .* T t = ones(UInt32, batch_size) .* T
# corrupt x₀ with noise # corrupt x₀ with noise
xₜ = Diffusers.add_noise(ddpm, x₀, ϵ, t) xₜ = forward(ddpm, x₀, ϵ, t)
@test std(xₜ) 1.0 atol = 1.0f-3 @test std(xₜ) 1.0 atol = 1.0f-3
@test mean(xₜ) 0.0 atol = 1.0f-2 @test mean(xₜ) 0.0 atol = 1.0f-2