♻️ (Schedulers) rename methods names

add_noise -> forward
step -> reverse
This commit is contained in:
Laureηt 2023-08-14 16:26:15 +02:00
parent b6c8309733
commit cc0326450b
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
5 changed files with 13 additions and 13 deletions

View file

@ -16,7 +16,7 @@ Add noise to clean data using the forward diffusion process.
## Output
* `xₜ::AbstractArray`: noisy data at the given timesteps
"""
function add_noise(
function forward(
scheduler::Scheduler,
x₀::AbstractArray,
ϵ::AbstractArray,
@ -36,7 +36,7 @@ Remove noise from model output using the backward diffusion process.
* `xₜ₋₁::AbstractArray`: denoised sample at t=t-1
* `x̂₀::AbstractArray`: denoised sample at t=0
"""
function step(
function reverse(
scheduler::Scheduler,
xₜ::AbstractArray,
ϵᵧ::AbstractArray,

View file

@ -65,7 +65,7 @@ function DDPM(β::AbstractVector)
)
end
function add_noise(
function forward(
scheduler::DDPM,
x₀::AbstractArray,
ϵ::AbstractArray,
@ -86,7 +86,7 @@ function add_noise(
return xₜ
end
function step(
function reverse(
scheduler::DDPM,
xₜ::AbstractArray,
ϵᵧ::AbstractArray,

View file

@ -7,7 +7,7 @@ export
DDPM,
# Scheduler methods
add_noise,
step
forward,
reverse
end # module Schedulers

View file

@ -15,7 +15,7 @@ using Test
α̅_cosine = cumprod(1 .- β_cosine)
α̅_sigmoid = cumprod(1 .- β_sigmoid)
# arxiv:2208.11970 (eq. 109)
# arxiv:2208.11970 Eq. 109
SNR_linear = α̅_linear ./ (1 .- α̅_linear)
SNR_scaled_linear = α̅_scaled_linear ./ (1 .- α̅_scaled_linear)
SNR_cosine = α̅_cosine ./ (1 .- α̅_cosine)

View file

@ -1,9 +1,9 @@
import Diffusers: step, add_noise, DDPM, cosine_beta_schedule, rescale_zero_terminal_snr
import Diffusers: reverse, forward, DDPM, cosine_beta_schedule
using Statistics
using Test
@testset "Schedulers tests" begin
@testset "check `step` correctness" begin
@testset "check `reverse` correctness" begin
T = 10
batch_size = 8
size = 128
@ -20,15 +20,15 @@ using Test
for t in 1:T
t = ones(UInt32, batch_size) .* t
# corrupt x₀ with noise
xₜ = Diffusers.add_noise(ddpm, x₀, ϵ, t)
xₜ = forward(ddpm, x₀, ϵ, t)
# suppose a model predicted ϵ perfectly
_, x̂₀ = Diffusers.step(ddpm, xₜ, ϵ, t)
_, x̂₀ = reverse(ddpm, xₜ, ϵ, t)
# test that we recover x₀
@test x̂₀ x₀
end
end
@testset "check `add_noise` terminal SNR" begin
@testset "check `forward` terminal SNR" begin
T = 10
batch_size = 1
size = 2500
@ -46,7 +46,7 @@ using Test
t = ones(UInt32, batch_size) .* T
# corrupt x₀ with noise
xₜ = Diffusers.add_noise(ddpm, x₀, ϵ, t)
xₜ = forward(ddpm, x₀, ϵ, t)
@test std(xₜ) 1.0 atol = 1.0f-3
@test mean(xₜ) 0.0 atol = 1.0f-2