mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
add docstrings for LCM / LCM-LoRA
This commit is contained in:
parent
383c3c8a04
commit
684e2b9a47
|
@ -8,7 +8,7 @@ from torch import nn
|
|||
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import LcmAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter
|
||||
|
||||
|
||||
class Args(argparse.Namespace):
|
||||
|
@ -39,7 +39,7 @@ def setup_converter(args: Args) -> ModelConverter:
|
|||
|
||||
if source_is_lcm:
|
||||
assert isinstance(target, SDXLUNet)
|
||||
LcmAdapter(target=target).inject()
|
||||
SDXLLcmAdapter(target=target).inject()
|
||||
|
||||
x = torch.randn(1, source_in_channels, 32, 32)
|
||||
timestep = torch.tensor(data=[0])
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
from refiners.fluxion.adapters.lora import Conv2dLora, LinearLora, Lora, LoraAdapter
|
||||
from refiners.fluxion.adapters.lora import Conv2dLora, LinearLora, Lora, LoraAdapter, auto_attach_loras
|
||||
|
||||
__all__ = [
|
||||
"Adapter",
|
||||
|
@ -7,4 +7,5 @@ __all__ = [
|
|||
"LinearLora",
|
||||
"Conv2dLora",
|
||||
"LoraAdapter",
|
||||
"auto_attach_loras",
|
||||
]
|
||||
|
|
|
@ -462,7 +462,7 @@ def auto_attach_loras(
|
|||
target: The target Chain.
|
||||
include: A list of layer names, only layers with such a layer in its parents will be considered.
|
||||
exclude: A list of layer names, layers with such a layer in its parents will not be considered.
|
||||
debug_map: Pass a list to get a debug mapping of key - path pairs.
|
||||
debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.
|
||||
|
||||
Returns:
|
||||
A list of keys of LoRA layers which failed to attach.
|
||||
|
|
|
@ -5,7 +5,7 @@ from refiners.foundationals.latent_diffusion.auto_encoder import (
|
|||
LatentDiffusionAutoencoder,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter
|
||||
from refiners.foundationals.latent_diffusion.solvers import DPMSolver, Solver
|
||||
from refiners.foundationals.latent_diffusion.solvers import DPMSolver, LCMSolver, Solver
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
|
||||
SD1ControlnetAdapter,
|
||||
SD1IPAdapter,
|
||||
|
@ -18,6 +18,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
|
|||
ControlLoraAdapter,
|
||||
DoubleTextEncoder,
|
||||
SDXLIPAdapter,
|
||||
SDXLLcmAdapter,
|
||||
SDXLT2IAdapter,
|
||||
SDXLUNet,
|
||||
StableDiffusion_XL,
|
||||
|
@ -34,8 +35,10 @@ __all__ = [
|
|||
"SDXLUNet",
|
||||
"DoubleTextEncoder",
|
||||
"SDXLIPAdapter",
|
||||
"SDXLLcmAdapter",
|
||||
"SDXLT2IAdapter",
|
||||
"DPMSolver",
|
||||
"LCMSolver",
|
||||
"Solver",
|
||||
"CLIPTextEncoderL",
|
||||
"LatentDiffusionAutoencoder",
|
||||
|
|
|
@ -6,6 +6,16 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
|
|||
|
||||
|
||||
class LCMSolver(Solver):
|
||||
"""Latent Consistency Model solver.
|
||||
|
||||
This solver is designed for use either with
|
||||
[a specific base model][refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm.SDXLLcmAdapter]
|
||||
or [a specific LoRA][refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm_lora.add_lcm_lora].
|
||||
|
||||
See [[arXiv:2310.04378] Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378)
|
||||
for details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ControlLora, ControlLoraAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm_lora import add_lcm_lora
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import SDXLAutoencoder, StableDiffusion_XL
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||
|
@ -11,7 +13,9 @@ __all__ = [
|
|||
"DoubleTextEncoder",
|
||||
"SDXLAutoencoder",
|
||||
"SDXLIPAdapter",
|
||||
"SDXLLcmAdapter",
|
||||
"SDXLT2IAdapter",
|
||||
"ControlLora",
|
||||
"ControlLoraAdapter",
|
||||
"add_lcm_lora",
|
||||
]
|
||||
|
|
|
@ -44,13 +44,25 @@ class ResidualBlock(fl.Residual):
|
|||
)
|
||||
|
||||
|
||||
class LcmAdapter(fl.Chain, Adapter[SDXLUNet]):
|
||||
class SDXLLcmAdapter(fl.Chain, Adapter[SDXLUNet]):
|
||||
def __init__(
|
||||
self,
|
||||
target: SDXLUNet,
|
||||
condition_scale_embedding_dim: int = 256,
|
||||
condition_scale: float = 7.5,
|
||||
) -> None:
|
||||
"""Adapt [the SDXl UNet][refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet.SDXLUNet]
|
||||
for use with [LCMSolver][refiners.foundationals.latent_diffusion.solvers.lcm.LCMSolver].
|
||||
|
||||
Note that LCM must be used *without* CFG. You can disable CFG on SD by setting the
|
||||
`classifier_free_guidance` attribute to `False`.
|
||||
|
||||
Args:
|
||||
target: A SDXL UNet.
|
||||
condition_scale_embedding_dim: LCM uses a condition scale embedding, this is its dimension.
|
||||
condition_scale: Because of the embedding, the condition scale must be passed to this adapter
|
||||
instead of SD. The condition scale passed to SD will be ignored.
|
||||
"""
|
||||
assert condition_scale_embedding_dim % 2 == 0
|
||||
self.condition_scale_embedding_dim = condition_scale_embedding_dim
|
||||
self.condition_scale = condition_scale
|
||||
|
@ -71,7 +83,7 @@ class LcmAdapter(fl.Chain, Adapter[SDXLUNet]):
|
|||
self.condition_scale = scale
|
||||
self.set_context("lcm", {"condition_scale_embedding": self.sinusoidal_embedding})
|
||||
|
||||
def inject(self: "LcmAdapter", parent: fl.Chain | None = None) -> "LcmAdapter":
|
||||
def inject(self: "SDXLLcmAdapter", parent: fl.Chain | None = None) -> "SDXLLcmAdapter":
|
||||
ra = self.target.ensure_find(RangeEncoder)
|
||||
block = ResidualBlock(
|
||||
in_channels=self.condition_scale_embedding_dim,
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
import torch
|
||||
|
||||
from refiners.fluxion.adapters.lora import Lora, auto_attach_loras
|
||||
|
||||
from .lora import SDLoraManager
|
||||
from .stable_diffusion_xl import StableDiffusion_XL
|
||||
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||
|
||||
|
||||
def _check_validity(debug_map: list[tuple[str, str]]):
|
||||
|
@ -25,14 +24,27 @@ def _check_validity(debug_map: list[tuple[str, str]]):
|
|||
|
||||
def add_lcm_lora(
|
||||
manager: SDLoraManager,
|
||||
name: str,
|
||||
tensors: dict[str, torch.Tensor],
|
||||
name: str = "lcm",
|
||||
scale: float = 1.0 / 8.0,
|
||||
check_validity: bool = True,
|
||||
) -> None:
|
||||
# This is a complex LoRA so SDLoraManager.add_lora() is not enough.
|
||||
# Instead, we add the LoRAs to the UNet in several iterations, using
|
||||
# the filtering mechanism of `auto_attach_loras`.
|
||||
"""Add a LCM LoRA to SDXLUNet.
|
||||
|
||||
This is a complex LoRA so [SDLoraManager.add_loras()][refiners.foundationals.latent_diffusion.lora.SDLoraManager.add_loras]
|
||||
is not enough. Instead, we add the LoRAs to the UNet in several iterations, using the filtering mechanism of
|
||||
[auto_attach_loras][refiners.fluxion.adapters.lora.auto_attach_loras].
|
||||
|
||||
This LoRA can be used with or without CFG in SD.
|
||||
If you use CFG, typical values range from 1.0 (same as no CFG) to 2.0.
|
||||
|
||||
Args:
|
||||
manager: A SDLoraManager for SDXL
|
||||
tensors: The `state_dict` of the LCM LoRA
|
||||
name: The name of the LoRA.
|
||||
scale: The scale to use for the LoRA (should generally not be changed).
|
||||
check_validity: Perform additional checks, raise an exception if they fail.
|
||||
"""
|
||||
|
||||
assert isinstance(manager.target, StableDiffusion_XL)
|
||||
unet = manager.target.unet
|
|
@ -7,10 +7,10 @@ import torch
|
|||
from PIL import Image
|
||||
|
||||
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
|
||||
from refiners.foundationals.latent_diffusion.lcm_lora import add_lcm_lora
|
||||
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
||||
from refiners.foundationals.latent_diffusion.solvers import LCMSolver
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import LcmAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm_lora import add_lcm_lora
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||
from tests.utils import ensure_similar_images
|
||||
|
||||
|
@ -105,7 +105,7 @@ def test_lcm_base(
|
|||
|
||||
# With standard LCM the condition scale is passed to the adapter,
|
||||
# not in the diffusion loop.
|
||||
LcmAdapter(sdxl.unet, condition_scale=8.0).inject()
|
||||
SDXLLcmAdapter(sdxl.unet, condition_scale=8.0).inject()
|
||||
|
||||
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
|
||||
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
|
||||
|
@ -158,7 +158,7 @@ def test_lcm_lora_with_guidance(
|
|||
sdxl.unet.load_from_safetensors(sdxl_unet_weights)
|
||||
|
||||
manager = SDLoraManager(sdxl)
|
||||
add_lcm_lora(manager, "lcm", load_from_safetensors(sdxl_lcm_lora_weights))
|
||||
add_lcm_lora(manager, load_from_safetensors(sdxl_lcm_lora_weights))
|
||||
|
||||
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
||||
expected_image = expected_lcm_lora_1_0 if condition_scale == 1.0 else expected_lcm_lora_1_2
|
||||
|
@ -208,7 +208,7 @@ def test_lcm_lora_without_guidance(
|
|||
sdxl.unet.load_from_safetensors(sdxl_unet_weights)
|
||||
|
||||
manager = SDLoraManager(sdxl)
|
||||
add_lcm_lora(manager, "lcm", load_from_safetensors(sdxl_lcm_lora_weights))
|
||||
add_lcm_lora(manager, load_from_safetensors(sdxl_lcm_lora_weights))
|
||||
|
||||
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
||||
expected_image = expected_lcm_lora_1_0
|
||||
|
|
Loading…
Reference in a new issue