🎨 (DDPM) simplify constructor

This commit is contained in:
Laureηt 2023-08-11 20:07:50 +02:00
parent 4b344dcc5d
commit b6c8309733
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
3 changed files with 2 additions and 5 deletions

View file

@ -38,7 +38,6 @@ scatter(dataset[1, :], dataset[2, :],
num_timesteps = 100
scheduler = DDPM(
Vector{Float32},
linear_beta_schedule(num_timesteps)
);

View file

@ -28,7 +28,7 @@ struct DDPM{V<:AbstractVector} <: Scheduler
⎷β̅₋₁::V # square root of β̅₋₁
end
function DDPM(V::DataType, β::AbstractVector)
function DDPM(β::AbstractVector)
T = length(β)
α = 1 .- β
@ -48,7 +48,7 @@ function DDPM(V::DataType, β::AbstractVector)
⎷α̅₋₁ = sqrt.(α̅₋₁)
⎷β̅₋₁ = sqrt.(β̅₋₁)
DDPM{V}(
DDPM{typeof(β)}(
T,
β,
α,

View file

@ -10,7 +10,6 @@ using Test
# create a DDPM with a cosine beta schedule
ddpm = Diffusers.DDPM(
Vector{Float32},
Diffusers.cosine_beta_schedule(T),
)
@ -36,7 +35,6 @@ using Test
# create a DDPM with a terminal SNR cosine beta schedule
ddpm = Diffusers.DDPM(
Vector{Float32},
Diffusers.rescale_zero_terminal_snr(
Diffusers.cosine_beta_schedule(T),
),