diff --git a/scripts/conversion/convert_diffusers_unet.py b/scripts/conversion/convert_diffusers_unet.py index 891de1a..d46f7e1 100644 --- a/scripts/conversion/convert_diffusers_unet.py +++ b/scripts/conversion/convert_diffusers_unet.py @@ -8,7 +8,7 @@ from torch import nn from refiners.fluxion.model_converter import ModelConverter from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import LcmAdapter +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter class Args(argparse.Namespace): @@ -39,7 +39,7 @@ def setup_converter(args: Args) -> ModelConverter: if source_is_lcm: assert isinstance(target, SDXLUNet) - LcmAdapter(target=target).inject() + SDXLLcmAdapter(target=target).inject() x = torch.randn(1, source_in_channels, 32, 32) timestep = torch.tensor(data=[0]) diff --git a/src/refiners/fluxion/adapters/__init__.py b/src/refiners/fluxion/adapters/__init__.py index 8d8c766..727006c 100644 --- a/src/refiners/fluxion/adapters/__init__.py +++ b/src/refiners/fluxion/adapters/__init__.py @@ -1,5 +1,5 @@ from refiners.fluxion.adapters.adapter import Adapter -from refiners.fluxion.adapters.lora import Conv2dLora, LinearLora, Lora, LoraAdapter +from refiners.fluxion.adapters.lora import Conv2dLora, LinearLora, Lora, LoraAdapter, auto_attach_loras __all__ = [ "Adapter", @@ -7,4 +7,5 @@ __all__ = [ "LinearLora", "Conv2dLora", "LoraAdapter", + "auto_attach_loras", ] diff --git a/src/refiners/fluxion/adapters/lora.py b/src/refiners/fluxion/adapters/lora.py index 1743bc3..802ee61 100644 --- a/src/refiners/fluxion/adapters/lora.py +++ b/src/refiners/fluxion/adapters/lora.py @@ -462,7 +462,7 @@ def auto_attach_loras( target: The target Chain. include: A list of layer names, only layers with such a layer in its parents will be considered. exclude: A list of layer names, layers with such a layer in its parents will not be considered. - debug_map: Pass a list to get a debug mapping of key - path pairs. + debug_map: Pass a list to get a debug mapping of key - path pairs of attached points. Returns: A list of keys of LoRA layers which failed to attach. diff --git a/src/refiners/foundationals/latent_diffusion/__init__.py b/src/refiners/foundationals/latent_diffusion/__init__.py index 208a4fb..eab4071 100644 --- a/src/refiners/foundationals/latent_diffusion/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/__init__.py @@ -5,7 +5,7 @@ from refiners.foundationals.latent_diffusion.auto_encoder import ( LatentDiffusionAutoencoder, ) from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter -from refiners.foundationals.latent_diffusion.solvers import DPMSolver, Solver +from refiners.foundationals.latent_diffusion.solvers import DPMSolver, LCMSolver, Solver from refiners.foundationals.latent_diffusion.stable_diffusion_1 import ( SD1ControlnetAdapter, SD1IPAdapter, @@ -18,6 +18,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl import ( ControlLoraAdapter, DoubleTextEncoder, SDXLIPAdapter, + SDXLLcmAdapter, SDXLT2IAdapter, SDXLUNet, StableDiffusion_XL, @@ -34,8 +35,10 @@ __all__ = [ "SDXLUNet", "DoubleTextEncoder", "SDXLIPAdapter", + "SDXLLcmAdapter", "SDXLT2IAdapter", "DPMSolver", + "LCMSolver", "Solver", "CLIPTextEncoderL", "LatentDiffusionAutoencoder", diff --git a/src/refiners/foundationals/latent_diffusion/solvers/lcm.py b/src/refiners/foundationals/latent_diffusion/solvers/lcm.py index d0189c0..00d6b50 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/lcm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/lcm.py @@ -6,6 +6,16 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule class LCMSolver(Solver): + """Latent Consistency Model solver. + + This solver is designed for use either with + [a specific base model][refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm.SDXLLcmAdapter] + or [a specific LoRA][refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm_lora.add_lcm_lora]. + + See [[arXiv:2310.04378] Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) + for details. + """ + def __init__( self, num_inference_steps: int, diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py index ce9d790..9031c6d 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py @@ -1,5 +1,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ControlLora, ControlLoraAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm_lora import add_lcm_lora from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import SDXLAutoencoder, StableDiffusion_XL from refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder @@ -11,7 +13,9 @@ __all__ = [ "DoubleTextEncoder", "SDXLAutoencoder", "SDXLIPAdapter", + "SDXLLcmAdapter", "SDXLT2IAdapter", "ControlLora", "ControlLoraAdapter", + "add_lcm_lora", ] diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm.py index 37137bf..abbd1d3 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm.py @@ -44,13 +44,25 @@ class ResidualBlock(fl.Residual): ) -class LcmAdapter(fl.Chain, Adapter[SDXLUNet]): +class SDXLLcmAdapter(fl.Chain, Adapter[SDXLUNet]): def __init__( self, target: SDXLUNet, condition_scale_embedding_dim: int = 256, condition_scale: float = 7.5, ) -> None: + """Adapt [the SDXl UNet][refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet.SDXLUNet] + for use with [LCMSolver][refiners.foundationals.latent_diffusion.solvers.lcm.LCMSolver]. + + Note that LCM must be used *without* CFG. You can disable CFG on SD by setting the + `classifier_free_guidance` attribute to `False`. + + Args: + target: A SDXL UNet. + condition_scale_embedding_dim: LCM uses a condition scale embedding, this is its dimension. + condition_scale: Because of the embedding, the condition scale must be passed to this adapter + instead of SD. The condition scale passed to SD will be ignored. + """ assert condition_scale_embedding_dim % 2 == 0 self.condition_scale_embedding_dim = condition_scale_embedding_dim self.condition_scale = condition_scale @@ -71,7 +83,7 @@ class LcmAdapter(fl.Chain, Adapter[SDXLUNet]): self.condition_scale = scale self.set_context("lcm", {"condition_scale_embedding": self.sinusoidal_embedding}) - def inject(self: "LcmAdapter", parent: fl.Chain | None = None) -> "LcmAdapter": + def inject(self: "SDXLLcmAdapter", parent: fl.Chain | None = None) -> "SDXLLcmAdapter": ra = self.target.ensure_find(RangeEncoder) block = ResidualBlock( in_channels=self.condition_scale_embedding_dim, diff --git a/src/refiners/foundationals/latent_diffusion/lcm_lora.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm_lora.py similarity index 75% rename from src/refiners/foundationals/latent_diffusion/lcm_lora.py rename to src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm_lora.py index 912ce17..6d714fb 100644 --- a/src/refiners/foundationals/latent_diffusion/lcm_lora.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm_lora.py @@ -1,9 +1,8 @@ import torch from refiners.fluxion.adapters.lora import Lora, auto_attach_loras - -from .lora import SDLoraManager -from .stable_diffusion_xl import StableDiffusion_XL +from refiners.foundationals.latent_diffusion.lora import SDLoraManager +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL def _check_validity(debug_map: list[tuple[str, str]]): @@ -25,14 +24,27 @@ def _check_validity(debug_map: list[tuple[str, str]]): def add_lcm_lora( manager: SDLoraManager, - name: str, tensors: dict[str, torch.Tensor], + name: str = "lcm", scale: float = 1.0 / 8.0, check_validity: bool = True, ) -> None: - # This is a complex LoRA so SDLoraManager.add_lora() is not enough. - # Instead, we add the LoRAs to the UNet in several iterations, using - # the filtering mechanism of `auto_attach_loras`. + """Add a LCM LoRA to SDXLUNet. + + This is a complex LoRA so [SDLoraManager.add_loras()][refiners.foundationals.latent_diffusion.lora.SDLoraManager.add_loras] + is not enough. Instead, we add the LoRAs to the UNet in several iterations, using the filtering mechanism of + [auto_attach_loras][refiners.fluxion.adapters.lora.auto_attach_loras]. + + This LoRA can be used with or without CFG in SD. + If you use CFG, typical values range from 1.0 (same as no CFG) to 2.0. + + Args: + manager: A SDLoraManager for SDXL + tensors: The `state_dict` of the LCM LoRA + name: The name of the LoRA. + scale: The scale to use for the LoRA (should generally not be changed). + check_validity: Perform additional checks, raise an exception if they fail. + """ assert isinstance(manager.target, StableDiffusion_XL) unet = manager.target.unet diff --git a/tests/e2e/test_lcm.py b/tests/e2e/test_lcm.py index df7aa4d..205b608 100644 --- a/tests/e2e/test_lcm.py +++ b/tests/e2e/test_lcm.py @@ -7,10 +7,10 @@ import torch from PIL import Image from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad -from refiners.foundationals.latent_diffusion.lcm_lora import add_lcm_lora from refiners.foundationals.latent_diffusion.lora import SDLoraManager from refiners.foundationals.latent_diffusion.solvers import LCMSolver -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import LcmAdapter +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm_lora import add_lcm_lora from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from tests.utils import ensure_similar_images @@ -105,7 +105,7 @@ def test_lcm_base( # With standard LCM the condition scale is passed to the adapter, # not in the diffusion loop. - LcmAdapter(sdxl.unet, condition_scale=8.0).inject() + SDXLLcmAdapter(sdxl.unet, condition_scale=8.0).inject() sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights) sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights) @@ -158,7 +158,7 @@ def test_lcm_lora_with_guidance( sdxl.unet.load_from_safetensors(sdxl_unet_weights) manager = SDLoraManager(sdxl) - add_lcm_lora(manager, "lcm", load_from_safetensors(sdxl_lcm_lora_weights)) + add_lcm_lora(manager, load_from_safetensors(sdxl_lcm_lora_weights)) prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" expected_image = expected_lcm_lora_1_0 if condition_scale == 1.0 else expected_lcm_lora_1_2 @@ -208,7 +208,7 @@ def test_lcm_lora_without_guidance( sdxl.unet.load_from_safetensors(sdxl_unet_weights) manager = SDLoraManager(sdxl) - add_lcm_lora(manager, "lcm", load_from_safetensors(sdxl_lcm_lora_weights)) + add_lcm_lora(manager, load_from_safetensors(sdxl_lcm_lora_weights)) prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" expected_image = expected_lcm_lora_1_0