mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
make Scheduler a fl.Module + Change name Scheduler -> Solver
This commit is contained in:
parent
07cb2ff21c
commit
73f6ccfc98
|
@ -19,7 +19,7 @@ ______________________________________________________________________
|
||||||
|
|
||||||
## Latest News 🔥
|
## Latest News 🔥
|
||||||
|
|
||||||
- Added [Euler's method](https://arxiv.org/abs/2206.00364) to schedulers (contributed by [@israfelsr](https://github.com/israfelsr))
|
- Added [Euler's method](https://arxiv.org/abs/2206.00364) to solvers (contributed by [@israfelsr](https://github.com/israfelsr))
|
||||||
- Added [DINOv2](https://github.com/facebookresearch/dinov2) for high-performance visual features (contributed by [@Laurent2916](https://github.com/Laurent2916))
|
- Added [DINOv2](https://github.com/facebookresearch/dinov2) for high-performance visual features (contributed by [@Laurent2916](https://github.com/Laurent2916))
|
||||||
- Added [FreeU](https://github.com/ChenyangSi/FreeU) for improved quality at no cost (contributed by [@isamu-isozaki](https://github.com/isamu-isozaki))
|
- Added [FreeU](https://github.com/ChenyangSi/FreeU) for improved quality at no cost (contributed by [@isamu-isozaki](https://github.com/isamu-isozaki))
|
||||||
- Added [Restart Sampling](https://github.com/Newbeeer/diffusion_restart_sampling) for improved image generation ([example](https://github.com/Newbeeer/diffusion_restart_sampling/issues/4))
|
- Added [Restart Sampling](https://github.com/Newbeeer/diffusion_restart_sampling) for improved image generation ([example](https://github.com/Newbeeer/diffusion_restart_sampling/issues/4))
|
||||||
|
|
|
@ -37,8 +37,8 @@ def convert(args: Args) -> dict[str, torch.Tensor]:
|
||||||
clip_text_embedding = torch.rand(1, 77, 768)
|
clip_text_embedding = torch.rand(1, 77, 768)
|
||||||
unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||||
|
|
||||||
scheduler = DPMSolver(num_inference_steps=10)
|
solver = DPMSolver(num_inference_steps=10)
|
||||||
timestep = scheduler.timesteps[0].unsqueeze(dim=0)
|
timestep = solver.timesteps[0].unsqueeze(dim=0)
|
||||||
unet.set_timestep(timestep=timestep.unsqueeze(dim=0))
|
unet.set_timestep(timestep=timestep.unsqueeze(dim=0))
|
||||||
|
|
||||||
x = torch.randn(1, 4, 64, 64)
|
x = torch.randn(1, 4, 64, 64)
|
||||||
|
|
|
@ -5,7 +5,7 @@ from refiners.foundationals.latent_diffusion.auto_encoder import (
|
||||||
LatentDiffusionAutoencoder,
|
LatentDiffusionAutoencoder,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter
|
from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter
|
||||||
from refiners.foundationals.latent_diffusion.schedulers import DPMSolver, Scheduler
|
from refiners.foundationals.latent_diffusion.solvers import DPMSolver, Solver
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
|
||||||
SD1ControlnetAdapter,
|
SD1ControlnetAdapter,
|
||||||
SD1IPAdapter,
|
SD1IPAdapter,
|
||||||
|
@ -33,7 +33,7 @@ __all__ = [
|
||||||
"SDXLIPAdapter",
|
"SDXLIPAdapter",
|
||||||
"SDXLT2IAdapter",
|
"SDXLT2IAdapter",
|
||||||
"DPMSolver",
|
"DPMSolver",
|
||||||
"Scheduler",
|
"Solver",
|
||||||
"CLIPTextEncoderL",
|
"CLIPTextEncoderL",
|
||||||
"LatentDiffusionAutoencoder",
|
"LatentDiffusionAutoencoder",
|
||||||
"SDFreeUAdapter",
|
"SDFreeUAdapter",
|
||||||
|
|
|
@ -7,7 +7,7 @@ from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||||
|
|
||||||
T = TypeVar("T", bound="fl.Module")
|
T = TypeVar("T", bound="fl.Module")
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
unet: fl.Module,
|
unet: fl.Module,
|
||||||
lda: LatentDiffusionAutoencoder,
|
lda: LatentDiffusionAutoencoder,
|
||||||
clip_text_encoder: fl.Module,
|
clip_text_encoder: fl.Module,
|
||||||
scheduler: Scheduler,
|
solver: Solver,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = torch.float32,
|
dtype: DType = torch.float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -30,10 +30,10 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
self.unet = unet.to(device=self.device, dtype=self.dtype)
|
self.unet = unet.to(device=self.device, dtype=self.dtype)
|
||||||
self.lda = lda.to(device=self.device, dtype=self.dtype)
|
self.lda = lda.to(device=self.device, dtype=self.dtype)
|
||||||
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
|
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
|
||||||
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
|
self.solver = solver.to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
|
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
|
||||||
self.scheduler = self.scheduler.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
|
self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
|
||||||
|
|
||||||
def init_latents(
|
def init_latents(
|
||||||
self,
|
self,
|
||||||
|
@ -51,15 +51,15 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
if init_image is None:
|
if init_image is None:
|
||||||
return noise
|
return noise
|
||||||
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
|
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
|
||||||
return self.scheduler.add_noise(
|
return self.solver.add_noise(
|
||||||
x=encoded_image,
|
x=encoded_image,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
step=self.scheduler.first_inference_step,
|
step=self.solver.first_inference_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def steps(self) -> list[int]:
|
def steps(self) -> list[int]:
|
||||||
return self.scheduler.inference_steps
|
return self.solver.inference_steps
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
|
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
|
||||||
|
@ -82,12 +82,12 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
def forward(
|
def forward(
|
||||||
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
|
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
timestep = self.solver.timesteps[step].unsqueeze(dim=0)
|
||||||
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
|
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
|
||||||
|
|
||||||
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
|
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
|
||||||
# scale latents for schedulers that need it
|
# scale latents for solvers that need it
|
||||||
latents = self.scheduler.scale_model_input(latents, step=step)
|
latents = self.solver.scale_model_input(latents, step=step)
|
||||||
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
||||||
|
|
||||||
# classifier-free guidance
|
# classifier-free guidance
|
||||||
|
@ -101,14 +101,14 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs
|
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.scheduler(x, predicted_noise=predicted_noise, step=step)
|
return self.solver(x, predicted_noise=predicted_noise, step=step)
|
||||||
|
|
||||||
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
|
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
unet=self.unet.structural_copy(),
|
unet=self.unet.structural_copy(),
|
||||||
lda=self.lda.structural_copy(),
|
lda=self.lda.structural_copy(),
|
||||||
clip_text_encoder=self.clip_text_encoder.structural_copy(),
|
clip_text_encoder=self.clip_text_encoder.structural_copy(),
|
||||||
scheduler=self.scheduler,
|
solver=self.solver,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
|
|
@ -51,7 +51,7 @@ class MultiDiffusion(Generic[T, D], ABC):
|
||||||
match step:
|
match step:
|
||||||
case step if step == target.start_step and target.init_latents is not None:
|
case step if step == target.start_step and target.init_latents is not None:
|
||||||
noise_view = target.crop(noise)
|
noise_view = target.crop(noise)
|
||||||
view = self.ldm.scheduler.add_noise(
|
view = self.ldm.solver.add_noise(
|
||||||
x=target.init_latents,
|
x=target.init_latents,
|
||||||
noise=noise_view,
|
noise=noise_view,
|
||||||
step=step,
|
step=step,
|
||||||
|
|
|
@ -5,22 +5,22 @@ from typing import Generic, TypeVar
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||||
|
|
||||||
T = TypeVar("T", bound=LatentDiffusionModel)
|
T = TypeVar("T", bound=LatentDiffusionModel)
|
||||||
|
|
||||||
|
|
||||||
def add_noise_interval(
|
def add_noise_interval(
|
||||||
scheduler: Scheduler,
|
solver: Solver,
|
||||||
/,
|
/,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
initial_timestep: torch.Tensor,
|
initial_timestep: torch.Tensor,
|
||||||
target_timestep: torch.Tensor,
|
target_timestep: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
initial_cumulative_scale_factors = scheduler.cumulative_scale_factors[initial_timestep]
|
initial_cumulative_scale_factors = solver.cumulative_scale_factors[initial_timestep]
|
||||||
target_cumulative_scale_factors = scheduler.cumulative_scale_factors[target_timestep]
|
target_cumulative_scale_factors = solver.cumulative_scale_factors[target_timestep]
|
||||||
|
|
||||||
factor = target_cumulative_scale_factors / initial_cumulative_scale_factors
|
factor = target_cumulative_scale_factors / initial_cumulative_scale_factors
|
||||||
noised_x = factor * x + torch.sqrt(1 - factor**2) * noise
|
noised_x = factor * x + torch.sqrt(1 - factor**2) * noise
|
||||||
|
@ -33,7 +33,7 @@ class Restart(Generic[T]):
|
||||||
Implements the restart sampling strategy from the paper "Restart Sampling for Improving Generative Processes"
|
Implements the restart sampling strategy from the paper "Restart Sampling for Improving Generative Processes"
|
||||||
(https://arxiv.org/pdf/2306.14878.pdf)
|
(https://arxiv.org/pdf/2306.14878.pdf)
|
||||||
|
|
||||||
Works only with the DDIM scheduler for now.
|
Works only with the DDIM solver for now.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ldm: T
|
ldm: T
|
||||||
|
@ -43,7 +43,7 @@ class Restart(Generic[T]):
|
||||||
end_time: float = 2
|
end_time: float = 2
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
assert isinstance(self.ldm.scheduler, DDIM), "Restart sampling only works with DDIM scheduler"
|
assert isinstance(self.ldm.solver, DDIM), "Restart sampling only works with DDIM solver"
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -53,15 +53,15 @@ class Restart(Generic[T]):
|
||||||
condition_scale: float = 7.5,
|
condition_scale: float = 7.5,
|
||||||
**kwargs: torch.Tensor,
|
**kwargs: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
original_scheduler = self.ldm.scheduler
|
original_solver = self.ldm.solver
|
||||||
new_scheduler = DDIM(self.ldm.scheduler.num_inference_steps, device=self.device, dtype=self.dtype)
|
new_solver = DDIM(self.ldm.solver.num_inference_steps, device=self.device, dtype=self.dtype)
|
||||||
new_scheduler.timesteps = self.timesteps
|
new_solver.timesteps = self.timesteps
|
||||||
self.ldm.scheduler = new_scheduler
|
self.ldm.solver = new_solver
|
||||||
|
|
||||||
for _ in range(self.num_iterations):
|
for _ in range(self.num_iterations):
|
||||||
noise = torch.randn_like(input=x, device=self.device, dtype=self.dtype)
|
noise = torch.randn_like(input=x, device=self.device, dtype=self.dtype)
|
||||||
x = add_noise_interval(
|
x = add_noise_interval(
|
||||||
new_scheduler,
|
new_solver,
|
||||||
x=x,
|
x=x,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
initial_timestep=self.timesteps[-1],
|
initial_timestep=self.timesteps[-1],
|
||||||
|
@ -73,18 +73,18 @@ class Restart(Generic[T]):
|
||||||
x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=condition_scale, **kwargs
|
x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=condition_scale, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ldm.scheduler = original_scheduler
|
self.ldm.solver = original_solver
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def start_step(self) -> int:
|
def start_step(self) -> int:
|
||||||
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors
|
sigmas = self.ldm.solver.noise_std / self.ldm.solver.cumulative_scale_factors
|
||||||
return int(torch.argmin(input=torch.abs(input=sigmas[self.ldm.scheduler.timesteps] - self.start_time)))
|
return int(torch.argmin(input=torch.abs(input=sigmas[self.ldm.solver.timesteps] - self.start_time)))
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def end_timestep(self) -> int:
|
def end_timestep(self) -> int:
|
||||||
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors
|
sigmas = self.ldm.solver.noise_std / self.ldm.solver.cumulative_scale_factors
|
||||||
return int(torch.argmin(input=torch.abs(input=sigmas - self.end_time)))
|
return int(torch.argmin(input=torch.abs(input=sigmas - self.end_time)))
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
|
@ -92,7 +92,7 @@ class Restart(Generic[T]):
|
||||||
return (
|
return (
|
||||||
torch.round(
|
torch.round(
|
||||||
torch.linspace(
|
torch.linspace(
|
||||||
start=int(self.ldm.scheduler.timesteps[self.start_step]),
|
start=int(self.ldm.solver.timesteps[self.start_step]),
|
||||||
end=self.end_timestep,
|
end=self.end_timestep,
|
||||||
steps=self.num_steps,
|
steps=self.num_steps,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,7 +0,0 @@
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
|
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler
|
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
|
||||||
|
|
||||||
__all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"]
|
|
|
@ -9,7 +9,7 @@ import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
from refiners.fluxion.context import Contexts
|
from refiners.fluxion.context import Contexts
|
||||||
from refiners.fluxion.utils import gaussian_blur, interpolate
|
from refiners.fluxion.utils import gaussian_blur, interpolate
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
@ -89,13 +89,13 @@ class SAGAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
return interpolate(attn_mask, Size((h, w)))
|
return interpolate(attn_mask, Size((h, w)))
|
||||||
|
|
||||||
def compute_degraded_latents(
|
def compute_degraded_latents(
|
||||||
self, scheduler: Scheduler, latents: Tensor, noise: Tensor, step: int, classifier_free_guidance: bool = True
|
self, solver: Solver, latents: Tensor, noise: Tensor, step: int, classifier_free_guidance: bool = True
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
sag_mask = self.compute_sag_mask(latents=latents, classifier_free_guidance=classifier_free_guidance)
|
sag_mask = self.compute_sag_mask(latents=latents, classifier_free_guidance=classifier_free_guidance)
|
||||||
original_latents = scheduler.remove_noise(x=latents, noise=noise, step=step)
|
original_latents = solver.remove_noise(x=latents, noise=noise, step=step)
|
||||||
degraded_latents = gaussian_blur(original_latents, kernel_size=self.kernel_size, sigma=self.sigma)
|
degraded_latents = gaussian_blur(original_latents, kernel_size=self.kernel_size, sigma=self.sigma)
|
||||||
degraded_latents = degraded_latents * sag_mask + original_latents * (1 - sag_mask)
|
degraded_latents = degraded_latents * sag_mask + original_latents * (1 - sag_mask)
|
||||||
return scheduler.add_noise(degraded_latents, noise=noise, step=step)
|
return solver.add_noise(degraded_latents, noise=noise, step=step)
|
||||||
|
|
||||||
def init_context(self) -> Contexts:
|
def init_context(self) -> Contexts:
|
||||||
return {"self_attention_map": {"middle_block_attn_map": None, "middle_block_attn_shape": []}}
|
return {"self_attention_map": {"middle_block_attn_map": None, "middle_block_attn_shape": []}}
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
|
||||||
|
from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
|
||||||
|
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||||
|
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
|
||||||
|
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||||
|
|
||||||
|
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler"]
|
|
@ -1,9 +1,9 @@
|
||||||
from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
|
from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
||||||
|
|
||||||
|
|
||||||
class DDIM(Scheduler):
|
class DDIM(Solver):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
|
@ -25,15 +25,14 @@ class DDIM(Scheduler):
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
self.timesteps = self._generate_timesteps()
|
|
||||||
|
|
||||||
def _generate_timesteps(self) -> Tensor:
|
def _generate_timesteps(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Generates decreasing timesteps with 'leading' spacing and offset of 1
|
Generates decreasing timesteps with 'leading' spacing and offset of 1
|
||||||
similar to diffusers settings for the DDIM scheduler in Stable Diffusion 1.5
|
similar to diffusers settings for the DDIM solver in Stable Diffusion 1.5
|
||||||
"""
|
"""
|
||||||
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
||||||
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
|
timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio + 1
|
||||||
return timesteps.flip(0)
|
return timesteps.flip(0)
|
||||||
|
|
||||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
|
@ -1,9 +1,9 @@
|
||||||
from torch import Generator, Tensor, arange, device as Device, dtype as DType
|
from torch import Generator, Tensor, arange, device as Device, dtype as DType
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
||||||
|
|
||||||
|
|
||||||
class DDPM(Scheduler):
|
class DDPM(Solver):
|
||||||
"""
|
"""
|
||||||
Denoising Diffusion Probabilistic Model
|
Denoising Diffusion Probabilistic Model
|
||||||
|
|
|
@ -3,10 +3,10 @@ from collections import deque
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
||||||
|
|
||||||
|
|
||||||
class DPMSolver(Scheduler):
|
class DPMSolver(Solver):
|
||||||
"""
|
"""
|
||||||
Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
|
Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
|
||||||
|
|
||||||
|
@ -48,7 +48,6 @@ class DPMSolver(Scheduler):
|
||||||
# ...and we want the same result as the original codebase.
|
# ...and we want the same result as the original codebase.
|
||||||
return tensor(
|
return tensor(
|
||||||
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
|
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
|
||||||
device=self.device,
|
|
||||||
).flip(0)
|
).flip(0)
|
||||||
|
|
||||||
def rebuild(
|
def rebuild(
|
|
@ -2,10 +2,10 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
||||||
|
|
||||||
|
|
||||||
class EulerScheduler(Scheduler):
|
class Euler(Solver):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
|
@ -40,9 +40,7 @@ class EulerScheduler(Scheduler):
|
||||||
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||||
# torch.linspace(0,999,31)[15] is 499.5
|
# torch.linspace(0,999,31)[15] is 499.5
|
||||||
# ...and we want the same result as the original codebase.
|
# ...and we want the same result as the original codebase.
|
||||||
timesteps = torch.tensor(
|
timesteps = torch.tensor(np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps)).flip(0)
|
||||||
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device
|
|
||||||
).flip(0)
|
|
||||||
return timesteps
|
return timesteps
|
||||||
|
|
||||||
def _generate_sigmas(self) -> Tensor:
|
def _generate_sigmas(self) -> Tensor:
|
|
@ -4,7 +4,9 @@ from typing import TypeVar
|
||||||
|
|
||||||
from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
|
from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
|
||||||
|
|
||||||
T = TypeVar("T", bound="Scheduler")
|
from refiners.fluxion import layers as fl
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="Solver")
|
||||||
|
|
||||||
|
|
||||||
class NoiseSchedule(str, Enum):
|
class NoiseSchedule(str, Enum):
|
||||||
|
@ -13,11 +15,11 @@ class NoiseSchedule(str, Enum):
|
||||||
KARRAS = "karras"
|
KARRAS = "karras"
|
||||||
|
|
||||||
|
|
||||||
class Scheduler(ABC):
|
class Solver(fl.Module, ABC):
|
||||||
"""
|
"""
|
||||||
A base class for creating a diffusion model scheduler.
|
A base class for creating a diffusion model solver.
|
||||||
|
|
||||||
The Scheduler creates a sequence of noise and scaling factors used in the diffusion process,
|
Solver creates a sequence of noise and scaling factors used in the diffusion process,
|
||||||
which gradually transforms the original data distribution into a Gaussian one.
|
which gradually transforms the original data distribution into a Gaussian one.
|
||||||
|
|
||||||
This process is described using several parameters such as initial and final diffusion rates,
|
This process is described using several parameters such as initial and final diffusion rates,
|
||||||
|
@ -36,9 +38,8 @@ class Scheduler(ABC):
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = float32,
|
dtype: DType = float32,
|
||||||
):
|
) -> None:
|
||||||
self.device: Device = Device(device)
|
super().__init__()
|
||||||
self.dtype: DType = dtype
|
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
self.num_train_timesteps = num_train_timesteps
|
self.num_train_timesteps = num_train_timesteps
|
||||||
self.initial_diffusion_rate = initial_diffusion_rate
|
self.initial_diffusion_rate = initial_diffusion_rate
|
||||||
|
@ -50,6 +51,7 @@ class Scheduler(ABC):
|
||||||
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
|
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
|
||||||
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
|
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
|
||||||
self.timesteps = self._generate_timesteps()
|
self.timesteps = self._generate_timesteps()
|
||||||
|
self.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
|
@ -69,57 +71,6 @@ class Scheduler(ABC):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@property
|
|
||||||
def all_steps(self) -> list[int]:
|
|
||||||
return list(range(self.num_inference_steps))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def inference_steps(self) -> list[int]:
|
|
||||||
return self.all_steps[self.first_inference_step :]
|
|
||||||
|
|
||||||
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:
|
|
||||||
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
|
|
||||||
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
|
|
||||||
return self.__class__(
|
|
||||||
num_inference_steps=num_inference_steps,
|
|
||||||
num_train_timesteps=self.num_train_timesteps,
|
|
||||||
initial_diffusion_rate=self.initial_diffusion_rate,
|
|
||||||
final_diffusion_rate=self.final_diffusion_rate,
|
|
||||||
noise_schedule=self.noise_schedule,
|
|
||||||
first_inference_step=first_inference_step,
|
|
||||||
device=self.device,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
|
||||||
"""
|
|
||||||
For compatibility with schedulers that need to scale the input according to the current timestep.
|
|
||||||
"""
|
|
||||||
return x
|
|
||||||
|
|
||||||
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
|
|
||||||
return (
|
|
||||||
linspace(
|
|
||||||
start=self.initial_diffusion_rate ** (1 / power),
|
|
||||||
end=self.final_diffusion_rate ** (1 / power),
|
|
||||||
steps=self.num_train_timesteps,
|
|
||||||
device=self.device,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
** power
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample_noise_schedule(self) -> Tensor:
|
|
||||||
match self.noise_schedule:
|
|
||||||
case "uniform":
|
|
||||||
return 1 - self.sample_power_distribution(1)
|
|
||||||
case "quadratic":
|
|
||||||
return 1 - self.sample_power_distribution(2)
|
|
||||||
case "karras":
|
|
||||||
return 1 - self.sample_power_distribution(7)
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
|
|
||||||
|
|
||||||
def add_noise(
|
def add_noise(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
|
@ -141,14 +92,77 @@ class Scheduler(ABC):
|
||||||
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
|
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
|
||||||
return denoised_x
|
return denoised_x
|
||||||
|
|
||||||
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
|
@property
|
||||||
if device is not None:
|
def all_steps(self) -> list[int]:
|
||||||
self.device = Device(device)
|
return list(range(self.num_inference_steps))
|
||||||
self.timesteps = self.timesteps.to(device)
|
|
||||||
if dtype is not None:
|
@property
|
||||||
self.dtype = dtype
|
def inference_steps(self) -> list[int]:
|
||||||
self.scale_factors = self.scale_factors.to(device, dtype=dtype)
|
return self.all_steps[self.first_inference_step :]
|
||||||
self.cumulative_scale_factors = self.cumulative_scale_factors.to(device, dtype=dtype)
|
|
||||||
self.noise_std = self.noise_std.to(device, dtype=dtype)
|
@property
|
||||||
self.signal_to_noise_ratios = self.signal_to_noise_ratios.to(device, dtype=dtype)
|
def device(self) -> Device:
|
||||||
|
return self.scale_factors.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> DType:
|
||||||
|
return self.scale_factors.dtype
|
||||||
|
|
||||||
|
@device.setter
|
||||||
|
def device(self, device: Device | str | None = None) -> None:
|
||||||
|
self.to(device=device)
|
||||||
|
|
||||||
|
@dtype.setter
|
||||||
|
def dtype(self, dtype: DType | None = None) -> None:
|
||||||
|
self.to(dtype=dtype)
|
||||||
|
|
||||||
|
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:
|
||||||
|
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
|
||||||
|
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
|
||||||
|
return self.__class__(
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
num_train_timesteps=self.num_train_timesteps,
|
||||||
|
initial_diffusion_rate=self.initial_diffusion_rate,
|
||||||
|
final_diffusion_rate=self.final_diffusion_rate,
|
||||||
|
noise_schedule=self.noise_schedule,
|
||||||
|
first_inference_step=first_inference_step,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
For compatibility with solvers that need to scale the input according to the current timestep.
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
|
||||||
|
return (
|
||||||
|
linspace(
|
||||||
|
start=self.initial_diffusion_rate ** (1 / power),
|
||||||
|
end=self.final_diffusion_rate ** (1 / power),
|
||||||
|
steps=self.num_train_timesteps,
|
||||||
|
)
|
||||||
|
** power
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_noise_schedule(self) -> Tensor:
|
||||||
|
match self.noise_schedule:
|
||||||
|
case "uniform":
|
||||||
|
return 1 - self.sample_power_distribution(1)
|
||||||
|
case "quadratic":
|
||||||
|
return 1 - self.sample_power_distribution(2)
|
||||||
|
case "karras":
|
||||||
|
return 1 - self.sample_power_distribution(7)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
|
||||||
|
|
||||||
|
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
for name, attr in [(name, attr) for name, attr in self.__dict__.items() if isinstance(attr, Tensor)]:
|
||||||
|
match name:
|
||||||
|
case "timesteps":
|
||||||
|
setattr(self, name, attr.to(device=device))
|
||||||
|
case _:
|
||||||
|
setattr(self, name, attr.to(device=device, dtype=dtype))
|
||||||
return self
|
return self
|
|
@ -7,8 +7,8 @@ from refiners.fluxion.utils import image_to_tensor, interpolate
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||||
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.self_attention_guidance import SD1SAGAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.self_attention_guidance import SD1SAGAdapter
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
|
||||||
|
@ -26,20 +26,20 @@ class StableDiffusion_1(LatentDiffusionModel):
|
||||||
unet: SD1UNet | None = None,
|
unet: SD1UNet | None = None,
|
||||||
lda: SD1Autoencoder | None = None,
|
lda: SD1Autoencoder | None = None,
|
||||||
clip_text_encoder: CLIPTextEncoderL | None = None,
|
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||||
scheduler: Scheduler | None = None,
|
solver: Solver | None = None,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = torch.float32,
|
dtype: DType = torch.float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
unet = unet or SD1UNet(in_channels=4)
|
unet = unet or SD1UNet(in_channels=4)
|
||||||
lda = lda or SD1Autoencoder()
|
lda = lda or SD1Autoencoder()
|
||||||
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
|
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
|
||||||
scheduler = scheduler or DPMSolver(num_inference_steps=30)
|
solver = solver or DPMSolver(num_inference_steps=30)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
unet=unet,
|
unet=unet,
|
||||||
lda=lda,
|
lda=lda,
|
||||||
clip_text_encoder=clip_text_encoder,
|
clip_text_encoder=clip_text_encoder,
|
||||||
scheduler=scheduler,
|
solver=solver,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
@ -82,14 +82,14 @@ class StableDiffusion_1(LatentDiffusionModel):
|
||||||
assert sag is not None
|
assert sag is not None
|
||||||
|
|
||||||
degraded_latents = sag.compute_degraded_latents(
|
degraded_latents = sag.compute_degraded_latents(
|
||||||
scheduler=self.scheduler,
|
solver=self.solver,
|
||||||
latents=x,
|
latents=x,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
step=step,
|
step=step,
|
||||||
classifier_free_guidance=True,
|
classifier_free_guidance=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
timestep = self.solver.timesteps[step].unsqueeze(dim=0)
|
||||||
negative_embedding, _ = clip_text_embedding.chunk(2)
|
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||||
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
|
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
|
||||||
if "ip_adapter" in self.unet.provider.contexts:
|
if "ip_adapter" in self.unet.provider.contexts:
|
||||||
|
@ -111,14 +111,14 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
||||||
unet: SD1UNet | None = None,
|
unet: SD1UNet | None = None,
|
||||||
lda: SD1Autoencoder | None = None,
|
lda: SD1Autoencoder | None = None,
|
||||||
clip_text_encoder: CLIPTextEncoderL | None = None,
|
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||||
scheduler: Scheduler | None = None,
|
solver: Solver | None = None,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = torch.float32,
|
dtype: DType = torch.float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.mask_latents: Tensor | None = None
|
self.mask_latents: Tensor | None = None
|
||||||
self.target_image_latents: Tensor | None = None
|
self.target_image_latents: Tensor | None = None
|
||||||
super().__init__(
|
super().__init__(
|
||||||
unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler, device=device, dtype=dtype
|
unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, solver=solver, device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -162,7 +162,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
||||||
assert self.target_image_latents is not None
|
assert self.target_image_latents is not None
|
||||||
|
|
||||||
degraded_latents = sag.compute_degraded_latents(
|
degraded_latents = sag.compute_degraded_latents(
|
||||||
scheduler=self.scheduler,
|
solver=self.solver,
|
||||||
latents=x,
|
latents=x,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
step=step,
|
step=step,
|
||||||
|
@ -173,7 +173,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
timestep = self.solver.timesteps[step].unsqueeze(dim=0)
|
||||||
negative_embedding, _ = clip_text_embedding.chunk(2)
|
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||||
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
|
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,8 @@ from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||||
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.self_attention_guidance import SDXLSAGAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.self_attention_guidance import SDXLSAGAdapter
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
@ -23,20 +23,20 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
||||||
unet: SDXLUNet | None = None,
|
unet: SDXLUNet | None = None,
|
||||||
lda: SDXLAutoencoder | None = None,
|
lda: SDXLAutoencoder | None = None,
|
||||||
clip_text_encoder: DoubleTextEncoder | None = None,
|
clip_text_encoder: DoubleTextEncoder | None = None,
|
||||||
scheduler: Scheduler | None = None,
|
solver: Solver | None = None,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = torch.float32,
|
dtype: DType = torch.float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
unet = unet or SDXLUNet(in_channels=4)
|
unet = unet or SDXLUNet(in_channels=4)
|
||||||
lda = lda or SDXLAutoencoder()
|
lda = lda or SDXLAutoencoder()
|
||||||
clip_text_encoder = clip_text_encoder or DoubleTextEncoder()
|
clip_text_encoder = clip_text_encoder or DoubleTextEncoder()
|
||||||
scheduler = scheduler or DDIM(num_inference_steps=30)
|
solver = solver or DDIM(num_inference_steps=30)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
unet=unet,
|
unet=unet,
|
||||||
lda=lda,
|
lda=lda,
|
||||||
clip_text_encoder=clip_text_encoder,
|
clip_text_encoder=clip_text_encoder,
|
||||||
scheduler=scheduler,
|
solver=solver,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
@ -131,7 +131,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
||||||
assert sag is not None
|
assert sag is not None
|
||||||
|
|
||||||
degraded_latents = sag.compute_degraded_latents(
|
degraded_latents = sag.compute_degraded_latents(
|
||||||
scheduler=self.scheduler,
|
solver=self.solver,
|
||||||
latents=x,
|
latents=x,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
step=step,
|
step=step,
|
||||||
|
@ -140,7 +140,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
||||||
|
|
||||||
negative_text_embedding, _ = clip_text_embedding.chunk(2)
|
negative_text_embedding, _ = clip_text_embedding.chunk(2)
|
||||||
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
|
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
|
||||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
timestep = self.solver.timesteps[step].unsqueeze(dim=0)
|
||||||
time_ids, _ = time_ids.chunk(2)
|
time_ids, _ = time_ids.chunk(2)
|
||||||
|
|
||||||
self.set_unet_context(
|
self.set_unet_context(
|
||||||
|
|
|
@ -19,7 +19,8 @@ from refiners.foundationals.latent_diffusion import (
|
||||||
SD1UNet,
|
SD1UNet,
|
||||||
StableDiffusion_1,
|
StableDiffusion_1,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.schedulers import DDPM
|
from refiners.foundationals.latent_diffusion.solvers import DDPM
|
||||||
|
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder
|
||||||
from refiners.training_utils.callback import Callback
|
from refiners.training_utils.callback import Callback
|
||||||
from refiners.training_utils.config import BaseConfig
|
from refiners.training_utils.config import BaseConfig
|
||||||
|
@ -150,7 +151,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
||||||
return TextEmbeddingLatentsDataset(trainer=self)
|
return TextEmbeddingLatentsDataset(trainer=self)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def ddpm_scheduler(self) -> DDPM:
|
def ddpm_solver(self) -> Solver:
|
||||||
return DDPM(
|
return DDPM(
|
||||||
num_inference_steps=1000,
|
num_inference_steps=1000,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
@ -159,7 +160,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
||||||
def sample_timestep(self) -> Tensor:
|
def sample_timestep(self) -> Tensor:
|
||||||
random_step = random.randint(a=self.config.latent_diffusion.min_step, b=self.config.latent_diffusion.max_step)
|
random_step = random.randint(a=self.config.latent_diffusion.min_step, b=self.config.latent_diffusion.max_step)
|
||||||
self.current_step = random_step
|
self.current_step = random_step
|
||||||
return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0)
|
return self.ddpm_solver.timesteps[random_step].unsqueeze(dim=0)
|
||||||
|
|
||||||
def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Tensor:
|
def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Tensor:
|
||||||
return sample_noise(
|
return sample_noise(
|
||||||
|
@ -170,7 +171,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
||||||
clip_text_embedding, latents = batch.text_embeddings, batch.latents
|
clip_text_embedding, latents = batch.text_embeddings, batch.latents
|
||||||
timestep = self.sample_timestep()
|
timestep = self.sample_timestep()
|
||||||
noise = self.sample_noise(size=latents.shape, dtype=latents.dtype)
|
noise = self.sample_noise(size=latents.shape, dtype=latents.dtype)
|
||||||
noisy_latents = self.ddpm_scheduler.add_noise(x=latents, noise=noise, step=self.current_step)
|
noisy_latents = self.ddpm_solver.add_noise(x=latents, noise=noise, step=self.current_step)
|
||||||
self.unet.set_timestep(timestep=timestep)
|
self.unet.set_timestep(timestep=timestep)
|
||||||
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||||
prediction = self.unet(noisy_latents)
|
prediction = self.unet(noisy_latents)
|
||||||
|
@ -182,7 +183,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
||||||
unet=self.unet,
|
unet=self.unet,
|
||||||
lda=self.lda,
|
lda=self.lda,
|
||||||
clip_text_encoder=self.text_encoder,
|
clip_text_encoder=self.text_encoder,
|
||||||
scheduler=DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps),
|
solver=DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
prompts = self.config.test_diffusion.prompts
|
prompts = self.config.test_diffusion.prompts
|
||||||
|
|
|
@ -24,8 +24,8 @@ from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
||||||
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
|
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
|
||||||
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
||||||
from refiners.foundationals.latent_diffusion.restart import Restart
|
from refiners.foundationals.latent_diffusion.restart import Restart
|
||||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM, EulerScheduler
|
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule
|
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||||
from tests.utils import ensure_similar_images
|
from tests.utils import ensure_similar_images
|
||||||
|
@ -491,8 +491,8 @@ def sd15_ddim(
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
|
||||||
ddim_scheduler = DDIM(num_inference_steps=20)
|
ddim_solver = DDIM(num_inference_steps=20)
|
||||||
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device)
|
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||||
sd15.lda.load_from_safetensors(lda_weights)
|
sd15.lda.load_from_safetensors(lda_weights)
|
||||||
|
@ -509,8 +509,8 @@ def sd15_ddim_karras(
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
|
||||||
ddim_scheduler = DDIM(num_inference_steps=20, noise_schedule=NoiseSchedule.KARRAS)
|
ddim_solver = DDIM(num_inference_steps=20, noise_schedule=NoiseSchedule.KARRAS)
|
||||||
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device)
|
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||||
sd15.lda.load_from_safetensors(lda_weights)
|
sd15.lda.load_from_safetensors(lda_weights)
|
||||||
|
@ -527,8 +527,8 @@ def sd15_euler(
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
|
||||||
euler_scheduler = EulerScheduler(num_inference_steps=30)
|
euler_solver = Euler(num_inference_steps=30)
|
||||||
sd15 = StableDiffusion_1(scheduler=euler_scheduler, device=test_device)
|
sd15 = StableDiffusion_1(solver=euler_solver, device=test_device)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||||
sd15.lda.load_from_safetensors(lda_weights)
|
sd15.lda.load_from_safetensors(lda_weights)
|
||||||
|
@ -545,8 +545,8 @@ def sd15_ddim_lda_ft_mse(
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
|
||||||
ddim_scheduler = DDIM(num_inference_steps=20)
|
ddim_solver = DDIM(num_inference_steps=20)
|
||||||
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device)
|
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||||
sd15.lda.load_state_dict(load_from_safetensors(lda_ft_mse_weights))
|
sd15.lda.load_state_dict(load_from_safetensors(lda_ft_mse_weights))
|
||||||
|
@ -599,8 +599,8 @@ def sdxl_ddim(
|
||||||
warn(message="not running on CPU, skipping")
|
warn(message="not running on CPU, skipping")
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
|
||||||
scheduler = DDIM(num_inference_steps=30)
|
solver = DDIM(num_inference_steps=30)
|
||||||
sdxl = StableDiffusion_XL(scheduler=scheduler, device=test_device)
|
sdxl = StableDiffusion_XL(solver=solver, device=test_device)
|
||||||
|
|
||||||
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
|
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
|
||||||
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_weights)
|
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_weights)
|
||||||
|
@ -617,8 +617,8 @@ def sdxl_ddim_lda_fp16_fix(
|
||||||
warn(message="not running on CPU, skipping")
|
warn(message="not running on CPU, skipping")
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
|
||||||
scheduler = DDIM(num_inference_steps=30)
|
solver = DDIM(num_inference_steps=30)
|
||||||
sdxl = StableDiffusion_XL(scheduler=scheduler, device=test_device)
|
sdxl = StableDiffusion_XL(solver=solver, device=test_device)
|
||||||
|
|
||||||
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
|
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
|
||||||
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights)
|
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights)
|
||||||
|
@ -659,8 +659,8 @@ def test_diffusion_std_random_init_euler(
|
||||||
sd15_euler: StableDiffusion_1, expected_image_std_random_init_euler: Image.Image, test_device: torch.device
|
sd15_euler: StableDiffusion_1, expected_image_std_random_init_euler: Image.Image, test_device: torch.device
|
||||||
):
|
):
|
||||||
sd15 = sd15_euler
|
sd15 = sd15_euler
|
||||||
euler_scheduler = sd15_euler.scheduler
|
euler_solver = sd15_euler.solver
|
||||||
assert isinstance(euler_scheduler, EulerScheduler)
|
assert isinstance(euler_solver, Euler)
|
||||||
|
|
||||||
prompt = "a cute cat, detailed high-quality professional image"
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
@ -670,7 +670,7 @@ def test_diffusion_std_random_init_euler(
|
||||||
|
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
x = x * euler_scheduler.init_noise_sigma
|
x = x * euler_solver.init_noise_sigma
|
||||||
|
|
||||||
for step in sd15.steps:
|
for step in sd15.steps:
|
||||||
x = sd15(
|
x = sd15(
|
||||||
|
@ -1202,7 +1202,7 @@ def test_diffusion_refonly(
|
||||||
|
|
||||||
for step in sd15.steps:
|
for step in sd15.steps:
|
||||||
noise = torch.randn(2, 4, 64, 64, device=test_device)
|
noise = torch.randn(2, 4, 64, 64, device=test_device)
|
||||||
noised_guide = sd15.scheduler.add_noise(guide, noise, step)
|
noised_guide = sd15.solver.add_noise(guide, noise, step)
|
||||||
refonly_adapter.set_controlnet_condition(noised_guide)
|
refonly_adapter.set_controlnet_condition(noised_guide)
|
||||||
x = sd15(
|
x = sd15(
|
||||||
x,
|
x,
|
||||||
|
@ -1244,7 +1244,7 @@ def test_diffusion_inpainting_refonly(
|
||||||
|
|
||||||
for step in sd15.steps:
|
for step in sd15.steps:
|
||||||
noise = torch.randn_like(guide)
|
noise = torch.randn_like(guide)
|
||||||
noised_guide = sd15.scheduler.add_noise(guide, noise, step)
|
noised_guide = sd15.solver.add_noise(guide, noise, step)
|
||||||
# See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support
|
# See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support
|
||||||
# inpaint variation models")
|
# inpaint variation models")
|
||||||
noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1)
|
noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1)
|
||||||
|
|
|
@ -5,7 +5,7 @@ import pytest
|
||||||
from torch import Tensor, allclose, device as Device, equal, isclose, randn
|
from torch import Tensor, allclose, device as Device, equal, isclose, randn
|
||||||
|
|
||||||
from refiners.fluxion import manual_seed
|
from refiners.fluxion import manual_seed
|
||||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver, EulerScheduler
|
from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler
|
||||||
|
|
||||||
|
|
||||||
def test_ddpm_diffusers():
|
def test_ddpm_diffusers():
|
||||||
|
@ -83,7 +83,7 @@ def test_euler_diffusers():
|
||||||
use_karras_sigmas=False,
|
use_karras_sigmas=False,
|
||||||
)
|
)
|
||||||
diffusers_scheduler.set_timesteps(30)
|
diffusers_scheduler.set_timesteps(30)
|
||||||
refiners_scheduler = EulerScheduler(num_inference_steps=30)
|
refiners_scheduler = Euler(num_inference_steps=30)
|
||||||
|
|
||||||
sample = randn(1, 4, 32, 32)
|
sample = randn(1, 4, 32, 32)
|
||||||
predicted_noise = randn(1, 4, 32, 32)
|
predicted_noise = randn(1, 4, 32, 32)
|
Loading…
Reference in a new issue