mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-09-20 19:45:26 +00:00
♻️ (Schedulers) rename methods names
add_noise -> forward step -> reverse
This commit is contained in:
parent
b6c8309733
commit
cc0326450b
|
@ -16,7 +16,7 @@ Add noise to clean data using the forward diffusion process.
|
||||||
## Output
|
## Output
|
||||||
* `xₜ::AbstractArray`: noisy data at the given timesteps
|
* `xₜ::AbstractArray`: noisy data at the given timesteps
|
||||||
"""
|
"""
|
||||||
function add_noise(
|
function forward(
|
||||||
scheduler::Scheduler,
|
scheduler::Scheduler,
|
||||||
x₀::AbstractArray,
|
x₀::AbstractArray,
|
||||||
ϵ::AbstractArray,
|
ϵ::AbstractArray,
|
||||||
|
@ -36,7 +36,7 @@ Remove noise from model output using the backward diffusion process.
|
||||||
* `xₜ₋₁::AbstractArray`: denoised sample at t=t-1
|
* `xₜ₋₁::AbstractArray`: denoised sample at t=t-1
|
||||||
* `x̂₀::AbstractArray`: denoised sample at t=0
|
* `x̂₀::AbstractArray`: denoised sample at t=0
|
||||||
"""
|
"""
|
||||||
function step(
|
function reverse(
|
||||||
scheduler::Scheduler,
|
scheduler::Scheduler,
|
||||||
xₜ::AbstractArray,
|
xₜ::AbstractArray,
|
||||||
ϵᵧ::AbstractArray,
|
ϵᵧ::AbstractArray,
|
||||||
|
|
|
@ -65,7 +65,7 @@ function DDPM(β::AbstractVector)
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
function add_noise(
|
function forward(
|
||||||
scheduler::DDPM,
|
scheduler::DDPM,
|
||||||
x₀::AbstractArray,
|
x₀::AbstractArray,
|
||||||
ϵ::AbstractArray,
|
ϵ::AbstractArray,
|
||||||
|
@ -86,7 +86,7 @@ function add_noise(
|
||||||
return xₜ
|
return xₜ
|
||||||
end
|
end
|
||||||
|
|
||||||
function step(
|
function reverse(
|
||||||
scheduler::DDPM,
|
scheduler::DDPM,
|
||||||
xₜ::AbstractArray,
|
xₜ::AbstractArray,
|
||||||
ϵᵧ::AbstractArray,
|
ϵᵧ::AbstractArray,
|
||||||
|
|
|
@ -7,7 +7,7 @@ export
|
||||||
DDPM,
|
DDPM,
|
||||||
|
|
||||||
# Scheduler methods
|
# Scheduler methods
|
||||||
add_noise,
|
forward,
|
||||||
step
|
reverse
|
||||||
|
|
||||||
end # module Schedulers
|
end # module Schedulers
|
||||||
|
|
|
@ -15,7 +15,7 @@ using Test
|
||||||
α̅_cosine = cumprod(1 .- β_cosine)
|
α̅_cosine = cumprod(1 .- β_cosine)
|
||||||
α̅_sigmoid = cumprod(1 .- β_sigmoid)
|
α̅_sigmoid = cumprod(1 .- β_sigmoid)
|
||||||
|
|
||||||
# arxiv:2208.11970 (eq. 109)
|
# arxiv:2208.11970 Eq. 109
|
||||||
SNR_linear = α̅_linear ./ (1 .- α̅_linear)
|
SNR_linear = α̅_linear ./ (1 .- α̅_linear)
|
||||||
SNR_scaled_linear = α̅_scaled_linear ./ (1 .- α̅_scaled_linear)
|
SNR_scaled_linear = α̅_scaled_linear ./ (1 .- α̅_scaled_linear)
|
||||||
SNR_cosine = α̅_cosine ./ (1 .- α̅_cosine)
|
SNR_cosine = α̅_cosine ./ (1 .- α̅_cosine)
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
import Diffusers: step, add_noise, DDPM, cosine_beta_schedule, rescale_zero_terminal_snr
|
import Diffusers: reverse, forward, DDPM, cosine_beta_schedule
|
||||||
using Statistics
|
using Statistics
|
||||||
using Test
|
using Test
|
||||||
|
|
||||||
@testset "Schedulers tests" begin
|
@testset "Schedulers tests" begin
|
||||||
@testset "check `step` correctness" begin
|
@testset "check `reverse` correctness" begin
|
||||||
T = 10
|
T = 10
|
||||||
batch_size = 8
|
batch_size = 8
|
||||||
size = 128
|
size = 128
|
||||||
|
@ -20,15 +20,15 @@ using Test
|
||||||
for t in 1:T
|
for t in 1:T
|
||||||
t = ones(UInt32, batch_size) .* t
|
t = ones(UInt32, batch_size) .* t
|
||||||
# corrupt x₀ with noise
|
# corrupt x₀ with noise
|
||||||
xₜ = Diffusers.add_noise(ddpm, x₀, ϵ, t)
|
xₜ = forward(ddpm, x₀, ϵ, t)
|
||||||
# suppose a model predicted ϵ perfectly
|
# suppose a model predicted ϵ perfectly
|
||||||
_, x̂₀ = Diffusers.step(ddpm, xₜ, ϵ, t)
|
_, x̂₀ = reverse(ddpm, xₜ, ϵ, t)
|
||||||
# test that we recover x₀
|
# test that we recover x₀
|
||||||
@test x̂₀ ≈ x₀
|
@test x̂₀ ≈ x₀
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "check `add_noise` terminal SNR" begin
|
@testset "check `forward` terminal SNR" begin
|
||||||
T = 10
|
T = 10
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
size = 2500
|
size = 2500
|
||||||
|
@ -46,7 +46,7 @@ using Test
|
||||||
|
|
||||||
t = ones(UInt32, batch_size) .* T
|
t = ones(UInt32, batch_size) .* T
|
||||||
# corrupt x₀ with noise
|
# corrupt x₀ with noise
|
||||||
xₜ = Diffusers.add_noise(ddpm, x₀, ϵ, t)
|
xₜ = forward(ddpm, x₀, ϵ, t)
|
||||||
|
|
||||||
@test std(xₜ) ≈ 1.0 atol = 1.0f-3
|
@test std(xₜ) ≈ 1.0 atol = 1.0f-3
|
||||||
@test mean(xₜ) ≈ 0.0 atol = 1.0f-2
|
@test mean(xₜ) ≈ 0.0 atol = 1.0f-2
|
||||||
|
|
Loading…
Reference in a new issue