diff --git a/src/DDPM.jl b/src/DDPM.jl index 5d09cae..826b6a8 100644 --- a/src/DDPM.jl +++ b/src/DDPM.jl @@ -3,7 +3,7 @@ include("Schedulers.jl") """ Denoising Diffusion Probabilistic Models (DDPM) scheduler. -https://arxiv.org/abs/2006.11239 +cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006). """ struct DDPM{V<:AbstractVector} <: Scheduler # number of diffusion steps used to train the model. @@ -41,14 +41,15 @@ end """ Remove noise from model output using the backward diffusion process. -Args: - scheduler (`DDPM`): scheduler object. - sample (`AbstractArray`): sample to remove noise from, i.e. model_input. - model_output (`AbstractArray`): predicted noise from the model. - timesteps (`AbstractArray`): timesteps to remove noise from. +## Input + * scheduler (`DDPM`): scheduler to use + * sample (`AbstractArray`): sample to remove noise from, i.e. model_input + * model_output (`AbstractArray`): predicted noise from the model + * timesteps (`AbstractArray`): timesteps to remove noise from -Returns: - `AbstractArray`: denoised model output at the given timestep. +## Output + * pred_prev_sample (`AbstractArray`): denoised sample at t=t-1 + * x_0_pred (`AbstractArray`): denoised sample at t=0 """ function step( scheduler::DDPM, diff --git a/src/Schedulers.jl b/src/Schedulers.jl index 78d325b..6c7a5f2 100644 --- a/src/Schedulers.jl +++ b/src/Schedulers.jl @@ -4,13 +4,13 @@ abstract type Scheduler end Add noise to clean data using the forward diffusion process. ## Input - * scheduler (`Scheduler`): scheduler to use. - * clean_data (`AbstractArray`): clean data to add noise to. - * noise (`AbstractArray`): noise to add to clean data. - * timesteps (`AbstractArray`): timesteps used to weight the noise. + * scheduler (`Scheduler`): scheduler to use + * clean_data (`AbstractArray`): clean data to add noise to + * noise (`AbstractArray`): noise to add to clean data + * timesteps (`AbstractArray`): timesteps used to weight the noise ## Output - * noisy_data (`AbstractArray`): noisy data at the given timesteps. + * noisy_data (`AbstractArray`): noisy data at the given timesteps """ function add_noise( scheduler::Scheduler,