mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
add docstrings for LCM / LCM-LoRA
This commit is contained in:
parent
383c3c8a04
commit
684e2b9a47
|
@ -8,7 +8,7 @@ from torch import nn
|
||||||
|
|
||||||
from refiners.fluxion.model_converter import ModelConverter
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import LcmAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter
|
||||||
|
|
||||||
|
|
||||||
class Args(argparse.Namespace):
|
class Args(argparse.Namespace):
|
||||||
|
@ -39,7 +39,7 @@ def setup_converter(args: Args) -> ModelConverter:
|
||||||
|
|
||||||
if source_is_lcm:
|
if source_is_lcm:
|
||||||
assert isinstance(target, SDXLUNet)
|
assert isinstance(target, SDXLUNet)
|
||||||
LcmAdapter(target=target).inject()
|
SDXLLcmAdapter(target=target).inject()
|
||||||
|
|
||||||
x = torch.randn(1, source_in_channels, 32, 32)
|
x = torch.randn(1, source_in_channels, 32, 32)
|
||||||
timestep = torch.tensor(data=[0])
|
timestep = torch.tensor(data=[0])
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
from refiners.fluxion.adapters.lora import Conv2dLora, LinearLora, Lora, LoraAdapter
|
from refiners.fluxion.adapters.lora import Conv2dLora, LinearLora, Lora, LoraAdapter, auto_attach_loras
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Adapter",
|
"Adapter",
|
||||||
|
@ -7,4 +7,5 @@ __all__ = [
|
||||||
"LinearLora",
|
"LinearLora",
|
||||||
"Conv2dLora",
|
"Conv2dLora",
|
||||||
"LoraAdapter",
|
"LoraAdapter",
|
||||||
|
"auto_attach_loras",
|
||||||
]
|
]
|
||||||
|
|
|
@ -462,7 +462,7 @@ def auto_attach_loras(
|
||||||
target: The target Chain.
|
target: The target Chain.
|
||||||
include: A list of layer names, only layers with such a layer in its parents will be considered.
|
include: A list of layer names, only layers with such a layer in its parents will be considered.
|
||||||
exclude: A list of layer names, layers with such a layer in its parents will not be considered.
|
exclude: A list of layer names, layers with such a layer in its parents will not be considered.
|
||||||
debug_map: Pass a list to get a debug mapping of key - path pairs.
|
debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of keys of LoRA layers which failed to attach.
|
A list of keys of LoRA layers which failed to attach.
|
||||||
|
|
|
@ -5,7 +5,7 @@ from refiners.foundationals.latent_diffusion.auto_encoder import (
|
||||||
LatentDiffusionAutoencoder,
|
LatentDiffusionAutoencoder,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter
|
from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter
|
||||||
from refiners.foundationals.latent_diffusion.solvers import DPMSolver, Solver
|
from refiners.foundationals.latent_diffusion.solvers import DPMSolver, LCMSolver, Solver
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
|
||||||
SD1ControlnetAdapter,
|
SD1ControlnetAdapter,
|
||||||
SD1IPAdapter,
|
SD1IPAdapter,
|
||||||
|
@ -18,6 +18,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
|
||||||
ControlLoraAdapter,
|
ControlLoraAdapter,
|
||||||
DoubleTextEncoder,
|
DoubleTextEncoder,
|
||||||
SDXLIPAdapter,
|
SDXLIPAdapter,
|
||||||
|
SDXLLcmAdapter,
|
||||||
SDXLT2IAdapter,
|
SDXLT2IAdapter,
|
||||||
SDXLUNet,
|
SDXLUNet,
|
||||||
StableDiffusion_XL,
|
StableDiffusion_XL,
|
||||||
|
@ -34,8 +35,10 @@ __all__ = [
|
||||||
"SDXLUNet",
|
"SDXLUNet",
|
||||||
"DoubleTextEncoder",
|
"DoubleTextEncoder",
|
||||||
"SDXLIPAdapter",
|
"SDXLIPAdapter",
|
||||||
|
"SDXLLcmAdapter",
|
||||||
"SDXLT2IAdapter",
|
"SDXLT2IAdapter",
|
||||||
"DPMSolver",
|
"DPMSolver",
|
||||||
|
"LCMSolver",
|
||||||
"Solver",
|
"Solver",
|
||||||
"CLIPTextEncoderL",
|
"CLIPTextEncoderL",
|
||||||
"LatentDiffusionAutoencoder",
|
"LatentDiffusionAutoencoder",
|
||||||
|
|
|
@ -6,6 +6,16 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
|
||||||
|
|
||||||
|
|
||||||
class LCMSolver(Solver):
|
class LCMSolver(Solver):
|
||||||
|
"""Latent Consistency Model solver.
|
||||||
|
|
||||||
|
This solver is designed for use either with
|
||||||
|
[a specific base model][refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm.SDXLLcmAdapter]
|
||||||
|
or [a specific LoRA][refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm_lora.add_lcm_lora].
|
||||||
|
|
||||||
|
See [[arXiv:2310.04378] Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378)
|
||||||
|
for details.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ControlLora, ControlLoraAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ControlLora, ControlLoraAdapter
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm_lora import add_lcm_lora
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import SDXLAutoencoder, StableDiffusion_XL
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import SDXLAutoencoder, StableDiffusion_XL
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||||
|
@ -11,7 +13,9 @@ __all__ = [
|
||||||
"DoubleTextEncoder",
|
"DoubleTextEncoder",
|
||||||
"SDXLAutoencoder",
|
"SDXLAutoencoder",
|
||||||
"SDXLIPAdapter",
|
"SDXLIPAdapter",
|
||||||
|
"SDXLLcmAdapter",
|
||||||
"SDXLT2IAdapter",
|
"SDXLT2IAdapter",
|
||||||
"ControlLora",
|
"ControlLora",
|
||||||
"ControlLoraAdapter",
|
"ControlLoraAdapter",
|
||||||
|
"add_lcm_lora",
|
||||||
]
|
]
|
||||||
|
|
|
@ -44,13 +44,25 @@ class ResidualBlock(fl.Residual):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LcmAdapter(fl.Chain, Adapter[SDXLUNet]):
|
class SDXLLcmAdapter(fl.Chain, Adapter[SDXLUNet]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
target: SDXLUNet,
|
target: SDXLUNet,
|
||||||
condition_scale_embedding_dim: int = 256,
|
condition_scale_embedding_dim: int = 256,
|
||||||
condition_scale: float = 7.5,
|
condition_scale: float = 7.5,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Adapt [the SDXl UNet][refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet.SDXLUNet]
|
||||||
|
for use with [LCMSolver][refiners.foundationals.latent_diffusion.solvers.lcm.LCMSolver].
|
||||||
|
|
||||||
|
Note that LCM must be used *without* CFG. You can disable CFG on SD by setting the
|
||||||
|
`classifier_free_guidance` attribute to `False`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: A SDXL UNet.
|
||||||
|
condition_scale_embedding_dim: LCM uses a condition scale embedding, this is its dimension.
|
||||||
|
condition_scale: Because of the embedding, the condition scale must be passed to this adapter
|
||||||
|
instead of SD. The condition scale passed to SD will be ignored.
|
||||||
|
"""
|
||||||
assert condition_scale_embedding_dim % 2 == 0
|
assert condition_scale_embedding_dim % 2 == 0
|
||||||
self.condition_scale_embedding_dim = condition_scale_embedding_dim
|
self.condition_scale_embedding_dim = condition_scale_embedding_dim
|
||||||
self.condition_scale = condition_scale
|
self.condition_scale = condition_scale
|
||||||
|
@ -71,7 +83,7 @@ class LcmAdapter(fl.Chain, Adapter[SDXLUNet]):
|
||||||
self.condition_scale = scale
|
self.condition_scale = scale
|
||||||
self.set_context("lcm", {"condition_scale_embedding": self.sinusoidal_embedding})
|
self.set_context("lcm", {"condition_scale_embedding": self.sinusoidal_embedding})
|
||||||
|
|
||||||
def inject(self: "LcmAdapter", parent: fl.Chain | None = None) -> "LcmAdapter":
|
def inject(self: "SDXLLcmAdapter", parent: fl.Chain | None = None) -> "SDXLLcmAdapter":
|
||||||
ra = self.target.ensure_find(RangeEncoder)
|
ra = self.target.ensure_find(RangeEncoder)
|
||||||
block = ResidualBlock(
|
block = ResidualBlock(
|
||||||
in_channels=self.condition_scale_embedding_dim,
|
in_channels=self.condition_scale_embedding_dim,
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from refiners.fluxion.adapters.lora import Lora, auto_attach_loras
|
from refiners.fluxion.adapters.lora import Lora, auto_attach_loras
|
||||||
|
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
||||||
from .lora import SDLoraManager
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||||
from .stable_diffusion_xl import StableDiffusion_XL
|
|
||||||
|
|
||||||
|
|
||||||
def _check_validity(debug_map: list[tuple[str, str]]):
|
def _check_validity(debug_map: list[tuple[str, str]]):
|
||||||
|
@ -25,14 +24,27 @@ def _check_validity(debug_map: list[tuple[str, str]]):
|
||||||
|
|
||||||
def add_lcm_lora(
|
def add_lcm_lora(
|
||||||
manager: SDLoraManager,
|
manager: SDLoraManager,
|
||||||
name: str,
|
|
||||||
tensors: dict[str, torch.Tensor],
|
tensors: dict[str, torch.Tensor],
|
||||||
|
name: str = "lcm",
|
||||||
scale: float = 1.0 / 8.0,
|
scale: float = 1.0 / 8.0,
|
||||||
check_validity: bool = True,
|
check_validity: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
# This is a complex LoRA so SDLoraManager.add_lora() is not enough.
|
"""Add a LCM LoRA to SDXLUNet.
|
||||||
# Instead, we add the LoRAs to the UNet in several iterations, using
|
|
||||||
# the filtering mechanism of `auto_attach_loras`.
|
This is a complex LoRA so [SDLoraManager.add_loras()][refiners.foundationals.latent_diffusion.lora.SDLoraManager.add_loras]
|
||||||
|
is not enough. Instead, we add the LoRAs to the UNet in several iterations, using the filtering mechanism of
|
||||||
|
[auto_attach_loras][refiners.fluxion.adapters.lora.auto_attach_loras].
|
||||||
|
|
||||||
|
This LoRA can be used with or without CFG in SD.
|
||||||
|
If you use CFG, typical values range from 1.0 (same as no CFG) to 2.0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
manager: A SDLoraManager for SDXL
|
||||||
|
tensors: The `state_dict` of the LCM LoRA
|
||||||
|
name: The name of the LoRA.
|
||||||
|
scale: The scale to use for the LoRA (should generally not be changed).
|
||||||
|
check_validity: Perform additional checks, raise an exception if they fail.
|
||||||
|
"""
|
||||||
|
|
||||||
assert isinstance(manager.target, StableDiffusion_XL)
|
assert isinstance(manager.target, StableDiffusion_XL)
|
||||||
unet = manager.target.unet
|
unet = manager.target.unet
|
|
@ -7,10 +7,10 @@ import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
|
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
|
||||||
from refiners.foundationals.latent_diffusion.lcm_lora import add_lcm_lora
|
|
||||||
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
||||||
from refiners.foundationals.latent_diffusion.solvers import LCMSolver
|
from refiners.foundationals.latent_diffusion.solvers import LCMSolver
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import LcmAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import SDXLLcmAdapter
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm_lora import add_lcm_lora
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||||
from tests.utils import ensure_similar_images
|
from tests.utils import ensure_similar_images
|
||||||
|
|
||||||
|
@ -105,7 +105,7 @@ def test_lcm_base(
|
||||||
|
|
||||||
# With standard LCM the condition scale is passed to the adapter,
|
# With standard LCM the condition scale is passed to the adapter,
|
||||||
# not in the diffusion loop.
|
# not in the diffusion loop.
|
||||||
LcmAdapter(sdxl.unet, condition_scale=8.0).inject()
|
SDXLLcmAdapter(sdxl.unet, condition_scale=8.0).inject()
|
||||||
|
|
||||||
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
|
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
|
||||||
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
|
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
|
||||||
|
@ -158,7 +158,7 @@ def test_lcm_lora_with_guidance(
|
||||||
sdxl.unet.load_from_safetensors(sdxl_unet_weights)
|
sdxl.unet.load_from_safetensors(sdxl_unet_weights)
|
||||||
|
|
||||||
manager = SDLoraManager(sdxl)
|
manager = SDLoraManager(sdxl)
|
||||||
add_lcm_lora(manager, "lcm", load_from_safetensors(sdxl_lcm_lora_weights))
|
add_lcm_lora(manager, load_from_safetensors(sdxl_lcm_lora_weights))
|
||||||
|
|
||||||
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
||||||
expected_image = expected_lcm_lora_1_0 if condition_scale == 1.0 else expected_lcm_lora_1_2
|
expected_image = expected_lcm_lora_1_0 if condition_scale == 1.0 else expected_lcm_lora_1_2
|
||||||
|
@ -208,7 +208,7 @@ def test_lcm_lora_without_guidance(
|
||||||
sdxl.unet.load_from_safetensors(sdxl_unet_weights)
|
sdxl.unet.load_from_safetensors(sdxl_unet_weights)
|
||||||
|
|
||||||
manager = SDLoraManager(sdxl)
|
manager = SDLoraManager(sdxl)
|
||||||
add_lcm_lora(manager, "lcm", load_from_safetensors(sdxl_lcm_lora_weights))
|
add_lcm_lora(manager, load_from_safetensors(sdxl_lcm_lora_weights))
|
||||||
|
|
||||||
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
||||||
expected_image = expected_lcm_lora_1_0
|
expected_image = expected_lcm_lora_1_0
|
||||||
|
|
Loading…
Reference in a new issue