refactor Lora LoraAdapter and the latent_diffusion/lora file

This commit is contained in:
limiteinductive 2024-01-18 15:34:29 +01:00 committed by Benjamin Trom
parent dd87b9706e
commit a1f50f3f9d
12 changed files with 430 additions and 874 deletions

View file

@ -158,16 +158,6 @@ You should get:
![dropy slime output](https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/dropy_slime_9752.png)
### Training
Refiners has a built-in training utils library and provides scripts that can be used as a starting point.
E.g. to train a LoRA on top of Stable Diffusion, copy and edit `configs/finetune-lora.toml` to suit your needs and launch the training as follows:
```bash
python scripts/training/finetune-ldm-lora.py configs/finetune-lora.toml
```
## Adapter Zoo
For now, given [finegrain](https://finegrain.ai)'s mission, we are focusing on image edition tasks. We support:

View file

@ -1,68 +0,0 @@
[wandb]
mode = "offline" # "online", "offline", "disabled"
entity = "acme"
project = "test-lora-training"
[models]
unet = {checkpoint = "/path/to/stable-diffusion-1-5/unet.safetensors"}
text_encoder = {checkpoint = "/path/to/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"}
lda = {checkpoint = "/path/to/stable-diffusion-1-5/lda.safetensors"}
[latent_diffusion]
unconditional_sampling_probability = 0.05
offset_noise = 0.1
[lora]
rank = 16
trigger_phrase = "a spsh photo,"
use_only_trigger_probability = 1.0
unet_targets = ["CrossAttentionBlock2d"]
text_encoder_targets = ["TransformerLayer"]
lda_targets = []
[training]
duration = "1000:epoch"
seed = 0
gpu_index = 0
batch_size = 4
gradient_accumulation = "4:step"
clip_grad_norm = 1.0
# clip_grad_value = 1.0
evaluation_interval = "5:epoch"
evaluation_seed = 1
[optimizer]
optimizer = "Prodigy" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit"
learning_rate = 1
betas = [0.9, 0.999]
eps = 1e-8
weight_decay = 1e-2
[scheduler]
scheduler_type = "ConstantLR"
update_interval = "1:step"
warmup = "500:step"
[dropout]
dropout_probability = 0.2
use_gyro_dropout = false
[dataset]
hf_repo = "acme/images"
revision = "main"
[checkpointing]
# save_folder = "/path/to/ckpts"
save_interval = "1:step"
[test_diffusion]
num_inference_steps = 30
use_short_prompts = false
prompts = [
"a cute cat",
"a cute dog",
"a cute bird",
"a cute horse",
]

View file

@ -1,149 +0,0 @@
import argparse
from pathlib import Path
from typing import cast
import torch
from diffusers import DiffusionPipeline # type: ignore
from torch import Tensor
from torch.nn import Parameter as TorchParameter
from torch.nn.init import zeros_
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.lora import Lora, LoraAdapter
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import no_grad, save_to_safetensors
from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
def get_weight(linear: fl.Linear) -> torch.Tensor:
assert linear.bias is None
return linear.state_dict()["weight"]
def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, torch.Tensor]:
weights: list[torch.Tensor] = []
for lora in module.layers(layer_type=Lora):
linears = list(lora.layers(layer_type=fl.Linear))
assert len(linears) == 2
weights.extend((get_weight(linear=linears[1]), get_weight(linear=linears[0]))) # aka (up_weight, down_weight)
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(iterable=weights)}
class Args(argparse.Namespace):
source_path: str
base_model: str
output_file: str
verbose: bool
@no_grad()
def process(args: Args) -> None:
diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
diffusers_sd = DiffusionPipeline.from_pretrained( # type: ignore
pretrained_model_name_or_path=args.base_model,
low_cpu_mem_usage=False,
)
diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore
refiners_model = SD1UNet(in_channels=4)
target = LoraTarget.CrossAttention
metadata = {"unet_targets": "CrossAttentionBlock2d"}
rank = diffusers_state_dict[
"mid_block.attentions.0.transformer_blocks.0.attn1.processor.to_q_lora.down.weight"
].shape[0]
x = torch.randn(1, 4, 32, 32)
timestep = torch.tensor(data=[0])
clip_text_embeddings = torch.randn(1, 77, 768)
refiners_model.set_timestep(timestep=timestep)
refiners_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
refiners_args = (x,)
diffusers_args = (x, timestep, clip_text_embeddings)
converter = ModelConverter(
source_model=refiners_model, target_model=diffusers_model, skip_output_check=True, verbose=args.verbose
)
if not converter.run(
source_args=refiners_args,
target_args=diffusers_args,
):
raise RuntimeError("Model conversion failed")
diffusers_to_refiners = converter.get_mapping()
LoraAdapter[SD1UNet](refiners_model, sub_targets=lora_targets(refiners_model, target), rank=rank).inject()
for layer in refiners_model.layers(layer_type=Lora):
zeros_(tensor=layer.Linear_1.weight)
targets = {k.split("_lora.")[0] for k in diffusers_state_dict.keys()}
for target_k in targets:
k_p, k_s = target_k.split(".processor.")
r = [v for k, v in diffusers_to_refiners.items() if k.startswith(f"{k_p}.{k_s}")]
assert len(r) == 1
orig_k = r[0]
orig_path = orig_k.split(sep=".")
p = refiners_model
for seg in orig_path[:-1]:
p = p[seg]
assert isinstance(p, fl.Chain)
last_seg = (
"SingleLoraAdapter"
if orig_path[-1] == "Linear"
else f"SingleLoraAdapter_{orig_path[-1].removeprefix('Linear_')}"
)
p_down = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.down.weight"])
p_up = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.up.weight"])
p[last_seg].Lora.load_weights(p_down, p_up)
state_dict = build_loras_safetensors(module=refiners_model, key_prefix="unet.")
assert len(state_dict) == 320
save_to_safetensors(path=args.output_path, tensors=state_dict, metadata=metadata)
def main() -> None:
parser = argparse.ArgumentParser(description="Convert LoRAs saved using the diffusers library to refiners format.")
parser.add_argument(
"--from",
type=str,
dest="source_path",
required=True,
help="Source file path (.bin|safetensors) containing the LoRAs.",
)
parser.add_argument(
"--base-model",
type=str,
required=False,
default="runwayml/stable-diffusion-v1-5",
help="Base model, used for the UNet structure. Default: runwayml/stable-diffusion-v1-5",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
required=False,
default=None,
help=(
"Output file path (.safetensors) for converted LoRAs. If not provided, the output path will be the same as"
" the source path."
),
)
parser.add_argument(
"--verbose",
action="store_true",
dest="verbose",
default=False,
help="Use this flag to print verbose output during conversion.",
)
args = parser.parse_args(namespace=Args())
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}-refiners.safetensors"
process(args=args)
if __name__ == "__main__":
main()

View file

@ -1,124 +0,0 @@
import argparse
from functools import partial
from convert_diffusers_unet import Args as UnetConversionArgs, setup_converter as convert_unet
from convert_transformers_clip_text_model import (
Args as TextEncoderConversionArgs,
setup_converter as convert_text_encoder,
)
from torch import Tensor
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import (
load_from_safetensors,
load_metadata_from_safetensors,
save_to_safetensors,
)
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.lora import LoraTarget
def get_unet_mapping(source_path: str) -> dict[str, str]:
args = UnetConversionArgs(source_path=source_path, verbose=False)
return convert_unet(args=args).get_mapping()
def get_text_encoder_mapping(source_path: str) -> dict[str, str]:
args = TextEncoderConversionArgs(source_path=source_path, subfolder="text_encoder", verbose=False)
return convert_text_encoder(
args=args,
).get_mapping()
def main() -> None:
parser = argparse.ArgumentParser(description="Converts a refiner's LoRA weights to SD-WebUI's LoRA weights")
parser.add_argument(
"-i",
"--input-file",
type=str,
required=True,
help="Path to the input file with refiner's LoRA weights (safetensors format)",
)
parser.add_argument(
"-o",
"--output-file",
type=str,
default="sdwebui_loras.safetensors",
help="Path to the output file with sd-webui's LoRA weights (safetensors format)",
)
parser.add_argument(
"--sd15",
type=str,
required=False,
default="runwayml/stable-diffusion-v1-5",
help="Path (preferred) or repository ID of Stable Diffusion 1.5 model (Hugging Face diffusers format)",
)
args = parser.parse_args()
metadata = load_metadata_from_safetensors(path=args.input_file)
assert metadata is not None, f"Could not load metadata from {args.input_file}"
tensors = load_from_safetensors(path=args.input_file)
state_dict: dict[str, Tensor] = {}
for meta_key, meta_value in metadata.items():
match meta_key:
case "unet_targets":
model = SD1UNet(in_channels=4)
create_mapping = partial(get_unet_mapping, source_path=args.sd15)
key_prefix = "unet."
lora_prefix = "lora_unet_"
case "text_encoder_targets":
model = CLIPTextEncoderL()
create_mapping = partial(get_text_encoder_mapping, source_path=args.sd15)
key_prefix = "text_encoder."
lora_prefix = "lora_te_"
case "lda_targets":
raise ValueError("SD-WebUI does not support LoRA for the auto-encoder")
case _:
raise ValueError(f"Unexpected key in checkpoint metadata: {meta_key}")
submodule_to_key: dict[fl.Module, str] = {}
for name, submodule in model.named_modules():
submodule_to_key[submodule] = name
# SD-WebUI expects LoRA state dicts with keys derived from the diffusers format, e.g.:
#
# lora_unet_down_blocks_0_attentions_0_proj_in.alpha
# lora_unet_down_blocks_0_attentions_0_proj_in.lora_down.weight
# lora_unet_down_blocks_0_attentions_0_proj_in.lora_up.weight
# ...
#
# Internally SD-WebUI has some logic[1] to convert such keys into the CompVis format. See
# `convert_diffusers_name_to_compvis` for more details.
#
# [1]: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/394ffa7/extensions-builtin/Lora/lora.py#L158-L225
refiners_to_diffusers = create_mapping()
assert refiners_to_diffusers is not None
# Compute the corresponding diffusers' keys where LoRA layers must be applied
lora_injection_points: list[str] = [
refiners_to_diffusers[submodule_to_key[linear]]
for target in [LoraTarget(t) for t in meta_value.split(sep=",")]
for layer in model.layers(layer_type=target.get_class())
for linear in layer.layers(layer_type=fl.Linear)
]
lora_weights = [w for w in [tensors[k] for k in sorted(tensors) if k.startswith(key_prefix)]]
assert len(lora_injection_points) == len(lora_weights) // 2
# Map LoRA weights to each key using SD-WebUI conventions (proper prefix and suffix, underscores)
for i, diffusers_key in enumerate(iterable=lora_injection_points):
lora_key = lora_prefix + diffusers_key.replace(".", "_")
# Note: no ".alpha" weights (those are used to scale the LoRA by alpha/rank). Refiners uses a scale = 1.0
# by default (see `lora_calc_updown` in SD-WebUI for more details)
state_dict[lora_key + ".lora_up.weight"] = lora_weights[2 * i]
state_dict[lora_key + ".lora_down.weight"] = lora_weights[2 * i + 1]
assert state_dict
save_to_safetensors(path=args.output_file, tensors=state_dict)
if __name__ == "__main__":
main()

View file

@ -229,10 +229,15 @@ def download_vae_ft_mse():
)
def download_lora():
dest_folder = os.path.join(test_weights_dir, "pcuenq", "pokemon-lora")
def download_loras():
dest_folder = os.path.join(test_weights_dir, "loras", "pokemon-lora")
download_file("https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin", dest_folder)
dest_folder = os.path.join(test_weights_dir, "loras", "dpo-lora")
download_file(
"https://huggingface.co/radames/sdxl-DPO-LoRA/resolve/main/pytorch_lora_weights.safetensors", dest_folder
)
def download_preprocessors():
dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings")
@ -454,17 +459,6 @@ def convert_vae_fp16_fix():
)
def convert_lora():
os.makedirs("tests/weights/loras", exist_ok=True)
run_conversion_script(
"convert_diffusers_lora.py",
"tests/weights/pcuenq/pokemon-lora/pytorch_lora_weights.bin",
"tests/weights/loras/pcuenq_pokemon_lora.safetensors",
additional_args=["--base-model", "tests/weights/runwayml/stable-diffusion-v1-5"],
expected_hash="a9d7e08e",
)
def convert_preprocessors():
subprocess.run(
[
@ -632,7 +626,7 @@ def download_all():
download_sdxl("stabilityai/stable-diffusion-xl-base-1.0")
download_vae_ft_mse()
download_vae_fp16_fix()
download_lora()
download_loras()
download_preprocessors()
download_controlnet()
download_unclip()
@ -647,7 +641,7 @@ def convert_all():
convert_sdxl()
convert_vae_ft_mse()
convert_vae_fp16_fix()
convert_lora()
# Note: no convert loras: this is done at runtime by `SDLoraManager`
convert_preprocessors()
convert_controlnet()
convert_unclip()

View file

@ -1,123 +0,0 @@
import random
from typing import Any
from loguru import logger
from pydantic import BaseModel
from torch import Tensor
from torch.utils.data import Dataset
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.latent_diffusion.lora import MODELS, LoraAdapter, LoraTarget, lora_targets
from refiners.training_utils.callback import Callback
from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig
from refiners.training_utils.latent_diffusion import (
FinetuneLatentDiffusionConfig,
LatentDiffusionConfig,
LatentDiffusionTrainer,
TextEmbeddingLatentsBatch,
TextEmbeddingLatentsDataset,
)
class LoraConfig(BaseModel):
rank: int = 32
trigger_phrase: str = ""
use_only_trigger_probability: float = 0.0
unet_targets: list[LoraTarget]
text_encoder_targets: list[LoraTarget]
lda_targets: list[LoraTarget]
class TriggerPhraseDataset(TextEmbeddingLatentsDataset):
def __init__(
self,
trainer: "LoraLatentDiffusionTrainer",
) -> None:
super().__init__(trainer=trainer)
self.trigger_phrase = trainer.config.lora.trigger_phrase
self.use_only_trigger_probability = trainer.config.lora.use_only_trigger_probability
logger.info(f"Trigger phrase: {self.trigger_phrase}")
def process_caption(self, caption: str) -> str:
caption = super().process_caption(caption=caption)
if self.trigger_phrase:
caption = (
f"{self.trigger_phrase} {caption}"
if random.random() < self.use_only_trigger_probability
else self.trigger_phrase
)
return caption
class LoraLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
dataset: HuggingfaceDatasetConfig
latent_diffusion: LatentDiffusionConfig
lora: LoraConfig
def model_post_init(self, __context: Any) -> None:
"""Pydantic v2 does post init differently, so we need to override this method too."""
logger.info("Freezing models to train only the loras.")
self.models["unet"].train = False
self.models["text_encoder"].train = False
self.models["lda"].train = False
class LoraLatentDiffusionTrainer(LatentDiffusionTrainer[LoraLatentDiffusionConfig]):
def __init__(
self,
config: LoraLatentDiffusionConfig,
callbacks: "list[Callback[Any]] | None" = None,
) -> None:
super().__init__(config=config, callbacks=callbacks)
self.callbacks.extend((LoadLoras(), SaveLoras()))
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
return TriggerPhraseDataset(trainer=self)
class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
def on_train_begin(self, trainer: LoraLatentDiffusionTrainer) -> None:
lora_config = trainer.config.lora
for model_name in MODELS:
model = getattr(trainer, model_name)
model_targets: list[LoraTarget] = getattr(lora_config, f"{model_name}_targets")
adapter = LoraAdapter[type(model)](
model,
sub_targets=lora_targets(model, model_targets),
rank=lora_config.rank,
)
for sub_adapter, _ in adapter.sub_adapters:
for linear in sub_adapter.Lora.layers(fl.Linear):
linear.requires_grad_(requires_grad=True)
adapter.inject()
class SaveLoras(Callback[LoraLatentDiffusionTrainer]):
def on_checkpoint_save(self, trainer: LoraLatentDiffusionTrainer) -> None:
tensors: dict[str, Tensor] = {}
metadata: dict[str, str] = {}
for model_name in MODELS:
model = getattr(trainer, model_name)
adapter = model.parent
tensors |= {f"{model_name}.{i:03d}": w for i, w in enumerate(adapter.weights)}
metadata |= {f"{model_name}_targets": ",".join(adapter.sub_targets)}
save_to_safetensors(
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors",
tensors=tensors,
metadata=metadata,
)
if __name__ == "__main__":
import sys
config_path = sys.argv[1]
config = LoraLatentDiffusionConfig.load_from_toml(
toml_path=config_path,
)
trainer = LoraLatentDiffusionTrainer(config=config)
trainer.train()

View file

@ -1,4 +1,5 @@
from typing import Any, Generic, Iterable, TypeVar
from abc import ABC, abstractmethod
from typing import Any
from torch import Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter
@ -6,125 +7,259 @@ from torch.nn.init import normal_, zeros_
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
T = TypeVar("T", bound=fl.Chain)
TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PEP 673)
from refiners.fluxion.layers.chain import Chain
class Lora(fl.Chain):
class Lora(fl.Chain, ABC):
def __init__(
self,
rank: int = 16,
scale: float = 1.0,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.rank = rank
self._scale = scale
super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale))
normal_(tensor=self.down.weight, std=1 / self.rank)
zeros_(tensor=self.up.weight)
@abstractmethod
def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.WeightedModule, fl.WeightedModule]:
...
@property
def down(self) -> fl.WeightedModule:
down_layer = self[0]
assert isinstance(down_layer, fl.WeightedModule)
return down_layer
@property
def up(self) -> fl.WeightedModule:
up_layer = self[1]
assert isinstance(up_layer, fl.WeightedModule)
return up_layer
@property
def scale(self) -> float:
return self._scale
@scale.setter
def scale(self, value: float) -> None:
self._scale = value
self.ensure_find(fl.Multiply).scale = value
@classmethod
def from_weights(
cls,
down: Tensor,
up: Tensor,
) -> "Lora":
match (up.ndim, down.ndim):
case (2, 2):
return LinearLora.from_weights(up=up, down=down)
case (4, 4):
return Conv2dLora.from_weights(up=up, down=down)
case _:
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")
@classmethod
def from_dict(cls, state_dict: dict[str, Tensor], /) -> dict[str, "Lora"]:
"""
Create a dictionary of LoRA layers from a state dict.
Expects the state dict to be a succession of down and up weights.
"""
state_dict = {k: v for k, v in state_dict.items() if ".weight" in k}
loras: dict[str, Lora] = {}
for down_key, down_tensor, up_tensor in zip(
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
):
key = ".".join(down_key.split(".")[:-2])
loras[key] = cls.from_weights(down=down_tensor, up=up_tensor)
return loras
@abstractmethod
def auto_attach(self, target: fl.Chain, exclude: list[str] | None = None) -> Any:
...
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
assert down_weight.shape == self.down.weight.shape
assert up_weight.shape == self.up.weight.shape
self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
class LinearLora(Lora):
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 16,
scale: float = 1.0,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_features = in_features
self.out_features = out_features
self.rank = rank
self.scale: float = 1.0
super().__init__(
fl.Linear(in_features=in_features, out_features=rank, bias=False, device=device, dtype=dtype),
fl.Linear(in_features=rank, out_features=out_features, bias=False, device=device, dtype=dtype),
fl.Lambda(func=self.scale_outputs),
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
@classmethod
def from_weights(
cls,
down: Tensor,
up: Tensor,
) -> "LinearLora":
assert up.ndim == 2 and down.ndim == 2
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
lora = cls(
in_features=down.shape[1], out_features=up.shape[0], rank=down.shape[0], device=up.device, dtype=up.dtype
)
lora.load_weights(down_weight=down, up_weight=up)
return lora
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
for layer, parent in target.walk(fl.Linear):
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
continue
if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue
if layer.in_features == self.in_features and layer.out_features == self.out_features:
return LoraAdapter(target=layer, lora=self), parent
def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.Linear, fl.Linear]:
return (
fl.Linear(
in_features=self.in_features,
out_features=self.rank,
bias=False,
device=device,
dtype=dtype,
),
fl.Linear(
in_features=self.rank,
out_features=self.out_features,
bias=False,
device=device,
dtype=dtype,
),
)
normal_(tensor=self.Linear_1.weight, std=1 / self.rank)
zeros_(tensor=self.Linear_2.weight)
def scale_outputs(self, x: Tensor) -> Tensor:
return x * self.scale
def set_scale(self, scale: float) -> None:
self.scale = scale
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
self.Linear_1.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
self.Linear_2.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
@property
def up_weight(self) -> Tensor:
return self.Linear_2.weight.data
@property
def down_weight(self) -> Tensor:
return self.Linear_1.weight.data
class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]):
class Conv2dLora(Lora):
def __init__(
self,
target: fl.Linear,
in_channels: int,
out_channels: int,
rank: int = 16,
scale: float = 1.0,
kernel_size: tuple[int, int] = (1, 3),
stride: tuple[int, int] = (1, 1),
padding: tuple[int, int] = (0, 1),
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_features = target.in_features
self.out_features = target.out_features
self.rank = rank
self.scale = scale
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
@classmethod
def from_weights(
cls,
down: Tensor,
up: Tensor,
) -> "Conv2dLora":
assert up.ndim == 4 and down.ndim == 4
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
down_kernel_size, up_kernel_size = down.shape[2], up.shape[2]
down_padding = 1 if down_kernel_size == 3 else 0
up_padding = 1 if up_kernel_size == 3 else 0
lora = cls(
in_channels=down.shape[1],
out_channels=up.shape[0],
rank=down.shape[0],
kernel_size=(down_kernel_size, up_kernel_size),
padding=(down_padding, up_padding),
device=up.device,
dtype=up.dtype,
)
lora.load_weights(down_weight=down, up_weight=up)
return lora
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
for layer, parent in target.walk(fl.Conv2d):
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
continue
if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue
if layer.in_channels == self.in_channels and layer.out_channels == self.out_channels:
if layer.stride != (self.stride[0], self.stride[0]):
self.down.stride = layer.stride
return LoraAdapter(
target=layer,
lora=self,
), parent
def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.Conv2d, fl.Conv2d]:
return (
fl.Conv2d(
in_channels=self.in_channels,
out_channels=self.rank,
kernel_size=self.kernel_size[0],
stride=self.stride[0],
padding=self.padding[0],
use_bias=False,
device=device,
dtype=dtype,
),
fl.Conv2d(
in_channels=self.rank,
out_channels=self.out_channels,
kernel_size=self.kernel_size[1],
stride=self.stride[1],
padding=self.padding[1],
use_bias=False,
device=device,
dtype=dtype,
),
)
class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
def __init__(self, target: fl.WeightedModule, lora: Lora) -> None:
with self.setup_adapter(target):
super().__init__(
target,
Lora(
in_features=target.in_features,
out_features=target.out_features,
rank=rank,
device=target.device,
dtype=target.dtype,
),
)
self.Lora.set_scale(scale=scale)
class LoraAdapter(Generic[T], fl.Chain, Adapter[T]):
def __init__(
self,
target: T,
sub_targets: Iterable[tuple[fl.Linear, fl.Chain]],
rank: int | None = None,
scale: float = 1.0,
weights: list[Tensor] | None = None,
) -> None:
with self.setup_adapter(target):
super().__init__(target)
if weights is not None:
assert len(weights) % 2 == 0
weights_rank = weights[0].shape[1]
if rank is None:
rank = weights_rank
else:
assert rank == weights_rank
assert rank is not None, "either pass a rank or weights"
self.sub_targets = sub_targets
self.sub_adapters: list[tuple[SingleLoraAdapter, fl.Chain]] = []
for linear, parent in self.sub_targets:
self.sub_adapters.append((SingleLoraAdapter(target=linear, rank=rank, scale=scale), parent))
if weights is not None:
assert len(self.sub_adapters) == (len(weights) // 2)
for i, (adapter, _) in enumerate(self.sub_adapters):
lora = adapter.Lora
assert (
lora.rank == weights[i * 2].shape[1]
), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}"
adapter.Lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1])
def inject(self: TLoraAdapter, parent: fl.Chain | None = None) -> TLoraAdapter:
for adapter, adapter_parent in self.sub_adapters:
adapter.inject(adapter_parent)
return super().inject(parent)
def eject(self) -> None:
for adapter, _ in self.sub_adapters:
adapter.eject()
super().eject()
super().__init__(target, lora)
@property
def weights(self) -> list[Tensor]:
return [w for adapter, _ in self.sub_adapters for w in [adapter.Lora.up_weight, adapter.Lora.down_weight]]
def lora(self) -> Lora:
return self.ensure_find(Lora)
@property
def scale(self) -> float:
return self.lora.scale
@scale.setter
def scale(self, value: float) -> None:
self.lora.scale = value

View file

@ -1,146 +1,140 @@
from enum import Enum
from pathlib import Path
from typing import Callable, Iterator
from warnings import warn
from torch import Tensor
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.adapters.lora import Lora, LoraAdapter
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
from refiners.foundationals.latent_diffusion import (
CLIPTextEncoderL,
LatentDiffusionAutoencoder,
SD1UNet,
StableDiffusion_1,
)
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
MODELS = ["unet", "text_encoder", "lda"]
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
class LoraTarget(str, Enum):
Self = "self"
Attention = "Attention"
SelfAttention = "SelfAttention"
CrossAttention = "CrossAttentionBlock2d"
FeedForward = "FeedForward"
TransformerLayer = "TransformerLayer"
def get_class(self) -> type[fl.Chain]:
match self:
case LoraTarget.Self:
return fl.Chain
case LoraTarget.Attention:
return fl.Attention
case LoraTarget.SelfAttention:
return fl.SelfAttention
case LoraTarget.CrossAttention:
return CrossAttentionBlock2d
case LoraTarget.FeedForward:
return FeedForward
case LoraTarget.TransformerLayer:
return TransformerLayer
def _predicate(k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
def f(m: fl.Module, _: fl.Chain) -> bool:
if isinstance(m, Lora): # do not adapt other LoRAs
raise StopIteration
if isinstance(m, Controlnet): # do not adapt Controlnet linears
raise StopIteration
return isinstance(m, k)
return f
def _iter_linears(module: fl.Chain) -> Iterator[tuple[fl.Linear, fl.Chain]]:
for m, p in module.walk(_predicate(fl.Linear)):
assert isinstance(m, fl.Linear)
yield (m, p)
def lora_targets(
module: fl.Chain,
target: LoraTarget | list[LoraTarget],
) -> Iterator[tuple[fl.Linear, fl.Chain]]:
if isinstance(target, list):
for t in target:
yield from lora_targets(module, t)
return
if target == LoraTarget.Self:
yield from _iter_linears(module)
return
for layer, _ in module.walk(_predicate(target.get_class())):
assert isinstance(layer, fl.Chain)
yield from _iter_linears(layer)
class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
metadata: dict[str, str] | None
tensors: dict[str, Tensor]
class SDLoraManager:
def __init__(
self,
target: StableDiffusion_1,
sub_targets: dict[str, list[LoraTarget]],
target: LatentDiffusionModel,
) -> None:
self.target = target
@property
def unet(self) -> fl.Chain:
unet = self.target.unet
assert isinstance(unet, fl.Chain)
return unet
@property
def clip_text_encoder(self) -> fl.Chain:
clip_text_encoder = self.target.clip_text_encoder
assert isinstance(clip_text_encoder, fl.Chain)
return clip_text_encoder
def load(
self,
tensors: dict[str, Tensor],
/,
scale: float = 1.0,
weights: dict[str, Tensor] | None = None,
):
with self.setup_adapter(target):
super().__init__(target)
) -> None:
"""Load the LoRA weights from a dictionary of tensors.
self.sub_adapters: list[LoraAdapter[SD1UNet | CLIPTextEncoderL | LatentDiffusionAutoencoder]] = []
for model_name in MODELS:
if not (model_targets := sub_targets.get(model_name, [])):
continue
model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name)
lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None
self.sub_adapters.append(
LoraAdapter[type(model)](
model,
sub_targets=lora_targets(model, model_targets),
scale=scale,
weights=lora_weights,
)
)
@classmethod
def from_safetensors(
cls,
target: StableDiffusion_1,
checkpoint_path: Path | str,
scale: float = 1.0,
):
metadata = load_metadata_from_safetensors(checkpoint_path)
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
tensors = load_from_safetensors(checkpoint_path, device=target.device)
sub_targets: dict[str, list[LoraTarget]] = {}
for model_name in MODELS:
if not (v := metadata.get(f"{model_name}_targets", "")):
continue
sub_targets[model_name] = [LoraTarget(x) for x in v.split(",")]
return cls(
target,
sub_targets,
scale=scale,
weights=tensors,
Expects the keys to be in the commonly found formats on CivitAI's hub.
"""
assert len(self.lora_adapters) == 0, "Loras already loaded"
loras = Lora.from_dict(
{key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()}
)
loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}
def inject(self: "SD1LoraAdapter", parent: fl.Chain | None = None) -> "SD1LoraAdapter":
for adapter in self.sub_adapters:
adapter.inject()
return super().inject(parent)
# if no key contains "unet" or "text", assume all keys are for the unet
if not "unet" in loras and not "text" in loras:
loras = {f"unet_{key}": loras[key] for key in loras.keys()}
def eject(self) -> None:
for adapter in self.sub_adapters:
adapter.eject()
super().eject()
self.load_unet(loras)
self.load_text_encoder(loras)
self.scale = scale
def load_text_encoder(self, loras: dict[str, Lora], /) -> None:
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
def load_unet(self, loras: dict[str, Lora], /) -> None:
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
exclude: list[str] = []
exclude = [
self.unet_exclusions[exclusion]
for exclusion in self.unet_exclusions
if all([exclusion not in key for key in unet_loras.keys()])
]
SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude)
def unload(self) -> None:
for lora_adapter in self.lora_adapters:
lora_adapter.eject()
@property
def loras(self) -> list[Lora]:
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora))
@property
def lora_adapters(self) -> list[LoraAdapter]:
return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter))
@property
def unet_exclusions(self) -> dict[str, str]:
return {
"time": "TimestepEncoder",
"res": "ResidualBlock",
"downsample": "DownsampleBlock",
"upsample": "UpsampleBlock",
}
@property
def scale(self) -> float:
assert len(self.loras) > 0, "No loras found"
assert all([lora.scale == self.loras[0].scale for lora in self.loras])
return self.loras[0].scale
@scale.setter
def scale(self, value: float) -> None:
for lora in self.loras:
lora.scale = value
@staticmethod
def pad(input: str, /, padding_length: int = 2) -> str:
new_split: list[str] = []
for s in input.split("_"):
if s.isdigit():
new_split.append(s.zfill(padding_length))
else:
new_split.append(s)
return "_".join(new_split)
@staticmethod
def sort_keys(key: str, /) -> tuple[str, int]:
# out0 happens sometimes as an alias for out ; this dict might not be exhaustive
key_char_order = {"q": 1, "k": 2, "v": 3, "out": 4, "out0": 4}
for i, s in enumerate(key.split("_")):
if s in key_char_order:
prefix = SDLoraManager.pad("_".join(key.split("_")[:i]))
return (prefix, key_char_order[s])
return (SDLoraManager.pad(key), 5)
@staticmethod
def auto_attach(
loras: dict[str, Lora],
target: fl.Chain,
/,
exclude: list[str] | None = None,
) -> None:
failed_loras: dict[str, Lora] = {}
for key, lora in loras.items():
if attach := lora.auto_attach(target, exclude=exclude):
adapter, parent = attach
adapter.inject(parent)
else:
failed_loras[key] = lora
if failed_loras:
warn(f"failed to attach {len(failed_loras)}/{len(loras)} loras to {target.__class__.__name__}")
# TODO: add a stronger sanity check to make sure loras are attached correctly

View file

@ -1,82 +0,0 @@
from torch import allclose, randn
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.lora import Lora, LoraAdapter, SingleLoraAdapter
def test_single_lora_adapter() -> None:
chain = fl.Chain(
fl.Chain(
fl.Linear(in_features=1, out_features=1),
fl.Linear(in_features=1, out_features=1),
),
fl.Linear(in_features=1, out_features=2),
)
x = randn(1, 1)
y = chain(x)
lora_adapter = SingleLoraAdapter(chain.Chain.Linear_1).inject(chain.Chain)
assert isinstance(lora_adapter[1], Lora)
assert allclose(input=chain(x), other=y)
assert lora_adapter.parent == chain.Chain
lora_adapter.eject()
assert isinstance(chain.Chain[0], fl.Linear)
assert len(chain) == 2
lora_adapter.inject(chain.Chain)
assert isinstance(chain.Chain[0], SingleLoraAdapter)
def test_lora_adapter() -> None:
chain = fl.Chain(
fl.Chain(
fl.Linear(in_features=1, out_features=1),
fl.Linear(in_features=1, out_features=1),
),
fl.Linear(in_features=1, out_features=2),
)
# create and inject twice
a1 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
assert len(list(chain.layers(Lora))) == 3
a2 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
assert len(list(chain.layers(Lora))) == 6
# If we init a LoRA when another LoRA is already injected, the Linear
# layers of the first LoRA will be adapted too, which is typically not
# what we want.
# This issue can be avoided either by making the predicate for
# `walk` raise StopIteration when it encounters a LoRA (see the SD LoRA)
# or by creating all the LoRA Adapters first, before injecting them
# (see below).
assert len(list(chain.layers(Lora, recurse=True))) == 12
# ejection in forward order
a1.eject()
assert len(list(chain.layers(Lora))) == 3
a2.eject()
assert len(list(chain.layers(Lora))) == 0
# create twice then inject twice
a1 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0)
a2 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0)
a1.inject()
a2.inject()
assert len(list(chain.layers(Lora))) == 6
# If we inject after init we do not have the target selection problem,
# the LoRA layers are not adapted.
assert len(list(chain.layers(Lora, recurse=True))) == 6
# ejection in reverse order
a2.eject()
assert len(list(chain.layers(Lora))) == 3
a1.eject()
assert len(list(chain.layers(Lora))) == 0

View file

@ -19,7 +19,7 @@ from refiners.foundationals.latent_diffusion import (
StableDiffusion_1,
StableDiffusion_1_Inpainting,
)
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
from refiners.foundationals.latent_diffusion.restart import Restart
@ -182,14 +182,38 @@ def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[
condition_image = Image.open(ref_path / f"fairy_guide_{name}.png").convert("RGB")
expected_image = Image.open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB")
weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors"
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
return name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module")
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Path]:
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB")
weights_path = test_weights_path / "loras" / "pcuenq_pokemon_lora.safetensors"
return expected_image, weights_path
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
tensors = torch.load(weights_path) # type: ignore
return expected_image, tensors
@pytest.fixture(scope="module")
def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
expected_image = Image.open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB")
weights_path = test_weights_path / "loras" / "dpo-lora" / "pytorch_lora_weights.safetensors"
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
tensors = load_from_safetensors(weights_path)
return expected_image, tensors
@pytest.fixture
@ -1010,24 +1034,20 @@ def test_diffusion_controlnet_stack(
@no_grad()
def test_diffusion_lora(
sd15_std: StableDiffusion_1,
lora_data_pokemon: tuple[Image.Image, Path],
lora_data_pokemon: tuple[Image.Image, dict[str, torch.Tensor]],
test_device: torch.device,
):
) -> None:
sd15 = sd15_std
n_steps = 30
expected_image, lora_weights_path = lora_data_pokemon
if not lora_weights_path.is_file():
warn(f"could not find weights at {lora_weights_path}, skipping")
pytest.skip(allow_module_level=True)
expected_image, lora_weights = lora_data_pokemon
prompt = "a cute cat"
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sd15.set_num_inference_steps(n_steps)
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=1.0).inject()
SDLoraManager(sd15).load(lora_weights, scale=1)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
@ -1045,77 +1065,45 @@ def test_diffusion_lora(
@no_grad()
def test_diffusion_lora_float16(
sd15_std_float16: StableDiffusion_1,
lora_data_pokemon: tuple[Image.Image, Path],
test_device: torch.device,
):
sd15 = sd15_std_float16
n_steps = 30
def test_diffusion_sdxl_lora(
sdxl_ddim: StableDiffusion_XL,
lora_data_dpo: tuple[Image.Image, dict[str, torch.Tensor]],
) -> None:
sdxl = sdxl_ddim
expected_image, lora_weights = lora_data_dpo
expected_image, lora_weights_path = lora_data_pokemon
# parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA
# except that we are using DDIM instead of sde-dpmsolver++
n_steps = 40
seed = 12341234123
guidance_scale = 7.5
lora_scale = 1.4
prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography"
negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white"
if not lora_weights_path.is_file():
warn(f"could not find weights at {lora_weights_path}, skipping")
pytest.skip(allow_module_level=True)
SDLoraManager(sdxl).load(lora_weights, scale=lora_scale)
prompt = "a cute cat"
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
sd15.set_num_inference_steps(n_steps)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps)
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=1.0).inject()
manual_seed(seed=seed)
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
for step in sd15.steps:
x = sd15(
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=guidance_scale,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98)
@no_grad()
def test_diffusion_lora_twice(
sd15_std: StableDiffusion_1,
lora_data_pokemon: tuple[Image.Image, Path],
test_device: torch.device,
):
sd15 = sd15_std
n_steps = 30
expected_image, lora_weights_path = lora_data_pokemon
if not lora_weights_path.is_file():
warn(f"could not find weights at {lora_weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "a cute cat"
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sd15.set_num_inference_steps(n_steps)
# The same LoRA is used twice which is not a common use case: this is purely for testing purpose
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=0.4).inject()
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=0.6).inject()
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)

View file

@ -48,6 +48,7 @@ Special cases:
- `expected_restart.png`
- `expected_freeu.png`
- `expected_dropy_slime_9752.png`
- `expected_sdxl_dpo_lora.png`
## Other images

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB