diff --git a/src/DDPM.jl b/src/DDPM.jl index ebb3fd3..322af2f 100644 --- a/src/DDPM.jl +++ b/src/DDPM.jl @@ -19,7 +19,6 @@ struct DDPM{V<:AbstractVector} <: Scheduler sqrt_one_minus_α_cumprods::V end - function DDPM(V::DataType, βs::AbstractVector) αs = 1 .- βs α_cumprods = cumprod(αs) @@ -39,6 +38,50 @@ function DDPM(V::DataType, βs::AbstractVector) ) end -function DDPM(V::DataType, beta_scheduler) - DDPM(V, beta_scheduler) +function step( + scheduler::DDPM, + sample::AbstractArray, + model_output::AbstractArray, + timesteps::AbstractArray, +) + """ + Remove noise from model output using the backward diffusion process. + + Args: + scheduler (`Scheduler`): scheduler object. + sample (`AbstractArray`): sample to remove noise from, i.e. model_input. + model_output (`AbstractArray`): predicted noise from the model. + timesteps (`AbstractArray`): timesteps to remove noise from. + + Returns: + `AbstractArray`: denoised model output at the given timestep. + """ + # 1. compute alphas, betas + α_cumprod_t = scheduler.α_cumprods[timesteps] + α_cumprod_t_prev = scheduler.α_cumprods[timesteps.-1] + β_cumprod_t = 1 .- α_cumprod_t + β_cumprod_t_prev = 1 .- α_cumprod_t_prev + current_α_t = α_cumprod_t ./ α_cumprod_t_prev + current_β_t = 1 .- current_α_t + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + # epsilon prediction type + # print shapes of thingies + x_0_pred = (sample - sqrt.(β_cumprod_t)' .* model_output) ./ sqrt.(α_cumprod_t)' + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (sqrt.(α_cumprod_t_prev) .* current_β_t) ./ β_cumprod_t + current_sample_coeff = sqrt.(current_α_t) .* β_cumprod_t_prev ./ β_cumprod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff' .* x_0_pred + current_sample_coeff' .* sample + + # 6. Add noise + variance = sqrt.(scheduler.βs[timesteps])' .* randn(size(model_output)) + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample, x_0_pred end diff --git a/src/Schedulers.jl b/src/Schedulers.jl index d3ecd99..1ba7869 100644 --- a/src/Schedulers.jl +++ b/src/Schedulers.jl @@ -25,51 +25,3 @@ function add_noise( return noisy_data end - -function step( - scheduler::Scheduler, - sample::AbstractArray, - model_output::AbstractArray, - timesteps::AbstractArray, -) - """ - Remove noise from model output using the backward diffusion process. - - Args: - scheduler (`Scheduler`): scheduler object. - sample (`AbstractArray`): sample to remove noise from, i.e. model_input. - model_output (`AbstractArray`): predicted noise from the model. - timesteps (`AbstractArray`): timesteps to remove noise from. - - Returns: - `AbstractArray`: denoised model output at the given timestep. - """ - # 1. compute alphas, betas - α_cumprod_t = scheduler.α_cumprods[timesteps] - α_cumprod_t_prev = scheduler.α_cumprods[timesteps.-1] - β_cumprod_t = 1 .- α_cumprod_t - β_cumprod_t_prev = 1 .- α_cumprod_t_prev - current_α_t = α_cumprod_t ./ α_cumprod_t_prev - current_β_t = 1 .- current_α_t - - # 2. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - # epsilon prediction type - # print shapes of thingies - x_0_pred = (sample - sqrt.(β_cumprod_t)' .* model_output) ./ sqrt.(α_cumprod_t)' - - # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t - # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_original_sample_coeff = (sqrt.(α_cumprod_t_prev) .* current_β_t) ./ β_cumprod_t - current_sample_coeff = sqrt.(current_α_t) .* β_cumprod_t_prev ./ β_cumprod_t - - # 5. Compute predicted previous sample µ_t - # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_prev_sample = pred_original_sample_coeff' .* x_0_pred + current_sample_coeff' .* sample - - # 6. Add noise - variance = sqrt.(scheduler.βs[timesteps])' .* randn(size(model_output)) - pred_prev_sample = pred_prev_sample + variance - - return pred_prev_sample -end