🚚 move DDPM's step from Schedulers.jl to DDPM.jl

This commit is contained in:
Laureηt 2023-07-25 20:08:35 +02:00
parent cfd090b6e2
commit b4ed6e3c99
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
2 changed files with 46 additions and 51 deletions

View file

@ -19,7 +19,6 @@ struct DDPM{V<:AbstractVector} <: Scheduler
sqrt_one_minus_α_cumprods::V
end
function DDPM(V::DataType, βs::AbstractVector)
αs = 1 .- βs
α_cumprods = cumprod(αs)
@ -39,6 +38,50 @@ function DDPM(V::DataType, βs::AbstractVector)
)
end
function DDPM(V::DataType, beta_scheduler)
DDPM(V, beta_scheduler)
function step(
scheduler::DDPM,
sample::AbstractArray,
model_output::AbstractArray,
timesteps::AbstractArray,
)
"""
Remove noise from model output using the backward diffusion process.
Args:
scheduler (`Scheduler`): scheduler object.
sample (`AbstractArray`): sample to remove noise from, i.e. model_input.
model_output (`AbstractArray`): predicted noise from the model.
timesteps (`AbstractArray`): timesteps to remove noise from.
Returns:
`AbstractArray`: denoised model output at the given timestep.
"""
# 1. compute alphas, betas
α_cumprod_t = scheduler.α_cumprods[timesteps]
α_cumprod_t_prev = scheduler.α_cumprods[timesteps.-1]
β_cumprod_t = 1 .- α_cumprod_t
β_cumprod_t_prev = 1 .- α_cumprod_t_prev
current_α_t = α_cumprod_t ./ α_cumprod_t_prev
current_β_t = 1 .- current_α_t
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
# epsilon prediction type
# print shapes of thingies
x_0_pred = (sample - sqrt.(β_cumprod_t)' .* model_output) ./ sqrt.(α_cumprod_t)'
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (sqrt.(α_cumprod_t_prev) .* current_β_t) ./ β_cumprod_t
current_sample_coeff = sqrt.(current_α_t) .* β_cumprod_t_prev ./ β_cumprod_t
# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample = pred_original_sample_coeff' .* x_0_pred + current_sample_coeff' .* sample
# 6. Add noise
variance = sqrt.(scheduler.βs[timesteps])' .* randn(size(model_output))
pred_prev_sample = pred_prev_sample + variance
return pred_prev_sample, x_0_pred
end

View file

@ -25,51 +25,3 @@ function add_noise(
return noisy_data
end
function step(
scheduler::Scheduler,
sample::AbstractArray,
model_output::AbstractArray,
timesteps::AbstractArray,
)
"""
Remove noise from model output using the backward diffusion process.
Args:
scheduler (`Scheduler`): scheduler object.
sample (`AbstractArray`): sample to remove noise from, i.e. model_input.
model_output (`AbstractArray`): predicted noise from the model.
timesteps (`AbstractArray`): timesteps to remove noise from.
Returns:
`AbstractArray`: denoised model output at the given timestep.
"""
# 1. compute alphas, betas
α_cumprod_t = scheduler.α_cumprods[timesteps]
α_cumprod_t_prev = scheduler.α_cumprods[timesteps.-1]
β_cumprod_t = 1 .- α_cumprod_t
β_cumprod_t_prev = 1 .- α_cumprod_t_prev
current_α_t = α_cumprod_t ./ α_cumprod_t_prev
current_β_t = 1 .- current_α_t
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
# epsilon prediction type
# print shapes of thingies
x_0_pred = (sample - sqrt.(β_cumprod_t)' .* model_output) ./ sqrt.(α_cumprod_t)'
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (sqrt.(α_cumprod_t_prev) .* current_β_t) ./ β_cumprod_t
current_sample_coeff = sqrt.(current_α_t) .* β_cumprod_t_prev ./ β_cumprod_t
# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample = pred_original_sample_coeff' .* x_0_pred + current_sample_coeff' .* sample
# 6. Add noise
variance = sqrt.(scheduler.βs[timesteps])' .* randn(size(model_output))
pred_prev_sample = pred_prev_sample + variance
return pred_prev_sample
end