From b6c83097338ff530ea8affa9e21ad1c921aaed5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Fri, 11 Aug 2023 20:07:50 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20(DDPM)=20simplify=20constructor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/swissroll.jl | 1 - src/Schedulers/DDPM.jl | 4 ++-- test/Schedulers.jl | 2 -- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/swissroll.jl b/examples/swissroll.jl index bc7eca1..9c1f28d 100644 --- a/examples/swissroll.jl +++ b/examples/swissroll.jl @@ -38,7 +38,6 @@ scatter(dataset[1, :], dataset[2, :], num_timesteps = 100 scheduler = DDPM( - Vector{Float32}, linear_beta_schedule(num_timesteps) ); diff --git a/src/Schedulers/DDPM.jl b/src/Schedulers/DDPM.jl index 846e6f1..c8afc5c 100644 --- a/src/Schedulers/DDPM.jl +++ b/src/Schedulers/DDPM.jl @@ -28,7 +28,7 @@ struct DDPM{V<:AbstractVector} <: Scheduler ⎷β̅₋₁::V # square root of β̅₋₁ end -function DDPM(V::DataType, β::AbstractVector) +function DDPM(β::AbstractVector) T = length(β) α = 1 .- β @@ -48,7 +48,7 @@ function DDPM(V::DataType, β::AbstractVector) ⎷α̅₋₁ = sqrt.(α̅₋₁) ⎷β̅₋₁ = sqrt.(β̅₋₁) - DDPM{V}( + DDPM{typeof(β)}( T, β, α, diff --git a/test/Schedulers.jl b/test/Schedulers.jl index d6ed877..20d6653 100644 --- a/test/Schedulers.jl +++ b/test/Schedulers.jl @@ -10,7 +10,6 @@ using Test # create a DDPM with a cosine beta schedule ddpm = Diffusers.DDPM( - Vector{Float32}, Diffusers.cosine_beta_schedule(T), ) @@ -36,7 +35,6 @@ using Test # create a DDPM with a terminal SNR cosine beta schedule ddpm = Diffusers.DDPM( - Vector{Float32}, Diffusers.rescale_zero_terminal_snr( Diffusers.cosine_beta_schedule(T), ),