make Scheduler a fl.Module + Change name Scheduler -> Solver

This commit is contained in:
limiteinductive 2024-01-31 14:07:34 +00:00 committed by Benjamin Trom
parent 07cb2ff21c
commit 73f6ccfc98
19 changed files with 184 additions and 173 deletions

View file

@ -19,7 +19,7 @@ ______________________________________________________________________
## Latest News 🔥
- Added [Euler's method](https://arxiv.org/abs/2206.00364) to schedulers (contributed by [@israfelsr](https://github.com/israfelsr))
- Added [Euler's method](https://arxiv.org/abs/2206.00364) to solvers (contributed by [@israfelsr](https://github.com/israfelsr))
- Added [DINOv2](https://github.com/facebookresearch/dinov2) for high-performance visual features (contributed by [@Laurent2916](https://github.com/Laurent2916))
- Added [FreeU](https://github.com/ChenyangSi/FreeU) for improved quality at no cost (contributed by [@isamu-isozaki](https://github.com/isamu-isozaki))
- Added [Restart Sampling](https://github.com/Newbeeer/diffusion_restart_sampling) for improved image generation ([example](https://github.com/Newbeeer/diffusion_restart_sampling/issues/4))

View file

@ -37,8 +37,8 @@ def convert(args: Args) -> dict[str, torch.Tensor]:
clip_text_embedding = torch.rand(1, 77, 768)
unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
scheduler = DPMSolver(num_inference_steps=10)
timestep = scheduler.timesteps[0].unsqueeze(dim=0)
solver = DPMSolver(num_inference_steps=10)
timestep = solver.timesteps[0].unsqueeze(dim=0)
unet.set_timestep(timestep=timestep.unsqueeze(dim=0))
x = torch.randn(1, 4, 64, 64)

View file

@ -5,7 +5,7 @@ from refiners.foundationals.latent_diffusion.auto_encoder import (
LatentDiffusionAutoencoder,
)
from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter
from refiners.foundationals.latent_diffusion.schedulers import DPMSolver, Scheduler
from refiners.foundationals.latent_diffusion.solvers import DPMSolver, Solver
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
SD1ControlnetAdapter,
SD1IPAdapter,
@ -33,7 +33,7 @@ __all__ = [
"SDXLIPAdapter",
"SDXLT2IAdapter",
"DPMSolver",
"Scheduler",
"Solver",
"CLIPTextEncoderL",
"LatentDiffusionAutoencoder",
"SDFreeUAdapter",

View file

@ -7,7 +7,7 @@ from torch import Tensor, device as Device, dtype as DType
import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
T = TypeVar("T", bound="fl.Module")
@ -20,7 +20,7 @@ class LatentDiffusionModel(fl.Module, ABC):
unet: fl.Module,
lda: LatentDiffusionAutoencoder,
clip_text_encoder: fl.Module,
scheduler: Scheduler,
solver: Solver,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
@ -30,10 +30,10 @@ class LatentDiffusionModel(fl.Module, ABC):
self.unet = unet.to(device=self.device, dtype=self.dtype)
self.lda = lda.to(device=self.device, dtype=self.dtype)
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
self.solver = solver.to(device=self.device, dtype=self.dtype)
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
self.scheduler = self.scheduler.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
def init_latents(
self,
@ -51,15 +51,15 @@ class LatentDiffusionModel(fl.Module, ABC):
if init_image is None:
return noise
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
return self.scheduler.add_noise(
return self.solver.add_noise(
x=encoded_image,
noise=noise,
step=self.scheduler.first_inference_step,
step=self.solver.first_inference_step,
)
@property
def steps(self) -> list[int]:
return self.scheduler.inference_steps
return self.solver.inference_steps
@abstractmethod
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
@ -82,12 +82,12 @@ class LatentDiffusionModel(fl.Module, ABC):
def forward(
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
) -> Tensor:
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
timestep = self.solver.timesteps[step].unsqueeze(dim=0)
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
# scale latents for schedulers that need it
latents = self.scheduler.scale_model_input(latents, step=step)
# scale latents for solvers that need it
latents = self.solver.scale_model_input(latents, step=step)
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
# classifier-free guidance
@ -101,14 +101,14 @@ class LatentDiffusionModel(fl.Module, ABC):
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs
)
return self.scheduler(x, predicted_noise=predicted_noise, step=step)
return self.solver(x, predicted_noise=predicted_noise, step=step)
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
return self.__class__(
unet=self.unet.structural_copy(),
lda=self.lda.structural_copy(),
clip_text_encoder=self.clip_text_encoder.structural_copy(),
scheduler=self.scheduler,
solver=self.solver,
device=self.device,
dtype=self.dtype,
)

View file

@ -51,7 +51,7 @@ class MultiDiffusion(Generic[T, D], ABC):
match step:
case step if step == target.start_step and target.init_latents is not None:
noise_view = target.crop(noise)
view = self.ldm.scheduler.add_noise(
view = self.ldm.solver.add_noise(
x=target.init_latents,
noise=noise_view,
step=step,

View file

@ -5,22 +5,22 @@ from typing import Generic, TypeVar
import torch
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
T = TypeVar("T", bound=LatentDiffusionModel)
def add_noise_interval(
scheduler: Scheduler,
solver: Solver,
/,
x: torch.Tensor,
noise: torch.Tensor,
initial_timestep: torch.Tensor,
target_timestep: torch.Tensor,
) -> torch.Tensor:
initial_cumulative_scale_factors = scheduler.cumulative_scale_factors[initial_timestep]
target_cumulative_scale_factors = scheduler.cumulative_scale_factors[target_timestep]
initial_cumulative_scale_factors = solver.cumulative_scale_factors[initial_timestep]
target_cumulative_scale_factors = solver.cumulative_scale_factors[target_timestep]
factor = target_cumulative_scale_factors / initial_cumulative_scale_factors
noised_x = factor * x + torch.sqrt(1 - factor**2) * noise
@ -33,7 +33,7 @@ class Restart(Generic[T]):
Implements the restart sampling strategy from the paper "Restart Sampling for Improving Generative Processes"
(https://arxiv.org/pdf/2306.14878.pdf)
Works only with the DDIM scheduler for now.
Works only with the DDIM solver for now.
"""
ldm: T
@ -43,7 +43,7 @@ class Restart(Generic[T]):
end_time: float = 2
def __post_init__(self) -> None:
assert isinstance(self.ldm.scheduler, DDIM), "Restart sampling only works with DDIM scheduler"
assert isinstance(self.ldm.solver, DDIM), "Restart sampling only works with DDIM solver"
def __call__(
self,
@ -53,15 +53,15 @@ class Restart(Generic[T]):
condition_scale: float = 7.5,
**kwargs: torch.Tensor,
) -> torch.Tensor:
original_scheduler = self.ldm.scheduler
new_scheduler = DDIM(self.ldm.scheduler.num_inference_steps, device=self.device, dtype=self.dtype)
new_scheduler.timesteps = self.timesteps
self.ldm.scheduler = new_scheduler
original_solver = self.ldm.solver
new_solver = DDIM(self.ldm.solver.num_inference_steps, device=self.device, dtype=self.dtype)
new_solver.timesteps = self.timesteps
self.ldm.solver = new_solver
for _ in range(self.num_iterations):
noise = torch.randn_like(input=x, device=self.device, dtype=self.dtype)
x = add_noise_interval(
new_scheduler,
new_solver,
x=x,
noise=noise,
initial_timestep=self.timesteps[-1],
@ -73,18 +73,18 @@ class Restart(Generic[T]):
x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=condition_scale, **kwargs
)
self.ldm.scheduler = original_scheduler
self.ldm.solver = original_solver
return x
@cached_property
def start_step(self) -> int:
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors
return int(torch.argmin(input=torch.abs(input=sigmas[self.ldm.scheduler.timesteps] - self.start_time)))
sigmas = self.ldm.solver.noise_std / self.ldm.solver.cumulative_scale_factors
return int(torch.argmin(input=torch.abs(input=sigmas[self.ldm.solver.timesteps] - self.start_time)))
@cached_property
def end_timestep(self) -> int:
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors
sigmas = self.ldm.solver.noise_std / self.ldm.solver.cumulative_scale_factors
return int(torch.argmin(input=torch.abs(input=sigmas - self.end_time)))
@cached_property
@ -92,7 +92,7 @@ class Restart(Generic[T]):
return (
torch.round(
torch.linspace(
start=int(self.ldm.scheduler.timesteps[self.start_step]),
start=int(self.ldm.solver.timesteps[self.start_step]),
end=self.end_timestep,
steps=self.num_steps,
)

View file

@ -1,7 +0,0 @@
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
__all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"]

View file

@ -9,7 +9,7 @@ import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.context import Contexts
from refiners.fluxion.utils import gaussian_blur, interpolate
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
if TYPE_CHECKING:
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
@ -89,13 +89,13 @@ class SAGAdapter(Generic[T], fl.Chain, Adapter[T]):
return interpolate(attn_mask, Size((h, w)))
def compute_degraded_latents(
self, scheduler: Scheduler, latents: Tensor, noise: Tensor, step: int, classifier_free_guidance: bool = True
self, solver: Solver, latents: Tensor, noise: Tensor, step: int, classifier_free_guidance: bool = True
) -> Tensor:
sag_mask = self.compute_sag_mask(latents=latents, classifier_free_guidance=classifier_free_guidance)
original_latents = scheduler.remove_noise(x=latents, noise=noise, step=step)
original_latents = solver.remove_noise(x=latents, noise=noise, step=step)
degraded_latents = gaussian_blur(original_latents, kernel_size=self.kernel_size, sigma=self.sigma)
degraded_latents = degraded_latents * sag_mask + original_latents * (1 - sag_mask)
return scheduler.add_noise(degraded_latents, noise=noise, step=step)
return solver.add_noise(degraded_latents, noise=noise, step=step)
def init_context(self) -> Contexts:
return {"self_attention_map": {"middle_block_attn_map": None, "middle_block_attn_shape": []}}

View file

@ -0,0 +1,7 @@
from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler"]

View file

@ -1,9 +1,9 @@
from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
class DDIM(Scheduler):
class DDIM(Solver):
def __init__(
self,
num_inference_steps: int,
@ -25,15 +25,14 @@ class DDIM(Scheduler):
device=device,
dtype=dtype,
)
self.timesteps = self._generate_timesteps()
def _generate_timesteps(self) -> Tensor:
"""
Generates decreasing timesteps with 'leading' spacing and offset of 1
similar to diffusers settings for the DDIM scheduler in Stable Diffusion 1.5
similar to diffusers settings for the DDIM solver in Stable Diffusion 1.5
"""
step_ratio = self.num_train_timesteps // self.num_inference_steps
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio + 1
return timesteps.flip(0)
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:

View file

@ -1,9 +1,9 @@
from torch import Generator, Tensor, arange, device as Device, dtype as DType
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
class DDPM(Scheduler):
class DDPM(Solver):
"""
Denoising Diffusion Probabilistic Model

View file

@ -3,10 +3,10 @@ from collections import deque
import numpy as np
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
class DPMSolver(Scheduler):
class DPMSolver(Solver):
"""
Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
@ -48,7 +48,6 @@ class DPMSolver(Scheduler):
# ...and we want the same result as the original codebase.
return tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
device=self.device,
).flip(0)
def rebuild(

View file

@ -2,10 +2,10 @@ import numpy as np
import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
class EulerScheduler(Scheduler):
class Euler(Solver):
def __init__(
self,
num_inference_steps: int,
@ -40,9 +40,7 @@ class EulerScheduler(Scheduler):
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# ...and we want the same result as the original codebase.
timesteps = torch.tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device
).flip(0)
timesteps = torch.tensor(np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps)).flip(0)
return timesteps
def _generate_sigmas(self) -> Tensor:

View file

@ -4,7 +4,9 @@ from typing import TypeVar
from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
T = TypeVar("T", bound="Scheduler")
from refiners.fluxion import layers as fl
T = TypeVar("T", bound="Solver")
class NoiseSchedule(str, Enum):
@ -13,11 +15,11 @@ class NoiseSchedule(str, Enum):
KARRAS = "karras"
class Scheduler(ABC):
class Solver(fl.Module, ABC):
"""
A base class for creating a diffusion model scheduler.
A base class for creating a diffusion model solver.
The Scheduler creates a sequence of noise and scaling factors used in the diffusion process,
Solver creates a sequence of noise and scaling factors used in the diffusion process,
which gradually transforms the original data distribution into a Gaussian one.
This process is described using several parameters such as initial and final diffusion rates,
@ -36,9 +38,8 @@ class Scheduler(ABC):
first_inference_step: int = 0,
device: Device | str = "cpu",
dtype: DType = float32,
):
self.device: Device = Device(device)
self.dtype: DType = dtype
) -> None:
super().__init__()
self.num_inference_steps = num_inference_steps
self.num_train_timesteps = num_train_timesteps
self.initial_diffusion_rate = initial_diffusion_rate
@ -50,6 +51,7 @@ class Scheduler(ABC):
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
self.timesteps = self._generate_timesteps()
self.to(device=device, dtype=dtype)
@abstractmethod
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
@ -69,57 +71,6 @@ class Scheduler(ABC):
"""
...
@property
def all_steps(self) -> list[int]:
return list(range(self.num_inference_steps))
@property
def inference_steps(self) -> list[int]:
return self.all_steps[self.first_inference_step :]
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
return self.__class__(
num_inference_steps=num_inference_steps,
num_train_timesteps=self.num_train_timesteps,
initial_diffusion_rate=self.initial_diffusion_rate,
final_diffusion_rate=self.final_diffusion_rate,
noise_schedule=self.noise_schedule,
first_inference_step=first_inference_step,
device=self.device,
dtype=self.dtype,
)
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
"""
For compatibility with schedulers that need to scale the input according to the current timestep.
"""
return x
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
return (
linspace(
start=self.initial_diffusion_rate ** (1 / power),
end=self.final_diffusion_rate ** (1 / power),
steps=self.num_train_timesteps,
device=self.device,
dtype=self.dtype,
)
** power
)
def sample_noise_schedule(self) -> Tensor:
match self.noise_schedule:
case "uniform":
return 1 - self.sample_power_distribution(1)
case "quadratic":
return 1 - self.sample_power_distribution(2)
case "karras":
return 1 - self.sample_power_distribution(7)
case _:
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
def add_noise(
self,
x: Tensor,
@ -141,14 +92,77 @@ class Scheduler(ABC):
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
return denoised_x
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
if device is not None:
self.device = Device(device)
self.timesteps = self.timesteps.to(device)
if dtype is not None:
self.dtype = dtype
self.scale_factors = self.scale_factors.to(device, dtype=dtype)
self.cumulative_scale_factors = self.cumulative_scale_factors.to(device, dtype=dtype)
self.noise_std = self.noise_std.to(device, dtype=dtype)
self.signal_to_noise_ratios = self.signal_to_noise_ratios.to(device, dtype=dtype)
@property
def all_steps(self) -> list[int]:
return list(range(self.num_inference_steps))
@property
def inference_steps(self) -> list[int]:
return self.all_steps[self.first_inference_step :]
@property
def device(self) -> Device:
return self.scale_factors.device
@property
def dtype(self) -> DType:
return self.scale_factors.dtype
@device.setter
def device(self, device: Device | str | None = None) -> None:
self.to(device=device)
@dtype.setter
def dtype(self, dtype: DType | None = None) -> None:
self.to(dtype=dtype)
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
return self.__class__(
num_inference_steps=num_inference_steps,
num_train_timesteps=self.num_train_timesteps,
initial_diffusion_rate=self.initial_diffusion_rate,
final_diffusion_rate=self.final_diffusion_rate,
noise_schedule=self.noise_schedule,
first_inference_step=first_inference_step,
device=self.device,
dtype=self.dtype,
)
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
"""
For compatibility with solvers that need to scale the input according to the current timestep.
"""
return x
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
return (
linspace(
start=self.initial_diffusion_rate ** (1 / power),
end=self.final_diffusion_rate ** (1 / power),
steps=self.num_train_timesteps,
)
** power
)
def sample_noise_schedule(self) -> Tensor:
match self.noise_schedule:
case "uniform":
return 1 - self.sample_power_distribution(1)
case "quadratic":
return 1 - self.sample_power_distribution(2)
case "karras":
return 1 - self.sample_power_distribution(7)
case _:
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
super().to(device=device, dtype=dtype)
for name, attr in [(name, attr) for name, attr in self.__dict__.items() if isinstance(attr, Tensor)]:
match name:
case "timesteps":
setattr(self, name, attr.to(device=device))
case _:
setattr(self, name, attr.to(device=device, dtype=dtype))
return self

View file

@ -7,8 +7,8 @@ from refiners.fluxion.utils import image_to_tensor, interpolate
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
from refiners.foundationals.latent_diffusion.stable_diffusion_1.self_attention_guidance import SD1SAGAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
@ -26,20 +26,20 @@ class StableDiffusion_1(LatentDiffusionModel):
unet: SD1UNet | None = None,
lda: SD1Autoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None,
scheduler: Scheduler | None = None,
solver: Solver | None = None,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
unet = unet or SD1UNet(in_channels=4)
lda = lda or SD1Autoencoder()
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
scheduler = scheduler or DPMSolver(num_inference_steps=30)
solver = solver or DPMSolver(num_inference_steps=30)
super().__init__(
unet=unet,
lda=lda,
clip_text_encoder=clip_text_encoder,
scheduler=scheduler,
solver=solver,
device=device,
dtype=dtype,
)
@ -82,14 +82,14 @@ class StableDiffusion_1(LatentDiffusionModel):
assert sag is not None
degraded_latents = sag.compute_degraded_latents(
scheduler=self.scheduler,
solver=self.solver,
latents=x,
noise=noise,
step=step,
classifier_free_guidance=True,
)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
timestep = self.solver.timesteps[step].unsqueeze(dim=0)
negative_embedding, _ = clip_text_embedding.chunk(2)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
if "ip_adapter" in self.unet.provider.contexts:
@ -111,14 +111,14 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
unet: SD1UNet | None = None,
lda: SD1Autoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None,
scheduler: Scheduler | None = None,
solver: Solver | None = None,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
self.mask_latents: Tensor | None = None
self.target_image_latents: Tensor | None = None
super().__init__(
unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler, device=device, dtype=dtype
unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, solver=solver, device=device, dtype=dtype
)
def forward(
@ -162,7 +162,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
assert self.target_image_latents is not None
degraded_latents = sag.compute_degraded_latents(
scheduler=self.scheduler,
solver=self.solver,
latents=x,
noise=noise,
step=step,
@ -173,7 +173,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
dim=1,
)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
timestep = self.solver.timesteps[step].unsqueeze(dim=0)
negative_embedding, _ = clip_text_embedding.chunk(2)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)

View file

@ -3,8 +3,8 @@ from torch import Tensor, device as Device, dtype as DType
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.self_attention_guidance import SDXLSAGAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
@ -23,20 +23,20 @@ class StableDiffusion_XL(LatentDiffusionModel):
unet: SDXLUNet | None = None,
lda: SDXLAutoencoder | None = None,
clip_text_encoder: DoubleTextEncoder | None = None,
scheduler: Scheduler | None = None,
solver: Solver | None = None,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
unet = unet or SDXLUNet(in_channels=4)
lda = lda or SDXLAutoencoder()
clip_text_encoder = clip_text_encoder or DoubleTextEncoder()
scheduler = scheduler or DDIM(num_inference_steps=30)
solver = solver or DDIM(num_inference_steps=30)
super().__init__(
unet=unet,
lda=lda,
clip_text_encoder=clip_text_encoder,
scheduler=scheduler,
solver=solver,
device=device,
dtype=dtype,
)
@ -131,7 +131,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
assert sag is not None
degraded_latents = sag.compute_degraded_latents(
scheduler=self.scheduler,
solver=self.solver,
latents=x,
noise=noise,
step=step,
@ -140,7 +140,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
negative_text_embedding, _ = clip_text_embedding.chunk(2)
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
timestep = self.solver.timesteps[step].unsqueeze(dim=0)
time_ids, _ = time_ids.chunk(2)
self.set_unet_context(

View file

@ -19,7 +19,8 @@ from refiners.foundationals.latent_diffusion import (
SD1UNet,
StableDiffusion_1,
)
from refiners.foundationals.latent_diffusion.schedulers import DDPM
from refiners.foundationals.latent_diffusion.solvers import DDPM
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder
from refiners.training_utils.callback import Callback
from refiners.training_utils.config import BaseConfig
@ -150,7 +151,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
return TextEmbeddingLatentsDataset(trainer=self)
@cached_property
def ddpm_scheduler(self) -> DDPM:
def ddpm_solver(self) -> Solver:
return DDPM(
num_inference_steps=1000,
device=self.device,
@ -159,7 +160,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
def sample_timestep(self) -> Tensor:
random_step = random.randint(a=self.config.latent_diffusion.min_step, b=self.config.latent_diffusion.max_step)
self.current_step = random_step
return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0)
return self.ddpm_solver.timesteps[random_step].unsqueeze(dim=0)
def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Tensor:
return sample_noise(
@ -170,7 +171,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
clip_text_embedding, latents = batch.text_embeddings, batch.latents
timestep = self.sample_timestep()
noise = self.sample_noise(size=latents.shape, dtype=latents.dtype)
noisy_latents = self.ddpm_scheduler.add_noise(x=latents, noise=noise, step=self.current_step)
noisy_latents = self.ddpm_solver.add_noise(x=latents, noise=noise, step=self.current_step)
self.unet.set_timestep(timestep=timestep)
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
prediction = self.unet(noisy_latents)
@ -182,7 +183,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
unet=self.unet,
lda=self.lda,
clip_text_encoder=self.text_encoder,
scheduler=DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps),
solver=DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps),
device=self.device,
)
prompts = self.config.test_diffusion.prompts

View file

@ -24,8 +24,8 @@ from refiners.foundationals.latent_diffusion.lora import SDLoraManager
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
from refiners.foundationals.latent_diffusion.restart import Restart
from refiners.foundationals.latent_diffusion.schedulers import DDIM, EulerScheduler
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
from tests.utils import ensure_similar_images
@ -491,8 +491,8 @@ def sd15_ddim(
warn("not running on CPU, skipping")
pytest.skip()
ddim_scheduler = DDIM(num_inference_steps=20)
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device)
ddim_solver = DDIM(num_inference_steps=20)
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
@ -509,8 +509,8 @@ def sd15_ddim_karras(
warn("not running on CPU, skipping")
pytest.skip()
ddim_scheduler = DDIM(num_inference_steps=20, noise_schedule=NoiseSchedule.KARRAS)
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device)
ddim_solver = DDIM(num_inference_steps=20, noise_schedule=NoiseSchedule.KARRAS)
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
@ -527,8 +527,8 @@ def sd15_euler(
warn("not running on CPU, skipping")
pytest.skip()
euler_scheduler = EulerScheduler(num_inference_steps=30)
sd15 = StableDiffusion_1(scheduler=euler_scheduler, device=test_device)
euler_solver = Euler(num_inference_steps=30)
sd15 = StableDiffusion_1(solver=euler_solver, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
@ -545,8 +545,8 @@ def sd15_ddim_lda_ft_mse(
warn("not running on CPU, skipping")
pytest.skip()
ddim_scheduler = DDIM(num_inference_steps=20)
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device)
ddim_solver = DDIM(num_inference_steps=20)
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
sd15.lda.load_state_dict(load_from_safetensors(lda_ft_mse_weights))
@ -599,8 +599,8 @@ def sdxl_ddim(
warn(message="not running on CPU, skipping")
pytest.skip()
scheduler = DDIM(num_inference_steps=30)
sdxl = StableDiffusion_XL(scheduler=scheduler, device=test_device)
solver = DDIM(num_inference_steps=30)
sdxl = StableDiffusion_XL(solver=solver, device=test_device)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_weights)
@ -617,8 +617,8 @@ def sdxl_ddim_lda_fp16_fix(
warn(message="not running on CPU, skipping")
pytest.skip()
scheduler = DDIM(num_inference_steps=30)
sdxl = StableDiffusion_XL(scheduler=scheduler, device=test_device)
solver = DDIM(num_inference_steps=30)
sdxl = StableDiffusion_XL(solver=solver, device=test_device)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights)
@ -659,8 +659,8 @@ def test_diffusion_std_random_init_euler(
sd15_euler: StableDiffusion_1, expected_image_std_random_init_euler: Image.Image, test_device: torch.device
):
sd15 = sd15_euler
euler_scheduler = sd15_euler.scheduler
assert isinstance(euler_scheduler, EulerScheduler)
euler_solver = sd15_euler.solver
assert isinstance(euler_solver, Euler)
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
@ -670,7 +670,7 @@ def test_diffusion_std_random_init_euler(
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
x = x * euler_scheduler.init_noise_sigma
x = x * euler_solver.init_noise_sigma
for step in sd15.steps:
x = sd15(
@ -1202,7 +1202,7 @@ def test_diffusion_refonly(
for step in sd15.steps:
noise = torch.randn(2, 4, 64, 64, device=test_device)
noised_guide = sd15.scheduler.add_noise(guide, noise, step)
noised_guide = sd15.solver.add_noise(guide, noise, step)
refonly_adapter.set_controlnet_condition(noised_guide)
x = sd15(
x,
@ -1244,7 +1244,7 @@ def test_diffusion_inpainting_refonly(
for step in sd15.steps:
noise = torch.randn_like(guide)
noised_guide = sd15.scheduler.add_noise(guide, noise, step)
noised_guide = sd15.solver.add_noise(guide, noise, step)
# See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support
# inpaint variation models")
noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1)

View file

@ -5,7 +5,7 @@ import pytest
from torch import Tensor, allclose, device as Device, equal, isclose, randn
from refiners.fluxion import manual_seed
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver, EulerScheduler
from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler
def test_ddpm_diffusers():
@ -83,7 +83,7 @@ def test_euler_diffusers():
use_karras_sigmas=False,
)
diffusers_scheduler.set_timesteps(30)
refiners_scheduler = EulerScheduler(num_inference_steps=30)
refiners_scheduler = Euler(num_inference_steps=30)
sample = randn(1, 4, 32, 32)
predicted_noise = randn(1, 4, 32, 32)