From 73f6ccfc98fcd27baefc9dfec3bfd84c1ccd46b3 Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Wed, 31 Jan 2024 14:07:34 +0000 Subject: [PATCH] make Scheduler a fl.Module + Change name Scheduler -> Solver --- README.md | 2 +- .../convert_diffusers_controlnet.py | 4 +- .../latent_diffusion/__init__.py | 4 +- .../foundationals/latent_diffusion/model.py | 24 +-- .../latent_diffusion/multi_diffusion.py | 2 +- .../foundationals/latent_diffusion/restart.py | 34 ++-- .../latent_diffusion/schedulers/__init__.py | 7 - .../self_attention_guidance.py | 8 +- .../latent_diffusion/solvers/__init__.py | 7 + .../{schedulers => solvers}/ddim.py | 9 +- .../{schedulers => solvers}/ddpm.py | 4 +- .../dpm_solver.py => solvers/dpm.py} | 5 +- .../{schedulers => solvers}/euler.py | 8 +- .../scheduler.py => solvers/solver.py} | 150 ++++++++++-------- .../stable_diffusion_1/model.py | 22 +-- .../stable_diffusion_xl/model.py | 14 +- .../training_utils/latent_diffusion.py | 11 +- tests/e2e/test_diffusion.py | 38 ++--- .../{test_schedulers.py => test_solvers.py} | 4 +- 19 files changed, 184 insertions(+), 173 deletions(-) delete mode 100644 src/refiners/foundationals/latent_diffusion/schedulers/__init__.py create mode 100644 src/refiners/foundationals/latent_diffusion/solvers/__init__.py rename src/refiners/foundationals/latent_diffusion/{schedulers => solvers}/ddim.py (88%) rename src/refiners/foundationals/latent_diffusion/{schedulers => solvers}/ddpm.py (91%) rename src/refiners/foundationals/latent_diffusion/{schedulers/dpm_solver.py => solvers/dpm.py} (97%) rename src/refiners/foundationals/latent_diffusion/{schedulers => solvers}/euler.py (90%) rename src/refiners/foundationals/latent_diffusion/{schedulers/scheduler.py => solvers/solver.py} (79%) rename tests/foundationals/latent_diffusion/{test_schedulers.py => test_solvers.py} (96%) diff --git a/README.md b/README.md index 98ed4fe..7b0ba89 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ ______________________________________________________________________ ## Latest News 🔥 -- Added [Euler's method](https://arxiv.org/abs/2206.00364) to schedulers (contributed by [@israfelsr](https://github.com/israfelsr)) +- Added [Euler's method](https://arxiv.org/abs/2206.00364) to solvers (contributed by [@israfelsr](https://github.com/israfelsr)) - Added [DINOv2](https://github.com/facebookresearch/dinov2) for high-performance visual features (contributed by [@Laurent2916](https://github.com/Laurent2916)) - Added [FreeU](https://github.com/ChenyangSi/FreeU) for improved quality at no cost (contributed by [@isamu-isozaki](https://github.com/isamu-isozaki)) - Added [Restart Sampling](https://github.com/Newbeeer/diffusion_restart_sampling) for improved image generation ([example](https://github.com/Newbeeer/diffusion_restart_sampling/issues/4)) diff --git a/scripts/conversion/convert_diffusers_controlnet.py b/scripts/conversion/convert_diffusers_controlnet.py index 5185193..e603bf8 100644 --- a/scripts/conversion/convert_diffusers_controlnet.py +++ b/scripts/conversion/convert_diffusers_controlnet.py @@ -37,8 +37,8 @@ def convert(args: Args) -> dict[str, torch.Tensor]: clip_text_embedding = torch.rand(1, 77, 768) unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) - scheduler = DPMSolver(num_inference_steps=10) - timestep = scheduler.timesteps[0].unsqueeze(dim=0) + solver = DPMSolver(num_inference_steps=10) + timestep = solver.timesteps[0].unsqueeze(dim=0) unet.set_timestep(timestep=timestep.unsqueeze(dim=0)) x = torch.randn(1, 4, 64, 64) diff --git a/src/refiners/foundationals/latent_diffusion/__init__.py b/src/refiners/foundationals/latent_diffusion/__init__.py index 68b6517..ce47a46 100644 --- a/src/refiners/foundationals/latent_diffusion/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/__init__.py @@ -5,7 +5,7 @@ from refiners.foundationals.latent_diffusion.auto_encoder import ( LatentDiffusionAutoencoder, ) from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter -from refiners.foundationals.latent_diffusion.schedulers import DPMSolver, Scheduler +from refiners.foundationals.latent_diffusion.solvers import DPMSolver, Solver from refiners.foundationals.latent_diffusion.stable_diffusion_1 import ( SD1ControlnetAdapter, SD1IPAdapter, @@ -33,7 +33,7 @@ __all__ = [ "SDXLIPAdapter", "SDXLT2IAdapter", "DPMSolver", - "Scheduler", + "Solver", "CLIPTextEncoderL", "LatentDiffusionAutoencoder", "SDFreeUAdapter", diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index f84824d..d7ede91 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -7,7 +7,7 @@ from torch import Tensor, device as Device, dtype as DType import refiners.fluxion.layers as fl from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder -from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler +from refiners.foundationals.latent_diffusion.solvers.solver import Solver T = TypeVar("T", bound="fl.Module") @@ -20,7 +20,7 @@ class LatentDiffusionModel(fl.Module, ABC): unet: fl.Module, lda: LatentDiffusionAutoencoder, clip_text_encoder: fl.Module, - scheduler: Scheduler, + solver: Solver, device: Device | str = "cpu", dtype: DType = torch.float32, ) -> None: @@ -30,10 +30,10 @@ class LatentDiffusionModel(fl.Module, ABC): self.unet = unet.to(device=self.device, dtype=self.dtype) self.lda = lda.to(device=self.device, dtype=self.dtype) self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype) - self.scheduler = scheduler.to(device=self.device, dtype=self.dtype) + self.solver = solver.to(device=self.device, dtype=self.dtype) def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None: - self.scheduler = self.scheduler.rebuild(num_inference_steps=num_steps, first_inference_step=first_step) + self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step) def init_latents( self, @@ -51,15 +51,15 @@ class LatentDiffusionModel(fl.Module, ABC): if init_image is None: return noise encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height))) - return self.scheduler.add_noise( + return self.solver.add_noise( x=encoded_image, noise=noise, - step=self.scheduler.first_inference_step, + step=self.solver.first_inference_step, ) @property def steps(self) -> list[int]: - return self.scheduler.inference_steps + return self.solver.inference_steps @abstractmethod def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: @@ -82,12 +82,12 @@ class LatentDiffusionModel(fl.Module, ABC): def forward( self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor ) -> Tensor: - timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) + timestep = self.solver.timesteps[step].unsqueeze(dim=0) self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs) latents = torch.cat(tensors=(x, x)) # for classifier-free guidance - # scale latents for schedulers that need it - latents = self.scheduler.scale_model_input(latents, step=step) + # scale latents for solvers that need it + latents = self.solver.scale_model_input(latents, step=step) unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2) # classifier-free guidance @@ -101,14 +101,14 @@ class LatentDiffusionModel(fl.Module, ABC): x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs ) - return self.scheduler(x, predicted_noise=predicted_noise, step=step) + return self.solver(x, predicted_noise=predicted_noise, step=step) def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel: return self.__class__( unet=self.unet.structural_copy(), lda=self.lda.structural_copy(), clip_text_encoder=self.clip_text_encoder.structural_copy(), - scheduler=self.scheduler, + solver=self.solver, device=self.device, dtype=self.dtype, ) diff --git a/src/refiners/foundationals/latent_diffusion/multi_diffusion.py b/src/refiners/foundationals/latent_diffusion/multi_diffusion.py index 31deac0..10f0f1f 100644 --- a/src/refiners/foundationals/latent_diffusion/multi_diffusion.py +++ b/src/refiners/foundationals/latent_diffusion/multi_diffusion.py @@ -51,7 +51,7 @@ class MultiDiffusion(Generic[T, D], ABC): match step: case step if step == target.start_step and target.init_latents is not None: noise_view = target.crop(noise) - view = self.ldm.scheduler.add_noise( + view = self.ldm.solver.add_noise( x=target.init_latents, noise=noise_view, step=step, diff --git a/src/refiners/foundationals/latent_diffusion/restart.py b/src/refiners/foundationals/latent_diffusion/restart.py index c436939..282cfcd 100644 --- a/src/refiners/foundationals/latent_diffusion/restart.py +++ b/src/refiners/foundationals/latent_diffusion/restart.py @@ -5,22 +5,22 @@ from typing import Generic, TypeVar import torch from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel -from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM -from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler +from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM +from refiners.foundationals.latent_diffusion.solvers.solver import Solver T = TypeVar("T", bound=LatentDiffusionModel) def add_noise_interval( - scheduler: Scheduler, + solver: Solver, /, x: torch.Tensor, noise: torch.Tensor, initial_timestep: torch.Tensor, target_timestep: torch.Tensor, ) -> torch.Tensor: - initial_cumulative_scale_factors = scheduler.cumulative_scale_factors[initial_timestep] - target_cumulative_scale_factors = scheduler.cumulative_scale_factors[target_timestep] + initial_cumulative_scale_factors = solver.cumulative_scale_factors[initial_timestep] + target_cumulative_scale_factors = solver.cumulative_scale_factors[target_timestep] factor = target_cumulative_scale_factors / initial_cumulative_scale_factors noised_x = factor * x + torch.sqrt(1 - factor**2) * noise @@ -33,7 +33,7 @@ class Restart(Generic[T]): Implements the restart sampling strategy from the paper "Restart Sampling for Improving Generative Processes" (https://arxiv.org/pdf/2306.14878.pdf) - Works only with the DDIM scheduler for now. + Works only with the DDIM solver for now. """ ldm: T @@ -43,7 +43,7 @@ class Restart(Generic[T]): end_time: float = 2 def __post_init__(self) -> None: - assert isinstance(self.ldm.scheduler, DDIM), "Restart sampling only works with DDIM scheduler" + assert isinstance(self.ldm.solver, DDIM), "Restart sampling only works with DDIM solver" def __call__( self, @@ -53,15 +53,15 @@ class Restart(Generic[T]): condition_scale: float = 7.5, **kwargs: torch.Tensor, ) -> torch.Tensor: - original_scheduler = self.ldm.scheduler - new_scheduler = DDIM(self.ldm.scheduler.num_inference_steps, device=self.device, dtype=self.dtype) - new_scheduler.timesteps = self.timesteps - self.ldm.scheduler = new_scheduler + original_solver = self.ldm.solver + new_solver = DDIM(self.ldm.solver.num_inference_steps, device=self.device, dtype=self.dtype) + new_solver.timesteps = self.timesteps + self.ldm.solver = new_solver for _ in range(self.num_iterations): noise = torch.randn_like(input=x, device=self.device, dtype=self.dtype) x = add_noise_interval( - new_scheduler, + new_solver, x=x, noise=noise, initial_timestep=self.timesteps[-1], @@ -73,18 +73,18 @@ class Restart(Generic[T]): x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=condition_scale, **kwargs ) - self.ldm.scheduler = original_scheduler + self.ldm.solver = original_solver return x @cached_property def start_step(self) -> int: - sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors - return int(torch.argmin(input=torch.abs(input=sigmas[self.ldm.scheduler.timesteps] - self.start_time))) + sigmas = self.ldm.solver.noise_std / self.ldm.solver.cumulative_scale_factors + return int(torch.argmin(input=torch.abs(input=sigmas[self.ldm.solver.timesteps] - self.start_time))) @cached_property def end_timestep(self) -> int: - sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors + sigmas = self.ldm.solver.noise_std / self.ldm.solver.cumulative_scale_factors return int(torch.argmin(input=torch.abs(input=sigmas - self.end_time))) @cached_property @@ -92,7 +92,7 @@ class Restart(Generic[T]): return ( torch.round( torch.linspace( - start=int(self.ldm.scheduler.timesteps[self.start_step]), + start=int(self.ldm.solver.timesteps[self.start_step]), end=self.end_timestep, steps=self.num_steps, ) diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/__init__.py b/src/refiners/foundationals/latent_diffusion/schedulers/__init__.py deleted file mode 100644 index 3eb00cf..0000000 --- a/src/refiners/foundationals/latent_diffusion/schedulers/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM -from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM -from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver -from refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler -from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler - -__all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"] diff --git a/src/refiners/foundationals/latent_diffusion/self_attention_guidance.py b/src/refiners/foundationals/latent_diffusion/self_attention_guidance.py index 1916a58..cddbfb0 100644 --- a/src/refiners/foundationals/latent_diffusion/self_attention_guidance.py +++ b/src/refiners/foundationals/latent_diffusion/self_attention_guidance.py @@ -9,7 +9,7 @@ import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.context import Contexts from refiners.fluxion.utils import gaussian_blur, interpolate -from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler +from refiners.foundationals.latent_diffusion.solvers.solver import Solver if TYPE_CHECKING: from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet @@ -89,13 +89,13 @@ class SAGAdapter(Generic[T], fl.Chain, Adapter[T]): return interpolate(attn_mask, Size((h, w))) def compute_degraded_latents( - self, scheduler: Scheduler, latents: Tensor, noise: Tensor, step: int, classifier_free_guidance: bool = True + self, solver: Solver, latents: Tensor, noise: Tensor, step: int, classifier_free_guidance: bool = True ) -> Tensor: sag_mask = self.compute_sag_mask(latents=latents, classifier_free_guidance=classifier_free_guidance) - original_latents = scheduler.remove_noise(x=latents, noise=noise, step=step) + original_latents = solver.remove_noise(x=latents, noise=noise, step=step) degraded_latents = gaussian_blur(original_latents, kernel_size=self.kernel_size, sigma=self.sigma) degraded_latents = degraded_latents * sag_mask + original_latents * (1 - sag_mask) - return scheduler.add_noise(degraded_latents, noise=noise, step=step) + return solver.add_noise(degraded_latents, noise=noise, step=step) def init_context(self) -> Contexts: return {"self_attention_map": {"middle_block_attn_map": None, "middle_block_attn_shape": []}} diff --git a/src/refiners/foundationals/latent_diffusion/solvers/__init__.py b/src/refiners/foundationals/latent_diffusion/solvers/__init__.py new file mode 100644 index 0000000..bd41fdb --- /dev/null +++ b/src/refiners/foundationals/latent_diffusion/solvers/__init__.py @@ -0,0 +1,7 @@ +from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM +from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM +from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver +from refiners.foundationals.latent_diffusion.solvers.euler import Euler +from refiners.foundationals.latent_diffusion.solvers.solver import Solver + +__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler"] diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py b/src/refiners/foundationals/latent_diffusion/solvers/ddim.py similarity index 88% rename from src/refiners/foundationals/latent_diffusion/schedulers/ddim.py rename to src/refiners/foundationals/latent_diffusion/solvers/ddim.py index 16d6ceb..8e02d45 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/ddim.py @@ -1,9 +1,9 @@ from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor -from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler +from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver -class DDIM(Scheduler): +class DDIM(Solver): def __init__( self, num_inference_steps: int, @@ -25,15 +25,14 @@ class DDIM(Scheduler): device=device, dtype=dtype, ) - self.timesteps = self._generate_timesteps() def _generate_timesteps(self) -> Tensor: """ Generates decreasing timesteps with 'leading' spacing and offset of 1 - similar to diffusers settings for the DDIM scheduler in Stable Diffusion 1.5 + similar to diffusers settings for the DDIM solver in Stable Diffusion 1.5 """ step_ratio = self.num_train_timesteps // self.num_inference_steps - timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1 + timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio + 1 return timesteps.flip(0) def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py b/src/refiners/foundationals/latent_diffusion/solvers/ddpm.py similarity index 91% rename from src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py rename to src/refiners/foundationals/latent_diffusion/solvers/ddpm.py index 109d65b..31cb52b 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/ddpm.py @@ -1,9 +1,9 @@ from torch import Generator, Tensor, arange, device as Device, dtype as DType -from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler +from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver -class DDPM(Scheduler): +class DDPM(Solver): """ Denoising Diffusion Probabilistic Model diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py similarity index 97% rename from src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py rename to src/refiners/foundationals/latent_diffusion/solvers/dpm.py index 41167a7..ca28913 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py @@ -3,10 +3,10 @@ from collections import deque import numpy as np from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor -from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler +from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver -class DPMSolver(Scheduler): +class DPMSolver(Solver): """ Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095 @@ -48,7 +48,6 @@ class DPMSolver(Scheduler): # ...and we want the same result as the original codebase. return tensor( np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:], - device=self.device, ).flip(0) def rebuild( diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/euler.py b/src/refiners/foundationals/latent_diffusion/solvers/euler.py similarity index 90% rename from src/refiners/foundationals/latent_diffusion/schedulers/euler.py rename to src/refiners/foundationals/latent_diffusion/solvers/euler.py index 690a8e5..17711cc 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/euler.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/euler.py @@ -2,10 +2,10 @@ import numpy as np import torch from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor -from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler +from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver -class EulerScheduler(Scheduler): +class Euler(Solver): def __init__( self, num_inference_steps: int, @@ -40,9 +40,7 @@ class EulerScheduler(Scheduler): # numpy.linspace(0,999,31)[15] is 499.49999999999994 # torch.linspace(0,999,31)[15] is 499.5 # ...and we want the same result as the original codebase. - timesteps = torch.tensor( - np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device - ).flip(0) + timesteps = torch.tensor(np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps)).flip(0) return timesteps def _generate_sigmas(self) -> Tensor: diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py b/src/refiners/foundationals/latent_diffusion/solvers/solver.py similarity index 79% rename from src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py rename to src/refiners/foundationals/latent_diffusion/solvers/solver.py index 25fbe08..6579fa1 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/solver.py @@ -4,7 +4,9 @@ from typing import TypeVar from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt -T = TypeVar("T", bound="Scheduler") +from refiners.fluxion import layers as fl + +T = TypeVar("T", bound="Solver") class NoiseSchedule(str, Enum): @@ -13,11 +15,11 @@ class NoiseSchedule(str, Enum): KARRAS = "karras" -class Scheduler(ABC): +class Solver(fl.Module, ABC): """ - A base class for creating a diffusion model scheduler. + A base class for creating a diffusion model solver. - The Scheduler creates a sequence of noise and scaling factors used in the diffusion process, + Solver creates a sequence of noise and scaling factors used in the diffusion process, which gradually transforms the original data distribution into a Gaussian one. This process is described using several parameters such as initial and final diffusion rates, @@ -36,9 +38,8 @@ class Scheduler(ABC): first_inference_step: int = 0, device: Device | str = "cpu", dtype: DType = float32, - ): - self.device: Device = Device(device) - self.dtype: DType = dtype + ) -> None: + super().__init__() self.num_inference_steps = num_inference_steps self.num_train_timesteps = num_train_timesteps self.initial_diffusion_rate = initial_diffusion_rate @@ -50,6 +51,7 @@ class Scheduler(ABC): self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0)) self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std) self.timesteps = self._generate_timesteps() + self.to(device=device, dtype=dtype) @abstractmethod def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: @@ -69,57 +71,6 @@ class Scheduler(ABC): """ ... - @property - def all_steps(self) -> list[int]: - return list(range(self.num_inference_steps)) - - @property - def inference_steps(self) -> list[int]: - return self.all_steps[self.first_inference_step :] - - def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T: - num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps - first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step - return self.__class__( - num_inference_steps=num_inference_steps, - num_train_timesteps=self.num_train_timesteps, - initial_diffusion_rate=self.initial_diffusion_rate, - final_diffusion_rate=self.final_diffusion_rate, - noise_schedule=self.noise_schedule, - first_inference_step=first_inference_step, - device=self.device, - dtype=self.dtype, - ) - - def scale_model_input(self, x: Tensor, step: int) -> Tensor: - """ - For compatibility with schedulers that need to scale the input according to the current timestep. - """ - return x - - def sample_power_distribution(self, power: float = 2, /) -> Tensor: - return ( - linspace( - start=self.initial_diffusion_rate ** (1 / power), - end=self.final_diffusion_rate ** (1 / power), - steps=self.num_train_timesteps, - device=self.device, - dtype=self.dtype, - ) - ** power - ) - - def sample_noise_schedule(self) -> Tensor: - match self.noise_schedule: - case "uniform": - return 1 - self.sample_power_distribution(1) - case "quadratic": - return 1 - self.sample_power_distribution(2) - case "karras": - return 1 - self.sample_power_distribution(7) - case _: - raise ValueError(f"Unknown noise schedule: {self.noise_schedule}") - def add_noise( self, x: Tensor, @@ -141,14 +92,77 @@ class Scheduler(ABC): denoised_x = (x - noise_stds * noise) / cumulative_scale_factors return denoised_x - def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore - if device is not None: - self.device = Device(device) - self.timesteps = self.timesteps.to(device) - if dtype is not None: - self.dtype = dtype - self.scale_factors = self.scale_factors.to(device, dtype=dtype) - self.cumulative_scale_factors = self.cumulative_scale_factors.to(device, dtype=dtype) - self.noise_std = self.noise_std.to(device, dtype=dtype) - self.signal_to_noise_ratios = self.signal_to_noise_ratios.to(device, dtype=dtype) + @property + def all_steps(self) -> list[int]: + return list(range(self.num_inference_steps)) + + @property + def inference_steps(self) -> list[int]: + return self.all_steps[self.first_inference_step :] + + @property + def device(self) -> Device: + return self.scale_factors.device + + @property + def dtype(self) -> DType: + return self.scale_factors.dtype + + @device.setter + def device(self, device: Device | str | None = None) -> None: + self.to(device=device) + + @dtype.setter + def dtype(self, dtype: DType | None = None) -> None: + self.to(dtype=dtype) + + def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T: + num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps + first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step + return self.__class__( + num_inference_steps=num_inference_steps, + num_train_timesteps=self.num_train_timesteps, + initial_diffusion_rate=self.initial_diffusion_rate, + final_diffusion_rate=self.final_diffusion_rate, + noise_schedule=self.noise_schedule, + first_inference_step=first_inference_step, + device=self.device, + dtype=self.dtype, + ) + + def scale_model_input(self, x: Tensor, step: int) -> Tensor: + """ + For compatibility with solvers that need to scale the input according to the current timestep. + """ + return x + + def sample_power_distribution(self, power: float = 2, /) -> Tensor: + return ( + linspace( + start=self.initial_diffusion_rate ** (1 / power), + end=self.final_diffusion_rate ** (1 / power), + steps=self.num_train_timesteps, + ) + ** power + ) + + def sample_noise_schedule(self) -> Tensor: + match self.noise_schedule: + case "uniform": + return 1 - self.sample_power_distribution(1) + case "quadratic": + return 1 - self.sample_power_distribution(2) + case "karras": + return 1 - self.sample_power_distribution(7) + case _: + raise ValueError(f"Unknown noise schedule: {self.noise_schedule}") + + def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver": + super().to(device=device, dtype=dtype) + for name, attr in [(name, attr) for name, attr in self.__dict__.items() if isinstance(attr, Tensor)]: + match name: + case "timesteps": + setattr(self, name, attr.to(device=device)) + case _: + setattr(self, name, attr.to(device=device, dtype=dtype)) return self diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py index 716fea3..6c9482c 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py @@ -7,8 +7,8 @@ from refiners.fluxion.utils import image_to_tensor, interpolate from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel -from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver -from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler +from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver +from refiners.foundationals.latent_diffusion.solvers.solver import Solver from refiners.foundationals.latent_diffusion.stable_diffusion_1.self_attention_guidance import SD1SAGAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet @@ -26,20 +26,20 @@ class StableDiffusion_1(LatentDiffusionModel): unet: SD1UNet | None = None, lda: SD1Autoencoder | None = None, clip_text_encoder: CLIPTextEncoderL | None = None, - scheduler: Scheduler | None = None, + solver: Solver | None = None, device: Device | str = "cpu", dtype: DType = torch.float32, ) -> None: unet = unet or SD1UNet(in_channels=4) lda = lda or SD1Autoencoder() clip_text_encoder = clip_text_encoder or CLIPTextEncoderL() - scheduler = scheduler or DPMSolver(num_inference_steps=30) + solver = solver or DPMSolver(num_inference_steps=30) super().__init__( unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, - scheduler=scheduler, + solver=solver, device=device, dtype=dtype, ) @@ -82,14 +82,14 @@ class StableDiffusion_1(LatentDiffusionModel): assert sag is not None degraded_latents = sag.compute_degraded_latents( - scheduler=self.scheduler, + solver=self.solver, latents=x, noise=noise, step=step, classifier_free_guidance=True, ) - timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) + timestep = self.solver.timesteps[step].unsqueeze(dim=0) negative_embedding, _ = clip_text_embedding.chunk(2) self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs) if "ip_adapter" in self.unet.provider.contexts: @@ -111,14 +111,14 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1): unet: SD1UNet | None = None, lda: SD1Autoencoder | None = None, clip_text_encoder: CLIPTextEncoderL | None = None, - scheduler: Scheduler | None = None, + solver: Solver | None = None, device: Device | str = "cpu", dtype: DType = torch.float32, ) -> None: self.mask_latents: Tensor | None = None self.target_image_latents: Tensor | None = None super().__init__( - unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler, device=device, dtype=dtype + unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, solver=solver, device=device, dtype=dtype ) def forward( @@ -162,7 +162,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1): assert self.target_image_latents is not None degraded_latents = sag.compute_degraded_latents( - scheduler=self.scheduler, + solver=self.solver, latents=x, noise=noise, step=step, @@ -173,7 +173,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1): dim=1, ) - timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) + timestep = self.solver.timesteps[step].unsqueeze(dim=0) negative_embedding, _ = clip_text_embedding.chunk(2) self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py index 1971b32..95d0b80 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py @@ -3,8 +3,8 @@ from torch import Tensor, device as Device, dtype as DType from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel -from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM -from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler +from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM +from refiners.foundationals.latent_diffusion.solvers.solver import Solver from refiners.foundationals.latent_diffusion.stable_diffusion_xl.self_attention_guidance import SDXLSAGAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet @@ -23,20 +23,20 @@ class StableDiffusion_XL(LatentDiffusionModel): unet: SDXLUNet | None = None, lda: SDXLAutoencoder | None = None, clip_text_encoder: DoubleTextEncoder | None = None, - scheduler: Scheduler | None = None, + solver: Solver | None = None, device: Device | str = "cpu", dtype: DType = torch.float32, ) -> None: unet = unet or SDXLUNet(in_channels=4) lda = lda or SDXLAutoencoder() clip_text_encoder = clip_text_encoder or DoubleTextEncoder() - scheduler = scheduler or DDIM(num_inference_steps=30) + solver = solver or DDIM(num_inference_steps=30) super().__init__( unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, - scheduler=scheduler, + solver=solver, device=device, dtype=dtype, ) @@ -131,7 +131,7 @@ class StableDiffusion_XL(LatentDiffusionModel): assert sag is not None degraded_latents = sag.compute_degraded_latents( - scheduler=self.scheduler, + solver=self.solver, latents=x, noise=noise, step=step, @@ -140,7 +140,7 @@ class StableDiffusion_XL(LatentDiffusionModel): negative_text_embedding, _ = clip_text_embedding.chunk(2) negative_pooled_embedding, _ = pooled_text_embedding.chunk(2) - timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) + timestep = self.solver.timesteps[step].unsqueeze(dim=0) time_ids, _ = time_ids.chunk(2) self.set_unet_context( diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index f4f8ccf..732ed60 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -19,7 +19,8 @@ from refiners.foundationals.latent_diffusion import ( SD1UNet, StableDiffusion_1, ) -from refiners.foundationals.latent_diffusion.schedulers import DDPM +from refiners.foundationals.latent_diffusion.solvers import DDPM +from refiners.foundationals.latent_diffusion.solvers.solver import Solver from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder from refiners.training_utils.callback import Callback from refiners.training_utils.config import BaseConfig @@ -150,7 +151,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]): return TextEmbeddingLatentsDataset(trainer=self) @cached_property - def ddpm_scheduler(self) -> DDPM: + def ddpm_solver(self) -> Solver: return DDPM( num_inference_steps=1000, device=self.device, @@ -159,7 +160,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]): def sample_timestep(self) -> Tensor: random_step = random.randint(a=self.config.latent_diffusion.min_step, b=self.config.latent_diffusion.max_step) self.current_step = random_step - return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0) + return self.ddpm_solver.timesteps[random_step].unsqueeze(dim=0) def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Tensor: return sample_noise( @@ -170,7 +171,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]): clip_text_embedding, latents = batch.text_embeddings, batch.latents timestep = self.sample_timestep() noise = self.sample_noise(size=latents.shape, dtype=latents.dtype) - noisy_latents = self.ddpm_scheduler.add_noise(x=latents, noise=noise, step=self.current_step) + noisy_latents = self.ddpm_solver.add_noise(x=latents, noise=noise, step=self.current_step) self.unet.set_timestep(timestep=timestep) self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) prediction = self.unet(noisy_latents) @@ -182,7 +183,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]): unet=self.unet, lda=self.lda, clip_text_encoder=self.text_encoder, - scheduler=DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps), + solver=DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps), device=self.device, ) prompts = self.config.test_diffusion.prompts diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index b282718..d47910f 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -24,8 +24,8 @@ from refiners.foundationals.latent_diffusion.lora import SDLoraManager from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter from refiners.foundationals.latent_diffusion.restart import Restart -from refiners.foundationals.latent_diffusion.schedulers import DDIM, EulerScheduler -from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule +from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler +from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from tests.utils import ensure_similar_images @@ -491,8 +491,8 @@ def sd15_ddim( warn("not running on CPU, skipping") pytest.skip() - ddim_scheduler = DDIM(num_inference_steps=20) - sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device) + ddim_solver = DDIM(num_inference_steps=20) + sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device) sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.lda.load_from_safetensors(lda_weights) @@ -509,8 +509,8 @@ def sd15_ddim_karras( warn("not running on CPU, skipping") pytest.skip() - ddim_scheduler = DDIM(num_inference_steps=20, noise_schedule=NoiseSchedule.KARRAS) - sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device) + ddim_solver = DDIM(num_inference_steps=20, noise_schedule=NoiseSchedule.KARRAS) + sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device) sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.lda.load_from_safetensors(lda_weights) @@ -527,8 +527,8 @@ def sd15_euler( warn("not running on CPU, skipping") pytest.skip() - euler_scheduler = EulerScheduler(num_inference_steps=30) - sd15 = StableDiffusion_1(scheduler=euler_scheduler, device=test_device) + euler_solver = Euler(num_inference_steps=30) + sd15 = StableDiffusion_1(solver=euler_solver, device=test_device) sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.lda.load_from_safetensors(lda_weights) @@ -545,8 +545,8 @@ def sd15_ddim_lda_ft_mse( warn("not running on CPU, skipping") pytest.skip() - ddim_scheduler = DDIM(num_inference_steps=20) - sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device) + ddim_solver = DDIM(num_inference_steps=20) + sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device) sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights)) sd15.lda.load_state_dict(load_from_safetensors(lda_ft_mse_weights)) @@ -599,8 +599,8 @@ def sdxl_ddim( warn(message="not running on CPU, skipping") pytest.skip() - scheduler = DDIM(num_inference_steps=30) - sdxl = StableDiffusion_XL(scheduler=scheduler, device=test_device) + solver = DDIM(num_inference_steps=30) + sdxl = StableDiffusion_XL(solver=solver, device=test_device) sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights) sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_weights) @@ -617,8 +617,8 @@ def sdxl_ddim_lda_fp16_fix( warn(message="not running on CPU, skipping") pytest.skip() - scheduler = DDIM(num_inference_steps=30) - sdxl = StableDiffusion_XL(scheduler=scheduler, device=test_device) + solver = DDIM(num_inference_steps=30) + sdxl = StableDiffusion_XL(solver=solver, device=test_device) sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights) sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights) @@ -659,8 +659,8 @@ def test_diffusion_std_random_init_euler( sd15_euler: StableDiffusion_1, expected_image_std_random_init_euler: Image.Image, test_device: torch.device ): sd15 = sd15_euler - euler_scheduler = sd15_euler.scheduler - assert isinstance(euler_scheduler, EulerScheduler) + euler_solver = sd15_euler.solver + assert isinstance(euler_solver, Euler) prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" @@ -670,7 +670,7 @@ def test_diffusion_std_random_init_euler( manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) - x = x * euler_scheduler.init_noise_sigma + x = x * euler_solver.init_noise_sigma for step in sd15.steps: x = sd15( @@ -1202,7 +1202,7 @@ def test_diffusion_refonly( for step in sd15.steps: noise = torch.randn(2, 4, 64, 64, device=test_device) - noised_guide = sd15.scheduler.add_noise(guide, noise, step) + noised_guide = sd15.solver.add_noise(guide, noise, step) refonly_adapter.set_controlnet_condition(noised_guide) x = sd15( x, @@ -1244,7 +1244,7 @@ def test_diffusion_inpainting_refonly( for step in sd15.steps: noise = torch.randn_like(guide) - noised_guide = sd15.scheduler.add_noise(guide, noise, step) + noised_guide = sd15.solver.add_noise(guide, noise, step) # See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support # inpaint variation models") noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1) diff --git a/tests/foundationals/latent_diffusion/test_schedulers.py b/tests/foundationals/latent_diffusion/test_solvers.py similarity index 96% rename from tests/foundationals/latent_diffusion/test_schedulers.py rename to tests/foundationals/latent_diffusion/test_solvers.py index 8423356..f7c8c4d 100644 --- a/tests/foundationals/latent_diffusion/test_schedulers.py +++ b/tests/foundationals/latent_diffusion/test_solvers.py @@ -5,7 +5,7 @@ import pytest from torch import Tensor, allclose, device as Device, equal, isclose, randn from refiners.fluxion import manual_seed -from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver, EulerScheduler +from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler def test_ddpm_diffusers(): @@ -83,7 +83,7 @@ def test_euler_diffusers(): use_karras_sigmas=False, ) diffusers_scheduler.set_timesteps(30) - refiners_scheduler = EulerScheduler(num_inference_steps=30) + refiners_scheduler = Euler(num_inference_steps=30) sample = randn(1, 4, 32, 32) predicted_noise = randn(1, 4, 32, 32)