From 184b9665ecf093d00127ef57702c706cc128e101 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Mon, 14 Aug 2023 17:19:43 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=AA=20(Schedulers)=20add=20get=5Fveloc?= =?UTF-8?q?ity=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/Schedulers.jl | 41 ++++++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/test/Schedulers.jl b/test/Schedulers.jl index 3d8b24c..ea15b8f 100644 --- a/test/Schedulers.jl +++ b/test/Schedulers.jl @@ -1,8 +1,9 @@ -import Diffusers: reverse, forward, DDPM, cosine_beta_schedule +import Diffusers: reverse, forward, get_velocity, DDPM, cosine_beta_schedule, VELOCITY import Statistics: mean, std using Test @testset "Schedulers tests" begin + @testset "check `reverse` correctness" begin T = 10 batch_size = 8 @@ -15,15 +16,36 @@ using Test x₀ = ones(Float32, size, size, batch_size) ϵ = randn(Float32, size, size, batch_size) - for t in 1:T - t = ones(UInt32, batch_size) .* t - # corrupt x₀ with noise - xₜ = forward(ddpm, x₀, ϵ, t) - # suppose a model predicted ϵ perfectly - _, x̂₀ = reverse(ddpm, xₜ, ϵ, t) - # test that we recover x₀ - @test x̂₀ ≈ x₀ + @testset "PredictionType == EPSILON" begin + for t in 1:T + t = ones(UInt32, batch_size) .* t + # get xₜ from forward diffusion process + xₜ = forward(ddpm, x₀, ϵ, t) + # suppose a model predicted ϵ perfectly + ϵᵧ = ϵ + # use reverse diffusion process to retreive x̂₀ + _, x̂₀ = reverse(ddpm, xₜ, ϵᵧ, t) + # test that we recover x₀ + @test x̂₀ ≈ x₀ + end end + + @testset "PredictionType == VELOCITY" begin + for t in 1:T + t = ones(UInt32, batch_size) .* t + # get xₜ from forward diffusion process + xₜ = forward(ddpm, x₀, ϵ, t) + # compute vₜ to train model + vₜ = get_velocity(ddpm, x₀, ϵ, t) + # suppose a model predicted vₜ perfectly + vᵧ = vₜ + # use reverse diffusion process to retreive x̂₀ + _, x̂₀ = Diffusers.reverse(ddpm, xₜ, vₜ, t, VELOCITY) + # test that we recover x₀ + @test x̂₀ ≈ x₀ + end + end + end @testset "check `forward` terminal SNR" begin @@ -45,4 +67,5 @@ using Test @test std(xₜ) ≈ 1.0 atol = 1.0f-3 @test mean(xₜ) ≈ 0.0 atol = 1.0f-2 end + end