mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-09-18 18:45:28 +00:00
🎨 (DDPM) simplify constructor
This commit is contained in:
parent
4b344dcc5d
commit
b6c8309733
|
@ -38,7 +38,6 @@ scatter(dataset[1, :], dataset[2, :],
|
||||||
|
|
||||||
num_timesteps = 100
|
num_timesteps = 100
|
||||||
scheduler = DDPM(
|
scheduler = DDPM(
|
||||||
Vector{Float32},
|
|
||||||
linear_beta_schedule(num_timesteps)
|
linear_beta_schedule(num_timesteps)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ struct DDPM{V<:AbstractVector} <: Scheduler
|
||||||
⎷β̅₋₁::V # square root of β̅₋₁
|
⎷β̅₋₁::V # square root of β̅₋₁
|
||||||
end
|
end
|
||||||
|
|
||||||
function DDPM(V::DataType, β::AbstractVector)
|
function DDPM(β::AbstractVector)
|
||||||
T = length(β)
|
T = length(β)
|
||||||
|
|
||||||
α = 1 .- β
|
α = 1 .- β
|
||||||
|
@ -48,7 +48,7 @@ function DDPM(V::DataType, β::AbstractVector)
|
||||||
⎷α̅₋₁ = sqrt.(α̅₋₁)
|
⎷α̅₋₁ = sqrt.(α̅₋₁)
|
||||||
⎷β̅₋₁ = sqrt.(β̅₋₁)
|
⎷β̅₋₁ = sqrt.(β̅₋₁)
|
||||||
|
|
||||||
DDPM{V}(
|
DDPM{typeof(β)}(
|
||||||
T,
|
T,
|
||||||
β,
|
β,
|
||||||
α,
|
α,
|
||||||
|
|
|
@ -10,7 +10,6 @@ using Test
|
||||||
|
|
||||||
# create a DDPM with a cosine beta schedule
|
# create a DDPM with a cosine beta schedule
|
||||||
ddpm = Diffusers.DDPM(
|
ddpm = Diffusers.DDPM(
|
||||||
Vector{Float32},
|
|
||||||
Diffusers.cosine_beta_schedule(T),
|
Diffusers.cosine_beta_schedule(T),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -36,7 +35,6 @@ using Test
|
||||||
|
|
||||||
# create a DDPM with a terminal SNR cosine beta schedule
|
# create a DDPM with a terminal SNR cosine beta schedule
|
||||||
ddpm = Diffusers.DDPM(
|
ddpm = Diffusers.DDPM(
|
||||||
Vector{Float32},
|
|
||||||
Diffusers.rescale_zero_terminal_snr(
|
Diffusers.rescale_zero_terminal_snr(
|
||||||
Diffusers.cosine_beta_schedule(T),
|
Diffusers.cosine_beta_schedule(T),
|
||||||
),
|
),
|
||||||
|
|
Loading…
Reference in a new issue