From 972013637d7fe3787429a3f4d9fa39253128667a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Thu, 27 Jul 2023 19:46:40 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=9D=20update=20some=20docstrings?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/BetaSchedulers.jl | 95 +++++++++++++++++++++++++++++-------------- src/DDPM.jl | 24 +++++------ src/Schedulers.jl | 2 +- 3 files changed, 78 insertions(+), 43 deletions(-) diff --git a/src/BetaSchedulers.jl b/src/BetaSchedulers.jl index 7571c3e..637b01a 100644 --- a/src/BetaSchedulers.jl +++ b/src/BetaSchedulers.jl @@ -3,47 +3,82 @@ import NNlib: sigmoid """ Linear beta schedule. -cf. https://arxiv.org/abs/2006.11239 +cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) + +## Input + * T (`Int`): number of timesteps + * β_1 (`Real := 0.0001f0`): initial value of β + * β_T (`Real := 0.02f0`): final value of β + +## Output + * βs (`Vector{Real}`): β_t values at each timestep t """ -function linear_beta_schedule(num_timesteps::Int, β_start=0.0001f0, β_end=0.02f0) - range(β_start, β_end, length=num_timesteps) +function linear_beta_schedule(T::Int, β_1::Real=0.0001f0, β_T::Real=0.02f0) + return range(start=β_1, stop=β_T, length=T) end """ Scaled linear beta schedule. -(very specific to latent diffusion models) -cf. https://arxiv.org/abs/2006.11239 +cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) + +## Input + * T (`Int`): number of timesteps + * β_1 (`Real := 0.0001f0`): initial value of β + * β_T (`Real := 0.02f0`): final value of β + +## Output + * βs (`Vector{Real}`): β_t values at each timestep t """ -function scaled_linear_beta_schedule(num_timesteps::Int, β_start=0.0001f0, β_end=0.02f0) - range(β_start^0.5, β_end^0.5, length=num_timesteps) .^ 2 -end - -""" -Cosine beta schedule. - -cf. https://arxiv.org/abs/2102.09672 -""" -function cosine_beta_schedule(num_timesteps::Int, max_beta=0.999f0, ϵ=1e-3f0) - α_bar(t) = cos((t + ϵ) / (1 + ϵ) * π / 2)^2 - - βs = Float32[] - for i in 1:num_timesteps - t1 = (i - 1) / num_timesteps - t2 = i / num_timesteps - push!(βs, min(1 - α_bar(t2) / α_bar(t1), max_beta)) - end - - return βs +function scaled_linear_beta_schedule(T::Int, β_1::Real=0.0001f0, β_T::Real=0.02f0) + return range(start=β_1^0.5, stop=β_T^0.5, length=T) .^ 2 end """ Sigmoid beta schedule. -cf. https://arxiv.org/abs/2203.02923 -and https://github.com/MinkaiXu/GeoDiff/blob/main/models/epsnet/diffusion.py#L34 +cf. [[2203.02923] GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation](https://arxiv.org/abs/2203.02923) +and [github.com:MinkaiXu/GeoDiff](https://github.com/MinkaiXu/GeoDiff/blob/ea0ca48045a2f7abfccd7f0df449e45eb6eae638/models/epsnet/diffusion.py#L57) + +## Input + * T (`Int`): number of timesteps + * β_1 (`Real := 0.0001f0`): initial value of β + * β_T (`Real := 0.02f0`): final value of β + +## Output + * βs (`Vector{Real}`): β_t values at each timestep t """ -function sigmoid_beta_schedule(num_timesteps::Int, β_start=0.0001f0, β_end=0.02f0) - x = range(-6, 6, length=num_timesteps) - sigmoid(x) * (β_end - β_start) + β_start +function sigmoid_beta_schedule(T::Int, β_1::Real=0.0001f0, β_T::Real=0.02f0) + x = range(start=-6, stop=6, length=T) + return sigmoid(x) * (β_T - β_1) + β_1 +end + +""" +Cosine beta schedule. + +cf. [[2102.09672] Improved Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2102.09672) + +## Input + * T (`Int`): number of timesteps + * β_max (`Real := 0.999f0`): maximum value of β + * ϵ (`Real := 1e-3f0`): small value used to avoid division by zero + +## Output + * βs (`Vector{Real}`): β_t values at each timestep t +""" +function cosine_beta_schedule(T::Int, β_max::Real=0.999f0, ϵ::Real=1e-3f0) + α_bar(t) = cos((t + ϵ) / (1 + ϵ) * π / 2)^2 + + βs = Float32[] + for t in 1:T + t1 = (t - 1) / T + t2 = t / T + + β_t = 1 - α_bar(t2) / α_bar(t1) + β_t = min(β_max, β_t) + + push!(βs, β_t) + end + + return βs end diff --git a/src/DDPM.jl b/src/DDPM.jl index 322af2f..5d09cae 100644 --- a/src/DDPM.jl +++ b/src/DDPM.jl @@ -38,24 +38,24 @@ function DDPM(V::DataType, βs::AbstractVector) ) end +""" +Remove noise from model output using the backward diffusion process. + +Args: + scheduler (`DDPM`): scheduler object. + sample (`AbstractArray`): sample to remove noise from, i.e. model_input. + model_output (`AbstractArray`): predicted noise from the model. + timesteps (`AbstractArray`): timesteps to remove noise from. + +Returns: + `AbstractArray`: denoised model output at the given timestep. +""" function step( scheduler::DDPM, sample::AbstractArray, model_output::AbstractArray, timesteps::AbstractArray, ) - """ - Remove noise from model output using the backward diffusion process. - - Args: - scheduler (`Scheduler`): scheduler object. - sample (`AbstractArray`): sample to remove noise from, i.e. model_input. - model_output (`AbstractArray`): predicted noise from the model. - timesteps (`AbstractArray`): timesteps to remove noise from. - - Returns: - `AbstractArray`: denoised model output at the given timestep. - """ # 1. compute alphas, betas α_cumprod_t = scheduler.α_cumprods[timesteps] α_cumprod_t_prev = scheduler.α_cumprods[timesteps.-1] diff --git a/src/Schedulers.jl b/src/Schedulers.jl index 35ef409..78d325b 100644 --- a/src/Schedulers.jl +++ b/src/Schedulers.jl @@ -4,7 +4,7 @@ abstract type Scheduler end Add noise to clean data using the forward diffusion process. ## Input - * scheduler (`Scheduler`): scheduler object. + * scheduler (`Scheduler`): scheduler to use. * clean_data (`AbstractArray`): clean data to add noise to. * noise (`AbstractArray`): noise to add to clean data. * timesteps (`AbstractArray`): timesteps used to weight the noise.