From ec302abc830dab8bbbda5da9c4525e31ada7b818 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Fri, 6 Oct 2023 13:43:30 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=9A=20(Schedulers)=20move=20abstract.j?= =?UTF-8?q?l=20includes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/Schedulers/Abstract.jl | 8 ++++---- src/Schedulers/DDIM.jl | 2 -- src/Schedulers/DDPM.jl | 2 -- src/Schedulers/Schedulers.jl | 1 + 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/Schedulers/Abstract.jl b/src/Schedulers/Abstract.jl index fe4f51e..2a2d7f3 100644 --- a/src/Schedulers/Abstract.jl +++ b/src/Schedulers/Abstract.jl @@ -1,11 +1,11 @@ +@enum VarianceType FIXED_SMALL FIXED_SMALL_LOG FIXED_LARGE FIXED_LARGE_LOG LEARNED +@enum PredictionType EPSILON SAMPLE VELOCITY + """ Abstract type for schedulers. - """ abstract type Scheduler end -@enum VarianceType FIXED_SMALL FIXED_SMALL_LOG FIXED_LARGE FIXED_LARGE_LOG LEARNED -@enum PredictionType EPSILON SAMPLE VELOCITY """ Add noise to clean data using the forward diffusion process. @@ -59,7 +59,7 @@ Compute the velocity of the diffusion process. * `vₜ::AbstractArray`: velocity at the given timesteps ## References - * [[2202.00512] Progressive Distillation for Fast Sampling of Diffusion Models](https://arxiv.org/abs/2202.00512) (Ann. D) + [salimans2022progressive; Ann. D](@cite) """ function get_velocity( scheduler::Scheduler, diff --git a/src/Schedulers/DDIM.jl b/src/Schedulers/DDIM.jl index b2bdb3f..76521a2 100644 --- a/src/Schedulers/DDIM.jl +++ b/src/Schedulers/DDIM.jl @@ -1,5 +1,3 @@ -include("Abstract.jl") - using ShiftedArrays function _extract( diff --git a/src/Schedulers/DDPM.jl b/src/Schedulers/DDPM.jl index 94d8c36..343b910 100644 --- a/src/Schedulers/DDPM.jl +++ b/src/Schedulers/DDPM.jl @@ -1,5 +1,3 @@ -include("Abstract.jl") - using ShiftedArrays function _extract( diff --git a/src/Schedulers/Schedulers.jl b/src/Schedulers/Schedulers.jl index 4b6c3b6..19d8ed4 100644 --- a/src/Schedulers/Schedulers.jl +++ b/src/Schedulers/Schedulers.jl @@ -1,5 +1,6 @@ module Schedulers +include("Abstract.jl") include("DDPM.jl") include("DDIM.jl")