make Scheduler a fl.Module + Change name Scheduler -> Solver

This commit is contained in:
limiteinductive 2024-01-31 14:07:34 +00:00 committed by Benjamin Trom
parent 07cb2ff21c
commit 73f6ccfc98
19 changed files with 184 additions and 173 deletions

View file

@ -19,7 +19,7 @@ ______________________________________________________________________
## Latest News 🔥 ## Latest News 🔥
- Added [Euler's method](https://arxiv.org/abs/2206.00364) to schedulers (contributed by [@israfelsr](https://github.com/israfelsr)) - Added [Euler's method](https://arxiv.org/abs/2206.00364) to solvers (contributed by [@israfelsr](https://github.com/israfelsr))
- Added [DINOv2](https://github.com/facebookresearch/dinov2) for high-performance visual features (contributed by [@Laurent2916](https://github.com/Laurent2916)) - Added [DINOv2](https://github.com/facebookresearch/dinov2) for high-performance visual features (contributed by [@Laurent2916](https://github.com/Laurent2916))
- Added [FreeU](https://github.com/ChenyangSi/FreeU) for improved quality at no cost (contributed by [@isamu-isozaki](https://github.com/isamu-isozaki)) - Added [FreeU](https://github.com/ChenyangSi/FreeU) for improved quality at no cost (contributed by [@isamu-isozaki](https://github.com/isamu-isozaki))
- Added [Restart Sampling](https://github.com/Newbeeer/diffusion_restart_sampling) for improved image generation ([example](https://github.com/Newbeeer/diffusion_restart_sampling/issues/4)) - Added [Restart Sampling](https://github.com/Newbeeer/diffusion_restart_sampling) for improved image generation ([example](https://github.com/Newbeeer/diffusion_restart_sampling/issues/4))

View file

@ -37,8 +37,8 @@ def convert(args: Args) -> dict[str, torch.Tensor]:
clip_text_embedding = torch.rand(1, 77, 768) clip_text_embedding = torch.rand(1, 77, 768)
unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
scheduler = DPMSolver(num_inference_steps=10) solver = DPMSolver(num_inference_steps=10)
timestep = scheduler.timesteps[0].unsqueeze(dim=0) timestep = solver.timesteps[0].unsqueeze(dim=0)
unet.set_timestep(timestep=timestep.unsqueeze(dim=0)) unet.set_timestep(timestep=timestep.unsqueeze(dim=0))
x = torch.randn(1, 4, 64, 64) x = torch.randn(1, 4, 64, 64)

View file

@ -5,7 +5,7 @@ from refiners.foundationals.latent_diffusion.auto_encoder import (
LatentDiffusionAutoencoder, LatentDiffusionAutoencoder,
) )
from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter
from refiners.foundationals.latent_diffusion.schedulers import DPMSolver, Scheduler from refiners.foundationals.latent_diffusion.solvers import DPMSolver, Solver
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import ( from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
SD1ControlnetAdapter, SD1ControlnetAdapter,
SD1IPAdapter, SD1IPAdapter,
@ -33,7 +33,7 @@ __all__ = [
"SDXLIPAdapter", "SDXLIPAdapter",
"SDXLT2IAdapter", "SDXLT2IAdapter",
"DPMSolver", "DPMSolver",
"Scheduler", "Solver",
"CLIPTextEncoderL", "CLIPTextEncoderL",
"LatentDiffusionAutoencoder", "LatentDiffusionAutoencoder",
"SDFreeUAdapter", "SDFreeUAdapter",

View file

@ -7,7 +7,7 @@ from torch import Tensor, device as Device, dtype as DType
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler from refiners.foundationals.latent_diffusion.solvers.solver import Solver
T = TypeVar("T", bound="fl.Module") T = TypeVar("T", bound="fl.Module")
@ -20,7 +20,7 @@ class LatentDiffusionModel(fl.Module, ABC):
unet: fl.Module, unet: fl.Module,
lda: LatentDiffusionAutoencoder, lda: LatentDiffusionAutoencoder,
clip_text_encoder: fl.Module, clip_text_encoder: fl.Module,
scheduler: Scheduler, solver: Solver,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = torch.float32, dtype: DType = torch.float32,
) -> None: ) -> None:
@ -30,10 +30,10 @@ class LatentDiffusionModel(fl.Module, ABC):
self.unet = unet.to(device=self.device, dtype=self.dtype) self.unet = unet.to(device=self.device, dtype=self.dtype)
self.lda = lda.to(device=self.device, dtype=self.dtype) self.lda = lda.to(device=self.device, dtype=self.dtype)
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype) self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype) self.solver = solver.to(device=self.device, dtype=self.dtype)
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None: def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
self.scheduler = self.scheduler.rebuild(num_inference_steps=num_steps, first_inference_step=first_step) self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
def init_latents( def init_latents(
self, self,
@ -51,15 +51,15 @@ class LatentDiffusionModel(fl.Module, ABC):
if init_image is None: if init_image is None:
return noise return noise
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height))) encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
return self.scheduler.add_noise( return self.solver.add_noise(
x=encoded_image, x=encoded_image,
noise=noise, noise=noise,
step=self.scheduler.first_inference_step, step=self.solver.first_inference_step,
) )
@property @property
def steps(self) -> list[int]: def steps(self) -> list[int]:
return self.scheduler.inference_steps return self.solver.inference_steps
@abstractmethod @abstractmethod
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
@ -82,12 +82,12 @@ class LatentDiffusionModel(fl.Module, ABC):
def forward( def forward(
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
) -> Tensor: ) -> Tensor:
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) timestep = self.solver.timesteps[step].unsqueeze(dim=0)
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs) self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
# scale latents for schedulers that need it # scale latents for solvers that need it
latents = self.scheduler.scale_model_input(latents, step=step) latents = self.solver.scale_model_input(latents, step=step)
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2) unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
# classifier-free guidance # classifier-free guidance
@ -101,14 +101,14 @@ class LatentDiffusionModel(fl.Module, ABC):
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs
) )
return self.scheduler(x, predicted_noise=predicted_noise, step=step) return self.solver(x, predicted_noise=predicted_noise, step=step)
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel: def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
return self.__class__( return self.__class__(
unet=self.unet.structural_copy(), unet=self.unet.structural_copy(),
lda=self.lda.structural_copy(), lda=self.lda.structural_copy(),
clip_text_encoder=self.clip_text_encoder.structural_copy(), clip_text_encoder=self.clip_text_encoder.structural_copy(),
scheduler=self.scheduler, solver=self.solver,
device=self.device, device=self.device,
dtype=self.dtype, dtype=self.dtype,
) )

View file

@ -51,7 +51,7 @@ class MultiDiffusion(Generic[T, D], ABC):
match step: match step:
case step if step == target.start_step and target.init_latents is not None: case step if step == target.start_step and target.init_latents is not None:
noise_view = target.crop(noise) noise_view = target.crop(noise)
view = self.ldm.scheduler.add_noise( view = self.ldm.solver.add_noise(
x=target.init_latents, x=target.init_latents,
noise=noise_view, noise=noise_view,
step=step, step=step,

View file

@ -5,22 +5,22 @@ from typing import Generic, TypeVar
import torch import torch
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler from refiners.foundationals.latent_diffusion.solvers.solver import Solver
T = TypeVar("T", bound=LatentDiffusionModel) T = TypeVar("T", bound=LatentDiffusionModel)
def add_noise_interval( def add_noise_interval(
scheduler: Scheduler, solver: Solver,
/, /,
x: torch.Tensor, x: torch.Tensor,
noise: torch.Tensor, noise: torch.Tensor,
initial_timestep: torch.Tensor, initial_timestep: torch.Tensor,
target_timestep: torch.Tensor, target_timestep: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
initial_cumulative_scale_factors = scheduler.cumulative_scale_factors[initial_timestep] initial_cumulative_scale_factors = solver.cumulative_scale_factors[initial_timestep]
target_cumulative_scale_factors = scheduler.cumulative_scale_factors[target_timestep] target_cumulative_scale_factors = solver.cumulative_scale_factors[target_timestep]
factor = target_cumulative_scale_factors / initial_cumulative_scale_factors factor = target_cumulative_scale_factors / initial_cumulative_scale_factors
noised_x = factor * x + torch.sqrt(1 - factor**2) * noise noised_x = factor * x + torch.sqrt(1 - factor**2) * noise
@ -33,7 +33,7 @@ class Restart(Generic[T]):
Implements the restart sampling strategy from the paper "Restart Sampling for Improving Generative Processes" Implements the restart sampling strategy from the paper "Restart Sampling for Improving Generative Processes"
(https://arxiv.org/pdf/2306.14878.pdf) (https://arxiv.org/pdf/2306.14878.pdf)
Works only with the DDIM scheduler for now. Works only with the DDIM solver for now.
""" """
ldm: T ldm: T
@ -43,7 +43,7 @@ class Restart(Generic[T]):
end_time: float = 2 end_time: float = 2
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert isinstance(self.ldm.scheduler, DDIM), "Restart sampling only works with DDIM scheduler" assert isinstance(self.ldm.solver, DDIM), "Restart sampling only works with DDIM solver"
def __call__( def __call__(
self, self,
@ -53,15 +53,15 @@ class Restart(Generic[T]):
condition_scale: float = 7.5, condition_scale: float = 7.5,
**kwargs: torch.Tensor, **kwargs: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
original_scheduler = self.ldm.scheduler original_solver = self.ldm.solver
new_scheduler = DDIM(self.ldm.scheduler.num_inference_steps, device=self.device, dtype=self.dtype) new_solver = DDIM(self.ldm.solver.num_inference_steps, device=self.device, dtype=self.dtype)
new_scheduler.timesteps = self.timesteps new_solver.timesteps = self.timesteps
self.ldm.scheduler = new_scheduler self.ldm.solver = new_solver
for _ in range(self.num_iterations): for _ in range(self.num_iterations):
noise = torch.randn_like(input=x, device=self.device, dtype=self.dtype) noise = torch.randn_like(input=x, device=self.device, dtype=self.dtype)
x = add_noise_interval( x = add_noise_interval(
new_scheduler, new_solver,
x=x, x=x,
noise=noise, noise=noise,
initial_timestep=self.timesteps[-1], initial_timestep=self.timesteps[-1],
@ -73,18 +73,18 @@ class Restart(Generic[T]):
x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=condition_scale, **kwargs x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=condition_scale, **kwargs
) )
self.ldm.scheduler = original_scheduler self.ldm.solver = original_solver
return x return x
@cached_property @cached_property
def start_step(self) -> int: def start_step(self) -> int:
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors sigmas = self.ldm.solver.noise_std / self.ldm.solver.cumulative_scale_factors
return int(torch.argmin(input=torch.abs(input=sigmas[self.ldm.scheduler.timesteps] - self.start_time))) return int(torch.argmin(input=torch.abs(input=sigmas[self.ldm.solver.timesteps] - self.start_time)))
@cached_property @cached_property
def end_timestep(self) -> int: def end_timestep(self) -> int:
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors sigmas = self.ldm.solver.noise_std / self.ldm.solver.cumulative_scale_factors
return int(torch.argmin(input=torch.abs(input=sigmas - self.end_time))) return int(torch.argmin(input=torch.abs(input=sigmas - self.end_time)))
@cached_property @cached_property
@ -92,7 +92,7 @@ class Restart(Generic[T]):
return ( return (
torch.round( torch.round(
torch.linspace( torch.linspace(
start=int(self.ldm.scheduler.timesteps[self.start_step]), start=int(self.ldm.solver.timesteps[self.start_step]),
end=self.end_timestep, end=self.end_timestep,
steps=self.num_steps, steps=self.num_steps,
) )

View file

@ -1,7 +0,0 @@
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
__all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"]

View file

@ -9,7 +9,7 @@ import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.context import Contexts from refiners.fluxion.context import Contexts
from refiners.fluxion.utils import gaussian_blur, interpolate from refiners.fluxion.utils import gaussian_blur, interpolate
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler from refiners.foundationals.latent_diffusion.solvers.solver import Solver
if TYPE_CHECKING: if TYPE_CHECKING:
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
@ -89,13 +89,13 @@ class SAGAdapter(Generic[T], fl.Chain, Adapter[T]):
return interpolate(attn_mask, Size((h, w))) return interpolate(attn_mask, Size((h, w)))
def compute_degraded_latents( def compute_degraded_latents(
self, scheduler: Scheduler, latents: Tensor, noise: Tensor, step: int, classifier_free_guidance: bool = True self, solver: Solver, latents: Tensor, noise: Tensor, step: int, classifier_free_guidance: bool = True
) -> Tensor: ) -> Tensor:
sag_mask = self.compute_sag_mask(latents=latents, classifier_free_guidance=classifier_free_guidance) sag_mask = self.compute_sag_mask(latents=latents, classifier_free_guidance=classifier_free_guidance)
original_latents = scheduler.remove_noise(x=latents, noise=noise, step=step) original_latents = solver.remove_noise(x=latents, noise=noise, step=step)
degraded_latents = gaussian_blur(original_latents, kernel_size=self.kernel_size, sigma=self.sigma) degraded_latents = gaussian_blur(original_latents, kernel_size=self.kernel_size, sigma=self.sigma)
degraded_latents = degraded_latents * sag_mask + original_latents * (1 - sag_mask) degraded_latents = degraded_latents * sag_mask + original_latents * (1 - sag_mask)
return scheduler.add_noise(degraded_latents, noise=noise, step=step) return solver.add_noise(degraded_latents, noise=noise, step=step)
def init_context(self) -> Contexts: def init_context(self) -> Contexts:
return {"self_attention_map": {"middle_block_attn_map": None, "middle_block_attn_shape": []}} return {"self_attention_map": {"middle_block_attn_map": None, "middle_block_attn_shape": []}}

View file

@ -0,0 +1,7 @@
from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler"]

View file

@ -1,9 +1,9 @@
from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
class DDIM(Scheduler): class DDIM(Solver):
def __init__( def __init__(
self, self,
num_inference_steps: int, num_inference_steps: int,
@ -25,15 +25,14 @@ class DDIM(Scheduler):
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
self.timesteps = self._generate_timesteps()
def _generate_timesteps(self) -> Tensor: def _generate_timesteps(self) -> Tensor:
""" """
Generates decreasing timesteps with 'leading' spacing and offset of 1 Generates decreasing timesteps with 'leading' spacing and offset of 1
similar to diffusers settings for the DDIM scheduler in Stable Diffusion 1.5 similar to diffusers settings for the DDIM solver in Stable Diffusion 1.5
""" """
step_ratio = self.num_train_timesteps // self.num_inference_steps step_ratio = self.num_train_timesteps // self.num_inference_steps
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1 timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio + 1
return timesteps.flip(0) return timesteps.flip(0)
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:

View file

@ -1,9 +1,9 @@
from torch import Generator, Tensor, arange, device as Device, dtype as DType from torch import Generator, Tensor, arange, device as Device, dtype as DType
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
class DDPM(Scheduler): class DDPM(Solver):
""" """
Denoising Diffusion Probabilistic Model Denoising Diffusion Probabilistic Model

View file

@ -3,10 +3,10 @@ from collections import deque
import numpy as np import numpy as np
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
class DPMSolver(Scheduler): class DPMSolver(Solver):
""" """
Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095 Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
@ -48,7 +48,6 @@ class DPMSolver(Scheduler):
# ...and we want the same result as the original codebase. # ...and we want the same result as the original codebase.
return tensor( return tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:], np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
device=self.device,
).flip(0) ).flip(0)
def rebuild( def rebuild(

View file

@ -2,10 +2,10 @@ import numpy as np
import torch import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
class EulerScheduler(Scheduler): class Euler(Solver):
def __init__( def __init__(
self, self,
num_inference_steps: int, num_inference_steps: int,
@ -40,9 +40,7 @@ class EulerScheduler(Scheduler):
# numpy.linspace(0,999,31)[15] is 499.49999999999994 # numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5 # torch.linspace(0,999,31)[15] is 499.5
# ...and we want the same result as the original codebase. # ...and we want the same result as the original codebase.
timesteps = torch.tensor( timesteps = torch.tensor(np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps)).flip(0)
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device
).flip(0)
return timesteps return timesteps
def _generate_sigmas(self) -> Tensor: def _generate_sigmas(self) -> Tensor:

View file

@ -4,7 +4,9 @@ from typing import TypeVar
from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
T = TypeVar("T", bound="Scheduler") from refiners.fluxion import layers as fl
T = TypeVar("T", bound="Solver")
class NoiseSchedule(str, Enum): class NoiseSchedule(str, Enum):
@ -13,11 +15,11 @@ class NoiseSchedule(str, Enum):
KARRAS = "karras" KARRAS = "karras"
class Scheduler(ABC): class Solver(fl.Module, ABC):
""" """
A base class for creating a diffusion model scheduler. A base class for creating a diffusion model solver.
The Scheduler creates a sequence of noise and scaling factors used in the diffusion process, Solver creates a sequence of noise and scaling factors used in the diffusion process,
which gradually transforms the original data distribution into a Gaussian one. which gradually transforms the original data distribution into a Gaussian one.
This process is described using several parameters such as initial and final diffusion rates, This process is described using several parameters such as initial and final diffusion rates,
@ -36,9 +38,8 @@ class Scheduler(ABC):
first_inference_step: int = 0, first_inference_step: int = 0,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = float32, dtype: DType = float32,
): ) -> None:
self.device: Device = Device(device) super().__init__()
self.dtype: DType = dtype
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.num_train_timesteps = num_train_timesteps self.num_train_timesteps = num_train_timesteps
self.initial_diffusion_rate = initial_diffusion_rate self.initial_diffusion_rate = initial_diffusion_rate
@ -50,6 +51,7 @@ class Scheduler(ABC):
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0)) self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std) self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
self.timesteps = self._generate_timesteps() self.timesteps = self._generate_timesteps()
self.to(device=device, dtype=dtype)
@abstractmethod @abstractmethod
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
@ -69,57 +71,6 @@ class Scheduler(ABC):
""" """
... ...
@property
def all_steps(self) -> list[int]:
return list(range(self.num_inference_steps))
@property
def inference_steps(self) -> list[int]:
return self.all_steps[self.first_inference_step :]
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
return self.__class__(
num_inference_steps=num_inference_steps,
num_train_timesteps=self.num_train_timesteps,
initial_diffusion_rate=self.initial_diffusion_rate,
final_diffusion_rate=self.final_diffusion_rate,
noise_schedule=self.noise_schedule,
first_inference_step=first_inference_step,
device=self.device,
dtype=self.dtype,
)
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
"""
For compatibility with schedulers that need to scale the input according to the current timestep.
"""
return x
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
return (
linspace(
start=self.initial_diffusion_rate ** (1 / power),
end=self.final_diffusion_rate ** (1 / power),
steps=self.num_train_timesteps,
device=self.device,
dtype=self.dtype,
)
** power
)
def sample_noise_schedule(self) -> Tensor:
match self.noise_schedule:
case "uniform":
return 1 - self.sample_power_distribution(1)
case "quadratic":
return 1 - self.sample_power_distribution(2)
case "karras":
return 1 - self.sample_power_distribution(7)
case _:
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
def add_noise( def add_noise(
self, self,
x: Tensor, x: Tensor,
@ -141,14 +92,77 @@ class Scheduler(ABC):
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
return denoised_x return denoised_x
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore @property
if device is not None: def all_steps(self) -> list[int]:
self.device = Device(device) return list(range(self.num_inference_steps))
self.timesteps = self.timesteps.to(device)
if dtype is not None: @property
self.dtype = dtype def inference_steps(self) -> list[int]:
self.scale_factors = self.scale_factors.to(device, dtype=dtype) return self.all_steps[self.first_inference_step :]
self.cumulative_scale_factors = self.cumulative_scale_factors.to(device, dtype=dtype)
self.noise_std = self.noise_std.to(device, dtype=dtype) @property
self.signal_to_noise_ratios = self.signal_to_noise_ratios.to(device, dtype=dtype) def device(self) -> Device:
return self.scale_factors.device
@property
def dtype(self) -> DType:
return self.scale_factors.dtype
@device.setter
def device(self, device: Device | str | None = None) -> None:
self.to(device=device)
@dtype.setter
def dtype(self, dtype: DType | None = None) -> None:
self.to(dtype=dtype)
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
return self.__class__(
num_inference_steps=num_inference_steps,
num_train_timesteps=self.num_train_timesteps,
initial_diffusion_rate=self.initial_diffusion_rate,
final_diffusion_rate=self.final_diffusion_rate,
noise_schedule=self.noise_schedule,
first_inference_step=first_inference_step,
device=self.device,
dtype=self.dtype,
)
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
"""
For compatibility with solvers that need to scale the input according to the current timestep.
"""
return x
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
return (
linspace(
start=self.initial_diffusion_rate ** (1 / power),
end=self.final_diffusion_rate ** (1 / power),
steps=self.num_train_timesteps,
)
** power
)
def sample_noise_schedule(self) -> Tensor:
match self.noise_schedule:
case "uniform":
return 1 - self.sample_power_distribution(1)
case "quadratic":
return 1 - self.sample_power_distribution(2)
case "karras":
return 1 - self.sample_power_distribution(7)
case _:
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
super().to(device=device, dtype=dtype)
for name, attr in [(name, attr) for name, attr in self.__dict__.items() if isinstance(attr, Tensor)]:
match name:
case "timesteps":
setattr(self, name, attr.to(device=device))
case _:
setattr(self, name, attr.to(device=device, dtype=dtype))
return self return self

View file

@ -7,8 +7,8 @@ from refiners.fluxion.utils import image_to_tensor, interpolate
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler from refiners.foundationals.latent_diffusion.solvers.solver import Solver
from refiners.foundationals.latent_diffusion.stable_diffusion_1.self_attention_guidance import SD1SAGAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_1.self_attention_guidance import SD1SAGAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
@ -26,20 +26,20 @@ class StableDiffusion_1(LatentDiffusionModel):
unet: SD1UNet | None = None, unet: SD1UNet | None = None,
lda: SD1Autoencoder | None = None, lda: SD1Autoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None, clip_text_encoder: CLIPTextEncoderL | None = None,
scheduler: Scheduler | None = None, solver: Solver | None = None,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = torch.float32, dtype: DType = torch.float32,
) -> None: ) -> None:
unet = unet or SD1UNet(in_channels=4) unet = unet or SD1UNet(in_channels=4)
lda = lda or SD1Autoencoder() lda = lda or SD1Autoencoder()
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL() clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
scheduler = scheduler or DPMSolver(num_inference_steps=30) solver = solver or DPMSolver(num_inference_steps=30)
super().__init__( super().__init__(
unet=unet, unet=unet,
lda=lda, lda=lda,
clip_text_encoder=clip_text_encoder, clip_text_encoder=clip_text_encoder,
scheduler=scheduler, solver=solver,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -82,14 +82,14 @@ class StableDiffusion_1(LatentDiffusionModel):
assert sag is not None assert sag is not None
degraded_latents = sag.compute_degraded_latents( degraded_latents = sag.compute_degraded_latents(
scheduler=self.scheduler, solver=self.solver,
latents=x, latents=x,
noise=noise, noise=noise,
step=step, step=step,
classifier_free_guidance=True, classifier_free_guidance=True,
) )
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) timestep = self.solver.timesteps[step].unsqueeze(dim=0)
negative_embedding, _ = clip_text_embedding.chunk(2) negative_embedding, _ = clip_text_embedding.chunk(2)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs) self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
if "ip_adapter" in self.unet.provider.contexts: if "ip_adapter" in self.unet.provider.contexts:
@ -111,14 +111,14 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
unet: SD1UNet | None = None, unet: SD1UNet | None = None,
lda: SD1Autoencoder | None = None, lda: SD1Autoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None, clip_text_encoder: CLIPTextEncoderL | None = None,
scheduler: Scheduler | None = None, solver: Solver | None = None,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = torch.float32, dtype: DType = torch.float32,
) -> None: ) -> None:
self.mask_latents: Tensor | None = None self.mask_latents: Tensor | None = None
self.target_image_latents: Tensor | None = None self.target_image_latents: Tensor | None = None
super().__init__( super().__init__(
unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler, device=device, dtype=dtype unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, solver=solver, device=device, dtype=dtype
) )
def forward( def forward(
@ -162,7 +162,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
assert self.target_image_latents is not None assert self.target_image_latents is not None
degraded_latents = sag.compute_degraded_latents( degraded_latents = sag.compute_degraded_latents(
scheduler=self.scheduler, solver=self.solver,
latents=x, latents=x,
noise=noise, noise=noise,
step=step, step=step,
@ -173,7 +173,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
dim=1, dim=1,
) )
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) timestep = self.solver.timesteps[step].unsqueeze(dim=0)
negative_embedding, _ = clip_text_embedding.chunk(2) negative_embedding, _ = clip_text_embedding.chunk(2)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs) self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)

View file

@ -3,8 +3,8 @@ from torch import Tensor, device as Device, dtype as DType
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler from refiners.foundationals.latent_diffusion.solvers.solver import Solver
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.self_attention_guidance import SDXLSAGAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_xl.self_attention_guidance import SDXLSAGAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
@ -23,20 +23,20 @@ class StableDiffusion_XL(LatentDiffusionModel):
unet: SDXLUNet | None = None, unet: SDXLUNet | None = None,
lda: SDXLAutoencoder | None = None, lda: SDXLAutoencoder | None = None,
clip_text_encoder: DoubleTextEncoder | None = None, clip_text_encoder: DoubleTextEncoder | None = None,
scheduler: Scheduler | None = None, solver: Solver | None = None,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = torch.float32, dtype: DType = torch.float32,
) -> None: ) -> None:
unet = unet or SDXLUNet(in_channels=4) unet = unet or SDXLUNet(in_channels=4)
lda = lda or SDXLAutoencoder() lda = lda or SDXLAutoencoder()
clip_text_encoder = clip_text_encoder or DoubleTextEncoder() clip_text_encoder = clip_text_encoder or DoubleTextEncoder()
scheduler = scheduler or DDIM(num_inference_steps=30) solver = solver or DDIM(num_inference_steps=30)
super().__init__( super().__init__(
unet=unet, unet=unet,
lda=lda, lda=lda,
clip_text_encoder=clip_text_encoder, clip_text_encoder=clip_text_encoder,
scheduler=scheduler, solver=solver,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -131,7 +131,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
assert sag is not None assert sag is not None
degraded_latents = sag.compute_degraded_latents( degraded_latents = sag.compute_degraded_latents(
scheduler=self.scheduler, solver=self.solver,
latents=x, latents=x,
noise=noise, noise=noise,
step=step, step=step,
@ -140,7 +140,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
negative_text_embedding, _ = clip_text_embedding.chunk(2) negative_text_embedding, _ = clip_text_embedding.chunk(2)
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2) negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) timestep = self.solver.timesteps[step].unsqueeze(dim=0)
time_ids, _ = time_ids.chunk(2) time_ids, _ = time_ids.chunk(2)
self.set_unet_context( self.set_unet_context(

View file

@ -19,7 +19,8 @@ from refiners.foundationals.latent_diffusion import (
SD1UNet, SD1UNet,
StableDiffusion_1, StableDiffusion_1,
) )
from refiners.foundationals.latent_diffusion.schedulers import DDPM from refiners.foundationals.latent_diffusion.solvers import DDPM
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder
from refiners.training_utils.callback import Callback from refiners.training_utils.callback import Callback
from refiners.training_utils.config import BaseConfig from refiners.training_utils.config import BaseConfig
@ -150,7 +151,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
return TextEmbeddingLatentsDataset(trainer=self) return TextEmbeddingLatentsDataset(trainer=self)
@cached_property @cached_property
def ddpm_scheduler(self) -> DDPM: def ddpm_solver(self) -> Solver:
return DDPM( return DDPM(
num_inference_steps=1000, num_inference_steps=1000,
device=self.device, device=self.device,
@ -159,7 +160,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
def sample_timestep(self) -> Tensor: def sample_timestep(self) -> Tensor:
random_step = random.randint(a=self.config.latent_diffusion.min_step, b=self.config.latent_diffusion.max_step) random_step = random.randint(a=self.config.latent_diffusion.min_step, b=self.config.latent_diffusion.max_step)
self.current_step = random_step self.current_step = random_step
return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0) return self.ddpm_solver.timesteps[random_step].unsqueeze(dim=0)
def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Tensor: def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Tensor:
return sample_noise( return sample_noise(
@ -170,7 +171,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
clip_text_embedding, latents = batch.text_embeddings, batch.latents clip_text_embedding, latents = batch.text_embeddings, batch.latents
timestep = self.sample_timestep() timestep = self.sample_timestep()
noise = self.sample_noise(size=latents.shape, dtype=latents.dtype) noise = self.sample_noise(size=latents.shape, dtype=latents.dtype)
noisy_latents = self.ddpm_scheduler.add_noise(x=latents, noise=noise, step=self.current_step) noisy_latents = self.ddpm_solver.add_noise(x=latents, noise=noise, step=self.current_step)
self.unet.set_timestep(timestep=timestep) self.unet.set_timestep(timestep=timestep)
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
prediction = self.unet(noisy_latents) prediction = self.unet(noisy_latents)
@ -182,7 +183,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
unet=self.unet, unet=self.unet,
lda=self.lda, lda=self.lda,
clip_text_encoder=self.text_encoder, clip_text_encoder=self.text_encoder,
scheduler=DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps), solver=DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps),
device=self.device, device=self.device,
) )
prompts = self.config.test_diffusion.prompts prompts = self.config.test_diffusion.prompts

View file

@ -24,8 +24,8 @@ from refiners.foundationals.latent_diffusion.lora import SDLoraManager
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
from refiners.foundationals.latent_diffusion.restart import Restart from refiners.foundationals.latent_diffusion.restart import Restart
from refiners.foundationals.latent_diffusion.schedulers import DDIM, EulerScheduler from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
from tests.utils import ensure_similar_images from tests.utils import ensure_similar_images
@ -491,8 +491,8 @@ def sd15_ddim(
warn("not running on CPU, skipping") warn("not running on CPU, skipping")
pytest.skip() pytest.skip()
ddim_scheduler = DDIM(num_inference_steps=20) ddim_solver = DDIM(num_inference_steps=20)
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device) sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights) sd15.lda.load_from_safetensors(lda_weights)
@ -509,8 +509,8 @@ def sd15_ddim_karras(
warn("not running on CPU, skipping") warn("not running on CPU, skipping")
pytest.skip() pytest.skip()
ddim_scheduler = DDIM(num_inference_steps=20, noise_schedule=NoiseSchedule.KARRAS) ddim_solver = DDIM(num_inference_steps=20, noise_schedule=NoiseSchedule.KARRAS)
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device) sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights) sd15.lda.load_from_safetensors(lda_weights)
@ -527,8 +527,8 @@ def sd15_euler(
warn("not running on CPU, skipping") warn("not running on CPU, skipping")
pytest.skip() pytest.skip()
euler_scheduler = EulerScheduler(num_inference_steps=30) euler_solver = Euler(num_inference_steps=30)
sd15 = StableDiffusion_1(scheduler=euler_scheduler, device=test_device) sd15 = StableDiffusion_1(solver=euler_solver, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights) sd15.lda.load_from_safetensors(lda_weights)
@ -545,8 +545,8 @@ def sd15_ddim_lda_ft_mse(
warn("not running on CPU, skipping") warn("not running on CPU, skipping")
pytest.skip() pytest.skip()
ddim_scheduler = DDIM(num_inference_steps=20) ddim_solver = DDIM(num_inference_steps=20)
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device) sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights)) sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
sd15.lda.load_state_dict(load_from_safetensors(lda_ft_mse_weights)) sd15.lda.load_state_dict(load_from_safetensors(lda_ft_mse_weights))
@ -599,8 +599,8 @@ def sdxl_ddim(
warn(message="not running on CPU, skipping") warn(message="not running on CPU, skipping")
pytest.skip() pytest.skip()
scheduler = DDIM(num_inference_steps=30) solver = DDIM(num_inference_steps=30)
sdxl = StableDiffusion_XL(scheduler=scheduler, device=test_device) sdxl = StableDiffusion_XL(solver=solver, device=test_device)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights) sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_weights) sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_weights)
@ -617,8 +617,8 @@ def sdxl_ddim_lda_fp16_fix(
warn(message="not running on CPU, skipping") warn(message="not running on CPU, skipping")
pytest.skip() pytest.skip()
scheduler = DDIM(num_inference_steps=30) solver = DDIM(num_inference_steps=30)
sdxl = StableDiffusion_XL(scheduler=scheduler, device=test_device) sdxl = StableDiffusion_XL(solver=solver, device=test_device)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights) sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights) sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights)
@ -659,8 +659,8 @@ def test_diffusion_std_random_init_euler(
sd15_euler: StableDiffusion_1, expected_image_std_random_init_euler: Image.Image, test_device: torch.device sd15_euler: StableDiffusion_1, expected_image_std_random_init_euler: Image.Image, test_device: torch.device
): ):
sd15 = sd15_euler sd15 = sd15_euler
euler_scheduler = sd15_euler.scheduler euler_solver = sd15_euler.solver
assert isinstance(euler_scheduler, EulerScheduler) assert isinstance(euler_solver, Euler)
prompt = "a cute cat, detailed high-quality professional image" prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
@ -670,7 +670,7 @@ def test_diffusion_std_random_init_euler(
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device) x = torch.randn(1, 4, 64, 64, device=test_device)
x = x * euler_scheduler.init_noise_sigma x = x * euler_solver.init_noise_sigma
for step in sd15.steps: for step in sd15.steps:
x = sd15( x = sd15(
@ -1202,7 +1202,7 @@ def test_diffusion_refonly(
for step in sd15.steps: for step in sd15.steps:
noise = torch.randn(2, 4, 64, 64, device=test_device) noise = torch.randn(2, 4, 64, 64, device=test_device)
noised_guide = sd15.scheduler.add_noise(guide, noise, step) noised_guide = sd15.solver.add_noise(guide, noise, step)
refonly_adapter.set_controlnet_condition(noised_guide) refonly_adapter.set_controlnet_condition(noised_guide)
x = sd15( x = sd15(
x, x,
@ -1244,7 +1244,7 @@ def test_diffusion_inpainting_refonly(
for step in sd15.steps: for step in sd15.steps:
noise = torch.randn_like(guide) noise = torch.randn_like(guide)
noised_guide = sd15.scheduler.add_noise(guide, noise, step) noised_guide = sd15.solver.add_noise(guide, noise, step)
# See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support # See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support
# inpaint variation models") # inpaint variation models")
noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1) noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1)

View file

@ -5,7 +5,7 @@ import pytest
from torch import Tensor, allclose, device as Device, equal, isclose, randn from torch import Tensor, allclose, device as Device, equal, isclose, randn
from refiners.fluxion import manual_seed from refiners.fluxion import manual_seed
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver, EulerScheduler from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler
def test_ddpm_diffusers(): def test_ddpm_diffusers():
@ -83,7 +83,7 @@ def test_euler_diffusers():
use_karras_sigmas=False, use_karras_sigmas=False,
) )
diffusers_scheduler.set_timesteps(30) diffusers_scheduler.set_timesteps(30)
refiners_scheduler = EulerScheduler(num_inference_steps=30) refiners_scheduler = Euler(num_inference_steps=30)
sample = randn(1, 4, 32, 32) sample = randn(1, 4, 32, 32)
predicted_noise = randn(1, 4, 32, 32) predicted_noise = randn(1, 4, 32, 32)