From cc0326450b93dcf1b56717a58bf027a2861ae802 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Mon, 14 Aug 2023 16:26:15 +0200 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20(Schedulers)=20rename=20me?= =?UTF-8?q?thods=20names=20add=5Fnoise=20->=20forward=20step=20->=20revers?= =?UTF-8?q?e?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/Schedulers/Abstract.jl | 4 ++-- src/Schedulers/DDPM.jl | 4 ++-- src/Schedulers/Schedulers.jl | 4 ++-- test/BetaSchedules.jl | 2 +- test/Schedulers.jl | 12 ++++++------ 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/Schedulers/Abstract.jl b/src/Schedulers/Abstract.jl index 5625d71..a419fec 100644 --- a/src/Schedulers/Abstract.jl +++ b/src/Schedulers/Abstract.jl @@ -16,7 +16,7 @@ Add noise to clean data using the forward diffusion process. ## Output * `xₜ::AbstractArray`: noisy data at the given timesteps """ -function add_noise( +function forward( scheduler::Scheduler, x₀::AbstractArray, ϵ::AbstractArray, @@ -36,7 +36,7 @@ Remove noise from model output using the backward diffusion process. * `xₜ₋₁::AbstractArray`: denoised sample at t=t-1 * `x̂₀::AbstractArray`: denoised sample at t=0 """ -function step( +function reverse( scheduler::Scheduler, xₜ::AbstractArray, ϵᵧ::AbstractArray, diff --git a/src/Schedulers/DDPM.jl b/src/Schedulers/DDPM.jl index c8afc5c..59dcd22 100644 --- a/src/Schedulers/DDPM.jl +++ b/src/Schedulers/DDPM.jl @@ -65,7 +65,7 @@ function DDPM(β::AbstractVector) ) end -function add_noise( +function forward( scheduler::DDPM, x₀::AbstractArray, ϵ::AbstractArray, @@ -86,7 +86,7 @@ function add_noise( return xₜ end -function step( +function reverse( scheduler::DDPM, xₜ::AbstractArray, ϵᵧ::AbstractArray, diff --git a/src/Schedulers/Schedulers.jl b/src/Schedulers/Schedulers.jl index cd4fc88..176b366 100644 --- a/src/Schedulers/Schedulers.jl +++ b/src/Schedulers/Schedulers.jl @@ -7,7 +7,7 @@ export DDPM, # Scheduler methods - add_noise, - step + forward, + reverse end # module Schedulers diff --git a/test/BetaSchedules.jl b/test/BetaSchedules.jl index e87de92..a186e05 100644 --- a/test/BetaSchedules.jl +++ b/test/BetaSchedules.jl @@ -15,7 +15,7 @@ using Test α̅_cosine = cumprod(1 .- β_cosine) α̅_sigmoid = cumprod(1 .- β_sigmoid) - # arxiv:2208.11970 (eq. 109) + # arxiv:2208.11970 Eq. 109 SNR_linear = α̅_linear ./ (1 .- α̅_linear) SNR_scaled_linear = α̅_scaled_linear ./ (1 .- α̅_scaled_linear) SNR_cosine = α̅_cosine ./ (1 .- α̅_cosine) diff --git a/test/Schedulers.jl b/test/Schedulers.jl index 20d6653..fe28aae 100644 --- a/test/Schedulers.jl +++ b/test/Schedulers.jl @@ -1,9 +1,9 @@ -import Diffusers: step, add_noise, DDPM, cosine_beta_schedule, rescale_zero_terminal_snr +import Diffusers: reverse, forward, DDPM, cosine_beta_schedule using Statistics using Test @testset "Schedulers tests" begin - @testset "check `step` correctness" begin + @testset "check `reverse` correctness" begin T = 10 batch_size = 8 size = 128 @@ -20,15 +20,15 @@ using Test for t in 1:T t = ones(UInt32, batch_size) .* t # corrupt x₀ with noise - xₜ = Diffusers.add_noise(ddpm, x₀, ϵ, t) + xₜ = forward(ddpm, x₀, ϵ, t) # suppose a model predicted ϵ perfectly - _, x̂₀ = Diffusers.step(ddpm, xₜ, ϵ, t) + _, x̂₀ = reverse(ddpm, xₜ, ϵ, t) # test that we recover x₀ @test x̂₀ ≈ x₀ end end - @testset "check `add_noise` terminal SNR" begin + @testset "check `forward` terminal SNR" begin T = 10 batch_size = 1 size = 2500 @@ -46,7 +46,7 @@ using Test t = ones(UInt32, batch_size) .* T # corrupt x₀ with noise - xₜ = Diffusers.add_noise(ddpm, x₀, ϵ, t) + xₜ = forward(ddpm, x₀, ϵ, t) @test std(xₜ) ≈ 1.0 atol = 1.0f-3 @test mean(xₜ) ≈ 0.0 atol = 1.0f-2