add docstrings for LCM / LCM-LoRA

This commit is contained in:
Pierre Chapuis 2024-02-20 16:22:08 +01:00
parent 383c3c8a04
commit 684e2b9a47
9 changed files with 61 additions and 19 deletions

View file

@ -8,7 +8,7 @@ from torch import nn
from refiners.fluxion.model_converter import ModelConverter
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import LcmAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter
class Args(argparse.Namespace):
@ -39,7 +39,7 @@ def setup_converter(args: Args) -> ModelConverter:
if source_is_lcm:
assert isinstance(target, SDXLUNet)
LcmAdapter(target=target).inject()
SDXLLcmAdapter(target=target).inject()
x = torch.randn(1, source_in_channels, 32, 32)
timestep = torch.tensor(data=[0])

View file

@ -1,5 +1,5 @@
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.adapters.lora import Conv2dLora, LinearLora, Lora, LoraAdapter
from refiners.fluxion.adapters.lora import Conv2dLora, LinearLora, Lora, LoraAdapter, auto_attach_loras
__all__ = [
"Adapter",
@ -7,4 +7,5 @@ __all__ = [
"LinearLora",
"Conv2dLora",
"LoraAdapter",
"auto_attach_loras",
]

View file

@ -462,7 +462,7 @@ def auto_attach_loras(
target: The target Chain.
include: A list of layer names, only layers with such a layer in its parents will be considered.
exclude: A list of layer names, layers with such a layer in its parents will not be considered.
debug_map: Pass a list to get a debug mapping of key - path pairs.
debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.
Returns:
A list of keys of LoRA layers which failed to attach.

View file

@ -5,7 +5,7 @@ from refiners.foundationals.latent_diffusion.auto_encoder import (
LatentDiffusionAutoencoder,
)
from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter
from refiners.foundationals.latent_diffusion.solvers import DPMSolver, Solver
from refiners.foundationals.latent_diffusion.solvers import DPMSolver, LCMSolver, Solver
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
SD1ControlnetAdapter,
SD1IPAdapter,
@ -18,6 +18,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
ControlLoraAdapter,
DoubleTextEncoder,
SDXLIPAdapter,
SDXLLcmAdapter,
SDXLT2IAdapter,
SDXLUNet,
StableDiffusion_XL,
@ -34,8 +35,10 @@ __all__ = [
"SDXLUNet",
"DoubleTextEncoder",
"SDXLIPAdapter",
"SDXLLcmAdapter",
"SDXLT2IAdapter",
"DPMSolver",
"LCMSolver",
"Solver",
"CLIPTextEncoderL",
"LatentDiffusionAutoencoder",

View file

@ -6,6 +6,16 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
class LCMSolver(Solver):
"""Latent Consistency Model solver.
This solver is designed for use either with
[a specific base model][refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm.SDXLLcmAdapter]
or [a specific LoRA][refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm_lora.add_lcm_lora].
See [[arXiv:2310.04378] Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378)
for details.
"""
def __init__(
self,
num_inference_steps: int,

View file

@ -1,5 +1,7 @@
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ControlLora, ControlLoraAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm_lora import add_lcm_lora
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import SDXLAutoencoder, StableDiffusion_XL
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
@ -11,7 +13,9 @@ __all__ = [
"DoubleTextEncoder",
"SDXLAutoencoder",
"SDXLIPAdapter",
"SDXLLcmAdapter",
"SDXLT2IAdapter",
"ControlLora",
"ControlLoraAdapter",
"add_lcm_lora",
]

View file

@ -44,13 +44,25 @@ class ResidualBlock(fl.Residual):
)
class LcmAdapter(fl.Chain, Adapter[SDXLUNet]):
class SDXLLcmAdapter(fl.Chain, Adapter[SDXLUNet]):
def __init__(
self,
target: SDXLUNet,
condition_scale_embedding_dim: int = 256,
condition_scale: float = 7.5,
) -> None:
"""Adapt [the SDXl UNet][refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet.SDXLUNet]
for use with [LCMSolver][refiners.foundationals.latent_diffusion.solvers.lcm.LCMSolver].
Note that LCM must be used *without* CFG. You can disable CFG on SD by setting the
`classifier_free_guidance` attribute to `False`.
Args:
target: A SDXL UNet.
condition_scale_embedding_dim: LCM uses a condition scale embedding, this is its dimension.
condition_scale: Because of the embedding, the condition scale must be passed to this adapter
instead of SD. The condition scale passed to SD will be ignored.
"""
assert condition_scale_embedding_dim % 2 == 0
self.condition_scale_embedding_dim = condition_scale_embedding_dim
self.condition_scale = condition_scale
@ -71,7 +83,7 @@ class LcmAdapter(fl.Chain, Adapter[SDXLUNet]):
self.condition_scale = scale
self.set_context("lcm", {"condition_scale_embedding": self.sinusoidal_embedding})
def inject(self: "LcmAdapter", parent: fl.Chain | None = None) -> "LcmAdapter":
def inject(self: "SDXLLcmAdapter", parent: fl.Chain | None = None) -> "SDXLLcmAdapter":
ra = self.target.ensure_find(RangeEncoder)
block = ResidualBlock(
in_channels=self.condition_scale_embedding_dim,

View file

@ -1,9 +1,8 @@
import torch
from refiners.fluxion.adapters.lora import Lora, auto_attach_loras
from .lora import SDLoraManager
from .stable_diffusion_xl import StableDiffusion_XL
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
def _check_validity(debug_map: list[tuple[str, str]]):
@ -25,14 +24,27 @@ def _check_validity(debug_map: list[tuple[str, str]]):
def add_lcm_lora(
manager: SDLoraManager,
name: str,
tensors: dict[str, torch.Tensor],
name: str = "lcm",
scale: float = 1.0 / 8.0,
check_validity: bool = True,
) -> None:
# This is a complex LoRA so SDLoraManager.add_lora() is not enough.
# Instead, we add the LoRAs to the UNet in several iterations, using
# the filtering mechanism of `auto_attach_loras`.
"""Add a LCM LoRA to SDXLUNet.
This is a complex LoRA so [SDLoraManager.add_loras()][refiners.foundationals.latent_diffusion.lora.SDLoraManager.add_loras]
is not enough. Instead, we add the LoRAs to the UNet in several iterations, using the filtering mechanism of
[auto_attach_loras][refiners.fluxion.adapters.lora.auto_attach_loras].
This LoRA can be used with or without CFG in SD.
If you use CFG, typical values range from 1.0 (same as no CFG) to 2.0.
Args:
manager: A SDLoraManager for SDXL
tensors: The `state_dict` of the LCM LoRA
name: The name of the LoRA.
scale: The scale to use for the LoRA (should generally not be changed).
check_validity: Perform additional checks, raise an exception if they fail.
"""
assert isinstance(manager.target, StableDiffusion_XL)
unet = manager.target.unet

View file

@ -7,10 +7,10 @@ import torch
from PIL import Image
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
from refiners.foundationals.latent_diffusion.lcm_lora import add_lcm_lora
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
from refiners.foundationals.latent_diffusion.solvers import LCMSolver
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import LcmAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm_lora import add_lcm_lora
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
from tests.utils import ensure_similar_images
@ -105,7 +105,7 @@ def test_lcm_base(
# With standard LCM the condition scale is passed to the adapter,
# not in the diffusion loop.
LcmAdapter(sdxl.unet, condition_scale=8.0).inject()
SDXLLcmAdapter(sdxl.unet, condition_scale=8.0).inject()
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
@ -158,7 +158,7 @@ def test_lcm_lora_with_guidance(
sdxl.unet.load_from_safetensors(sdxl_unet_weights)
manager = SDLoraManager(sdxl)
add_lcm_lora(manager, "lcm", load_from_safetensors(sdxl_lcm_lora_weights))
add_lcm_lora(manager, load_from_safetensors(sdxl_lcm_lora_weights))
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
expected_image = expected_lcm_lora_1_0 if condition_scale == 1.0 else expected_lcm_lora_1_2
@ -208,7 +208,7 @@ def test_lcm_lora_without_guidance(
sdxl.unet.load_from_safetensors(sdxl_unet_weights)
manager = SDLoraManager(sdxl)
add_lcm_lora(manager, "lcm", load_from_safetensors(sdxl_lcm_lora_weights))
add_lcm_lora(manager, load_from_safetensors(sdxl_lcm_lora_weights))
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
expected_image = expected_lcm_lora_1_0