🎨 (DDPM) simplify constructor

This commit is contained in:
Laureηt 2023-08-11 20:07:50 +02:00
parent 4b344dcc5d
commit b6c8309733
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
3 changed files with 2 additions and 5 deletions

View file

@ -38,7 +38,6 @@ scatter(dataset[1, :], dataset[2, :],
num_timesteps = 100 num_timesteps = 100
scheduler = DDPM( scheduler = DDPM(
Vector{Float32},
linear_beta_schedule(num_timesteps) linear_beta_schedule(num_timesteps)
); );

View file

@ -28,7 +28,7 @@ struct DDPM{V<:AbstractVector} <: Scheduler
⎷β̅₋₁::V # square root of β̅₋₁ ⎷β̅₋₁::V # square root of β̅₋₁
end end
function DDPM(V::DataType, β::AbstractVector) function DDPM(β::AbstractVector)
T = length(β) T = length(β)
α = 1 .- β α = 1 .- β
@ -48,7 +48,7 @@ function DDPM(V::DataType, β::AbstractVector)
⎷α̅₋₁ = sqrt.(α̅₋₁) ⎷α̅₋₁ = sqrt.(α̅₋₁)
⎷β̅₋₁ = sqrt.(β̅₋₁) ⎷β̅₋₁ = sqrt.(β̅₋₁)
DDPM{V}( DDPM{typeof(β)}(
T, T,
β, β,
α, α,

View file

@ -10,7 +10,6 @@ using Test
# create a DDPM with a cosine beta schedule # create a DDPM with a cosine beta schedule
ddpm = Diffusers.DDPM( ddpm = Diffusers.DDPM(
Vector{Float32},
Diffusers.cosine_beta_schedule(T), Diffusers.cosine_beta_schedule(T),
) )
@ -36,7 +35,6 @@ using Test
# create a DDPM with a terminal SNR cosine beta schedule # create a DDPM with a terminal SNR cosine beta schedule
ddpm = Diffusers.DDPM( ddpm = Diffusers.DDPM(
Vector{Float32},
Diffusers.rescale_zero_terminal_snr( Diffusers.rescale_zero_terminal_snr(
Diffusers.cosine_beta_schedule(T), Diffusers.cosine_beta_schedule(T),
), ),