From a6fe13ba5e4cfd36b77f8080d077bea0586b9451 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Mon, 14 Aug 2023 16:29:39 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20(Schedulers)=20make=20tests=20a?= =?UTF-8?q?=20bit=20more=20readable?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/Diffusers.jl | 8 ++++---- test/Schedulers.jl | 12 +++--------- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/src/Diffusers.jl b/src/Diffusers.jl index 13961a0..27c933f 100644 --- a/src/Diffusers.jl +++ b/src/Diffusers.jl @@ -29,15 +29,15 @@ import .Schedulers: DDPM, # Scheduler methods - add_noise, - step + forward, + reverse export # Scheduler DDPM, # Scheduler methods - add_noise, - step + forward, + reverse end # module Diffusers diff --git a/test/Schedulers.jl b/test/Schedulers.jl index fe28aae..3d8b24c 100644 --- a/test/Schedulers.jl +++ b/test/Schedulers.jl @@ -1,5 +1,5 @@ import Diffusers: reverse, forward, DDPM, cosine_beta_schedule -using Statistics +import Statistics: mean, std using Test @testset "Schedulers tests" begin @@ -9,9 +9,7 @@ using Test size = 128 # create a DDPM with a cosine beta schedule - ddpm = Diffusers.DDPM( - Diffusers.cosine_beta_schedule(T), - ) + ddpm = DDPM(cosine_beta_schedule(T)) # create some dummy data x₀ = ones(Float32, size, size, batch_size) @@ -34,11 +32,7 @@ using Test size = 2500 # create a DDPM with a terminal SNR cosine beta schedule - ddpm = Diffusers.DDPM( - Diffusers.rescale_zero_terminal_snr( - Diffusers.cosine_beta_schedule(T), - ), - ) + ddpm = DDPM(cosine_beta_schedule(T)) # create some dummy data x₀ = ones(Float32, size, size, batch_size)