diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/__init__.py b/src/refiners/foundationals/latent_diffusion/schedulers/__init__.py index 23bd7f4..3eb00cf 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/__init__.py @@ -1,7 +1,7 @@ from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver -from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler from refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler +from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler __all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"] diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py b/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py index 726bdcb..f6f2cba 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py @@ -1,4 +1,4 @@ -from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor, Generator +from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py index 3b79348..52e706c 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py @@ -1,8 +1,10 @@ -from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler -import numpy as np -from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype, Generator from collections import deque +import numpy as np +from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor + +from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler + class DPMSolver(Scheduler): """Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095 diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/euler.py b/src/refiners/foundationals/latent_diffusion/schedulers/euler.py index bb340d2..2a69c8c 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/euler.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/euler.py @@ -1,7 +1,8 @@ -from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler -from torch import Tensor, device as Device, dtype as Dtype, float32, tensor, Generator -import torch import numpy as np +import torch +from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor + +from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler class EulerScheduler(Scheduler): diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py index f570ad8..f64a4cc 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod from enum import Enum -from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log, Generator from typing import TypeVar +from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt + T = TypeVar("T", bound="Scheduler") diff --git a/tests/foundationals/latent_diffusion/test_schedulers.py b/tests/foundationals/latent_diffusion/test_schedulers.py index 5f1d9e0..4b5841f 100644 --- a/tests/foundationals/latent_diffusion/test_schedulers.py +++ b/tests/foundationals/latent_diffusion/test_schedulers.py @@ -2,7 +2,7 @@ from typing import cast from warnings import warn import pytest -from torch import Tensor, allclose, device as Device, equal, randn, isclose +from torch import Tensor, allclose, device as Device, equal, isclose, randn from refiners.fluxion import manual_seed from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver, EulerScheduler