mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add support for LCM LoRA weights loading
This commit is contained in:
parent
fafe5f8f5a
commit
12b6829a26
81
src/refiners/foundationals/latent_diffusion/lcm_lora.py
Normal file
81
src/refiners/foundationals/latent_diffusion/lcm_lora.py
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from refiners.fluxion.adapters.lora import Lora, auto_attach_loras
|
||||||
|
|
||||||
|
from .lora import SDLoraManager
|
||||||
|
from .stable_diffusion_xl import StableDiffusion_XL
|
||||||
|
|
||||||
|
|
||||||
|
def _check_validity(debug_map: list[tuple[str, str]]):
|
||||||
|
# Check things are in the right block.
|
||||||
|
prefix_map = {
|
||||||
|
"down_blocks_0": ["DownBlocks.Chain_1", "DownBlocks.Chain_2", "DownBlocks.Chain_3", "DownBlocks.Chain_4"],
|
||||||
|
"down_blocks_1": ["DownBlocks.Chain_5", "DownBlocks.Chain_6", "DownBlocks.Chain_7"],
|
||||||
|
"down_blocks_2": ["DownBlocks.Chain_8", "DownBlocks.Chain_9"],
|
||||||
|
"mid_block": ["MiddleBlock"],
|
||||||
|
"up_blocks_0": ["UpBlocks.Chain_1", "UpBlocks.Chain_2", "UpBlocks.Chain_3"],
|
||||||
|
"up_blocks_1": ["UpBlocks.Chain_4", "UpBlocks.Chain_5", "UpBlocks.Chain_6"],
|
||||||
|
"up_blocks_2": ["UpBlocks.Chain_7", "UpBlocks.Chain_8", "UpBlocks.Chain_9"],
|
||||||
|
}
|
||||||
|
for key, path in debug_map:
|
||||||
|
for key_pfx, paths_pfxs in prefix_map.items():
|
||||||
|
if key.startswith(f"lora_unet_{key_pfx}"):
|
||||||
|
assert any(path.startswith(f"SDXLUNet.{x}") for x in paths_pfxs), f"bad mapping: {key} {path}"
|
||||||
|
|
||||||
|
|
||||||
|
def add_lcm_lora(
|
||||||
|
manager: SDLoraManager,
|
||||||
|
name: str,
|
||||||
|
tensors: dict[str, torch.Tensor],
|
||||||
|
scale: float = 1.0 / 8.0,
|
||||||
|
check_validity: bool = True,
|
||||||
|
) -> None:
|
||||||
|
# This is a complex LoRA so SDLoraManager.add_lora() is not enough.
|
||||||
|
# Instead, we add the LoRAs to the UNet in several iterations, using
|
||||||
|
# the filtering mechanism of `auto_attach_loras`.
|
||||||
|
|
||||||
|
assert isinstance(manager.target, StableDiffusion_XL)
|
||||||
|
unet = manager.target.unet
|
||||||
|
|
||||||
|
loras = Lora.from_dict(name, {k: v.to(unet.device, unet.dtype) for k, v in tensors.items()})
|
||||||
|
assert all(k.startswith("lora_unet_") for k in loras.keys())
|
||||||
|
loras = {k: loras[k] for k in sorted(loras.keys(), key=SDLoraManager.sort_keys)}
|
||||||
|
|
||||||
|
debug_map: list[tuple[str, str]] | None = [] if check_validity else None
|
||||||
|
|
||||||
|
# Projections are in `SDXLCrossAttention` but not in `CrossAttentionBlock`.
|
||||||
|
loras_projs = {k: v for k, v in loras.items() if k.endswith("proj_in") or k.endswith("proj_out")}
|
||||||
|
auto_attach_loras(
|
||||||
|
loras_projs,
|
||||||
|
unet,
|
||||||
|
exclude=["CrossAttentionBlock"],
|
||||||
|
include=["SDXLCrossAttention"],
|
||||||
|
debug_map=debug_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Do *not* check for time because some keys include both `resnets` and `time_emb_proj`.
|
||||||
|
exclusions = {
|
||||||
|
"res": "ResidualBlock",
|
||||||
|
"downsample": "Downsample",
|
||||||
|
"upsample": "Upsample",
|
||||||
|
}
|
||||||
|
loras_excluded = {k: v for k, v in loras.items() if any(x in k for x in exclusions.keys())}
|
||||||
|
loras_remaining = {k: v for k, v in loras.items() if k not in loras_excluded and k not in loras_projs}
|
||||||
|
|
||||||
|
auto_attach_loras(
|
||||||
|
loras_remaining,
|
||||||
|
unet,
|
||||||
|
exclude=[*exclusions.values(), "TimestepEncoder"],
|
||||||
|
debug_map=debug_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process exclusions one by one to avoid mixing them up.
|
||||||
|
for exc, v in exclusions.items():
|
||||||
|
ls = {k: v for k, v in loras_excluded.items() if exc in k}
|
||||||
|
auto_attach_loras(ls, unet, include=[v], debug_map=debug_map)
|
||||||
|
|
||||||
|
if debug_map is not None:
|
||||||
|
_check_validity(debug_map)
|
||||||
|
|
||||||
|
# LoRAs are finally injected, set the scale with the manager.
|
||||||
|
manager.set_scale(name, scale)
|
Loading…
Reference in a new issue