mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
refactor Lora LoraAdapter and the latent_diffusion/lora file
This commit is contained in:
parent
dd87b9706e
commit
a1f50f3f9d
10
README.md
10
README.md
|
@ -158,16 +158,6 @@ You should get:
|
|||
|
||||
![dropy slime output](https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/dropy_slime_9752.png)
|
||||
|
||||
### Training
|
||||
|
||||
Refiners has a built-in training utils library and provides scripts that can be used as a starting point.
|
||||
|
||||
E.g. to train a LoRA on top of Stable Diffusion, copy and edit `configs/finetune-lora.toml` to suit your needs and launch the training as follows:
|
||||
|
||||
```bash
|
||||
python scripts/training/finetune-ldm-lora.py configs/finetune-lora.toml
|
||||
```
|
||||
|
||||
## Adapter Zoo
|
||||
|
||||
For now, given [finegrain](https://finegrain.ai)'s mission, we are focusing on image edition tasks. We support:
|
||||
|
|
|
@ -1,68 +0,0 @@
|
|||
[wandb]
|
||||
mode = "offline" # "online", "offline", "disabled"
|
||||
entity = "acme"
|
||||
project = "test-lora-training"
|
||||
|
||||
[models]
|
||||
unet = {checkpoint = "/path/to/stable-diffusion-1-5/unet.safetensors"}
|
||||
text_encoder = {checkpoint = "/path/to/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"}
|
||||
lda = {checkpoint = "/path/to/stable-diffusion-1-5/lda.safetensors"}
|
||||
|
||||
[latent_diffusion]
|
||||
unconditional_sampling_probability = 0.05
|
||||
offset_noise = 0.1
|
||||
|
||||
[lora]
|
||||
rank = 16
|
||||
trigger_phrase = "a spsh photo,"
|
||||
use_only_trigger_probability = 1.0
|
||||
unet_targets = ["CrossAttentionBlock2d"]
|
||||
text_encoder_targets = ["TransformerLayer"]
|
||||
lda_targets = []
|
||||
|
||||
[training]
|
||||
duration = "1000:epoch"
|
||||
seed = 0
|
||||
gpu_index = 0
|
||||
batch_size = 4
|
||||
gradient_accumulation = "4:step"
|
||||
clip_grad_norm = 1.0
|
||||
# clip_grad_value = 1.0
|
||||
evaluation_interval = "5:epoch"
|
||||
evaluation_seed = 1
|
||||
|
||||
|
||||
[optimizer]
|
||||
optimizer = "Prodigy" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit"
|
||||
learning_rate = 1
|
||||
betas = [0.9, 0.999]
|
||||
eps = 1e-8
|
||||
weight_decay = 1e-2
|
||||
|
||||
[scheduler]
|
||||
scheduler_type = "ConstantLR"
|
||||
update_interval = "1:step"
|
||||
warmup = "500:step"
|
||||
|
||||
|
||||
[dropout]
|
||||
dropout_probability = 0.2
|
||||
use_gyro_dropout = false
|
||||
|
||||
[dataset]
|
||||
hf_repo = "acme/images"
|
||||
revision = "main"
|
||||
|
||||
[checkpointing]
|
||||
# save_folder = "/path/to/ckpts"
|
||||
save_interval = "1:step"
|
||||
|
||||
[test_diffusion]
|
||||
num_inference_steps = 30
|
||||
use_short_prompts = false
|
||||
prompts = [
|
||||
"a cute cat",
|
||||
"a cute dog",
|
||||
"a cute bird",
|
||||
"a cute horse",
|
||||
]
|
|
@ -1,149 +0,0 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline # type: ignore
|
||||
from torch import Tensor
|
||||
from torch.nn import Parameter as TorchParameter
|
||||
from torch.nn.init import zeros_
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.adapters.lora import Lora, LoraAdapter
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.utils import no_grad, save_to_safetensors
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet
|
||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
|
||||
|
||||
|
||||
def get_weight(linear: fl.Linear) -> torch.Tensor:
|
||||
assert linear.bias is None
|
||||
return linear.state_dict()["weight"]
|
||||
|
||||
|
||||
def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, torch.Tensor]:
|
||||
weights: list[torch.Tensor] = []
|
||||
for lora in module.layers(layer_type=Lora):
|
||||
linears = list(lora.layers(layer_type=fl.Linear))
|
||||
assert len(linears) == 2
|
||||
weights.extend((get_weight(linear=linears[1]), get_weight(linear=linears[0]))) # aka (up_weight, down_weight)
|
||||
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(iterable=weights)}
|
||||
|
||||
|
||||
class Args(argparse.Namespace):
|
||||
source_path: str
|
||||
base_model: str
|
||||
output_file: str
|
||||
verbose: bool
|
||||
|
||||
|
||||
@no_grad()
|
||||
def process(args: Args) -> None:
|
||||
diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
diffusers_sd = DiffusionPipeline.from_pretrained( # type: ignore
|
||||
pretrained_model_name_or_path=args.base_model,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore
|
||||
|
||||
refiners_model = SD1UNet(in_channels=4)
|
||||
target = LoraTarget.CrossAttention
|
||||
metadata = {"unet_targets": "CrossAttentionBlock2d"}
|
||||
rank = diffusers_state_dict[
|
||||
"mid_block.attentions.0.transformer_blocks.0.attn1.processor.to_q_lora.down.weight"
|
||||
].shape[0]
|
||||
|
||||
x = torch.randn(1, 4, 32, 32)
|
||||
timestep = torch.tensor(data=[0])
|
||||
clip_text_embeddings = torch.randn(1, 77, 768)
|
||||
|
||||
refiners_model.set_timestep(timestep=timestep)
|
||||
refiners_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
||||
refiners_args = (x,)
|
||||
|
||||
diffusers_args = (x, timestep, clip_text_embeddings)
|
||||
|
||||
converter = ModelConverter(
|
||||
source_model=refiners_model, target_model=diffusers_model, skip_output_check=True, verbose=args.verbose
|
||||
)
|
||||
if not converter.run(
|
||||
source_args=refiners_args,
|
||||
target_args=diffusers_args,
|
||||
):
|
||||
raise RuntimeError("Model conversion failed")
|
||||
|
||||
diffusers_to_refiners = converter.get_mapping()
|
||||
|
||||
LoraAdapter[SD1UNet](refiners_model, sub_targets=lora_targets(refiners_model, target), rank=rank).inject()
|
||||
|
||||
for layer in refiners_model.layers(layer_type=Lora):
|
||||
zeros_(tensor=layer.Linear_1.weight)
|
||||
|
||||
targets = {k.split("_lora.")[0] for k in diffusers_state_dict.keys()}
|
||||
for target_k in targets:
|
||||
k_p, k_s = target_k.split(".processor.")
|
||||
r = [v for k, v in diffusers_to_refiners.items() if k.startswith(f"{k_p}.{k_s}")]
|
||||
assert len(r) == 1
|
||||
orig_k = r[0]
|
||||
orig_path = orig_k.split(sep=".")
|
||||
p = refiners_model
|
||||
for seg in orig_path[:-1]:
|
||||
p = p[seg]
|
||||
assert isinstance(p, fl.Chain)
|
||||
last_seg = (
|
||||
"SingleLoraAdapter"
|
||||
if orig_path[-1] == "Linear"
|
||||
else f"SingleLoraAdapter_{orig_path[-1].removeprefix('Linear_')}"
|
||||
)
|
||||
p_down = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.down.weight"])
|
||||
p_up = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.up.weight"])
|
||||
p[last_seg].Lora.load_weights(p_down, p_up)
|
||||
|
||||
state_dict = build_loras_safetensors(module=refiners_model, key_prefix="unet.")
|
||||
assert len(state_dict) == 320
|
||||
save_to_safetensors(path=args.output_path, tensors=state_dict, metadata=metadata)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Convert LoRAs saved using the diffusers library to refiners format.")
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
dest="source_path",
|
||||
required=True,
|
||||
help="Source file path (.bin|safetensors) containing the LoRAs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-model",
|
||||
type=str,
|
||||
required=False,
|
||||
default="runwayml/stable-diffusion-v1-5",
|
||||
help="Base model, used for the UNet structure. Default: runwayml/stable-diffusion-v1-5",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
required=False,
|
||||
default=None,
|
||||
help=(
|
||||
"Output file path (.safetensors) for converted LoRAs. If not provided, the output path will be the same as"
|
||||
" the source path."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
dest="verbose",
|
||||
default=False,
|
||||
help="Use this flag to print verbose output during conversion.",
|
||||
)
|
||||
args = parser.parse_args(namespace=Args())
|
||||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).stem}-refiners.safetensors"
|
||||
process(args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,124 +0,0 @@
|
|||
import argparse
|
||||
from functools import partial
|
||||
|
||||
from convert_diffusers_unet import Args as UnetConversionArgs, setup_converter as convert_unet
|
||||
from convert_transformers_clip_text_model import (
|
||||
Args as TextEncoderConversionArgs,
|
||||
setup_converter as convert_text_encoder,
|
||||
)
|
||||
from torch import Tensor
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.utils import (
|
||||
load_from_safetensors,
|
||||
load_metadata_from_safetensors,
|
||||
save_to_safetensors,
|
||||
)
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet
|
||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget
|
||||
|
||||
|
||||
def get_unet_mapping(source_path: str) -> dict[str, str]:
|
||||
args = UnetConversionArgs(source_path=source_path, verbose=False)
|
||||
return convert_unet(args=args).get_mapping()
|
||||
|
||||
|
||||
def get_text_encoder_mapping(source_path: str) -> dict[str, str]:
|
||||
args = TextEncoderConversionArgs(source_path=source_path, subfolder="text_encoder", verbose=False)
|
||||
return convert_text_encoder(
|
||||
args=args,
|
||||
).get_mapping()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Converts a refiner's LoRA weights to SD-WebUI's LoRA weights")
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input-file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the input file with refiner's LoRA weights (safetensors format)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-file",
|
||||
type=str,
|
||||
default="sdwebui_loras.safetensors",
|
||||
help="Path to the output file with sd-webui's LoRA weights (safetensors format)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sd15",
|
||||
type=str,
|
||||
required=False,
|
||||
default="runwayml/stable-diffusion-v1-5",
|
||||
help="Path (preferred) or repository ID of Stable Diffusion 1.5 model (Hugging Face diffusers format)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
metadata = load_metadata_from_safetensors(path=args.input_file)
|
||||
assert metadata is not None, f"Could not load metadata from {args.input_file}"
|
||||
tensors = load_from_safetensors(path=args.input_file)
|
||||
|
||||
state_dict: dict[str, Tensor] = {}
|
||||
|
||||
for meta_key, meta_value in metadata.items():
|
||||
match meta_key:
|
||||
case "unet_targets":
|
||||
model = SD1UNet(in_channels=4)
|
||||
create_mapping = partial(get_unet_mapping, source_path=args.sd15)
|
||||
key_prefix = "unet."
|
||||
lora_prefix = "lora_unet_"
|
||||
case "text_encoder_targets":
|
||||
model = CLIPTextEncoderL()
|
||||
create_mapping = partial(get_text_encoder_mapping, source_path=args.sd15)
|
||||
key_prefix = "text_encoder."
|
||||
lora_prefix = "lora_te_"
|
||||
case "lda_targets":
|
||||
raise ValueError("SD-WebUI does not support LoRA for the auto-encoder")
|
||||
case _:
|
||||
raise ValueError(f"Unexpected key in checkpoint metadata: {meta_key}")
|
||||
|
||||
submodule_to_key: dict[fl.Module, str] = {}
|
||||
for name, submodule in model.named_modules():
|
||||
submodule_to_key[submodule] = name
|
||||
|
||||
# SD-WebUI expects LoRA state dicts with keys derived from the diffusers format, e.g.:
|
||||
#
|
||||
# lora_unet_down_blocks_0_attentions_0_proj_in.alpha
|
||||
# lora_unet_down_blocks_0_attentions_0_proj_in.lora_down.weight
|
||||
# lora_unet_down_blocks_0_attentions_0_proj_in.lora_up.weight
|
||||
# ...
|
||||
#
|
||||
# Internally SD-WebUI has some logic[1] to convert such keys into the CompVis format. See
|
||||
# `convert_diffusers_name_to_compvis` for more details.
|
||||
#
|
||||
# [1]: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/394ffa7/extensions-builtin/Lora/lora.py#L158-L225
|
||||
|
||||
refiners_to_diffusers = create_mapping()
|
||||
assert refiners_to_diffusers is not None
|
||||
|
||||
# Compute the corresponding diffusers' keys where LoRA layers must be applied
|
||||
lora_injection_points: list[str] = [
|
||||
refiners_to_diffusers[submodule_to_key[linear]]
|
||||
for target in [LoraTarget(t) for t in meta_value.split(sep=",")]
|
||||
for layer in model.layers(layer_type=target.get_class())
|
||||
for linear in layer.layers(layer_type=fl.Linear)
|
||||
]
|
||||
|
||||
lora_weights = [w for w in [tensors[k] for k in sorted(tensors) if k.startswith(key_prefix)]]
|
||||
assert len(lora_injection_points) == len(lora_weights) // 2
|
||||
|
||||
# Map LoRA weights to each key using SD-WebUI conventions (proper prefix and suffix, underscores)
|
||||
for i, diffusers_key in enumerate(iterable=lora_injection_points):
|
||||
lora_key = lora_prefix + diffusers_key.replace(".", "_")
|
||||
# Note: no ".alpha" weights (those are used to scale the LoRA by alpha/rank). Refiners uses a scale = 1.0
|
||||
# by default (see `lora_calc_updown` in SD-WebUI for more details)
|
||||
state_dict[lora_key + ".lora_up.weight"] = lora_weights[2 * i]
|
||||
state_dict[lora_key + ".lora_down.weight"] = lora_weights[2 * i + 1]
|
||||
|
||||
assert state_dict
|
||||
save_to_safetensors(path=args.output_file, tensors=state_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -229,10 +229,15 @@ def download_vae_ft_mse():
|
|||
)
|
||||
|
||||
|
||||
def download_lora():
|
||||
dest_folder = os.path.join(test_weights_dir, "pcuenq", "pokemon-lora")
|
||||
def download_loras():
|
||||
dest_folder = os.path.join(test_weights_dir, "loras", "pokemon-lora")
|
||||
download_file("https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin", dest_folder)
|
||||
|
||||
dest_folder = os.path.join(test_weights_dir, "loras", "dpo-lora")
|
||||
download_file(
|
||||
"https://huggingface.co/radames/sdxl-DPO-LoRA/resolve/main/pytorch_lora_weights.safetensors", dest_folder
|
||||
)
|
||||
|
||||
|
||||
def download_preprocessors():
|
||||
dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings")
|
||||
|
@ -454,17 +459,6 @@ def convert_vae_fp16_fix():
|
|||
)
|
||||
|
||||
|
||||
def convert_lora():
|
||||
os.makedirs("tests/weights/loras", exist_ok=True)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_lora.py",
|
||||
"tests/weights/pcuenq/pokemon-lora/pytorch_lora_weights.bin",
|
||||
"tests/weights/loras/pcuenq_pokemon_lora.safetensors",
|
||||
additional_args=["--base-model", "tests/weights/runwayml/stable-diffusion-v1-5"],
|
||||
expected_hash="a9d7e08e",
|
||||
)
|
||||
|
||||
|
||||
def convert_preprocessors():
|
||||
subprocess.run(
|
||||
[
|
||||
|
@ -632,7 +626,7 @@ def download_all():
|
|||
download_sdxl("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
download_vae_ft_mse()
|
||||
download_vae_fp16_fix()
|
||||
download_lora()
|
||||
download_loras()
|
||||
download_preprocessors()
|
||||
download_controlnet()
|
||||
download_unclip()
|
||||
|
@ -647,7 +641,7 @@ def convert_all():
|
|||
convert_sdxl()
|
||||
convert_vae_ft_mse()
|
||||
convert_vae_fp16_fix()
|
||||
convert_lora()
|
||||
# Note: no convert loras: this is done at runtime by `SDLoraManager`
|
||||
convert_preprocessors()
|
||||
convert_controlnet()
|
||||
convert_unclip()
|
||||
|
|
|
@ -1,123 +0,0 @@
|
|||
import random
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.utils import save_to_safetensors
|
||||
from refiners.foundationals.latent_diffusion.lora import MODELS, LoraAdapter, LoraTarget, lora_targets
|
||||
from refiners.training_utils.callback import Callback
|
||||
from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig
|
||||
from refiners.training_utils.latent_diffusion import (
|
||||
FinetuneLatentDiffusionConfig,
|
||||
LatentDiffusionConfig,
|
||||
LatentDiffusionTrainer,
|
||||
TextEmbeddingLatentsBatch,
|
||||
TextEmbeddingLatentsDataset,
|
||||
)
|
||||
|
||||
|
||||
class LoraConfig(BaseModel):
|
||||
rank: int = 32
|
||||
trigger_phrase: str = ""
|
||||
use_only_trigger_probability: float = 0.0
|
||||
unet_targets: list[LoraTarget]
|
||||
text_encoder_targets: list[LoraTarget]
|
||||
lda_targets: list[LoraTarget]
|
||||
|
||||
|
||||
class TriggerPhraseDataset(TextEmbeddingLatentsDataset):
|
||||
def __init__(
|
||||
self,
|
||||
trainer: "LoraLatentDiffusionTrainer",
|
||||
) -> None:
|
||||
super().__init__(trainer=trainer)
|
||||
self.trigger_phrase = trainer.config.lora.trigger_phrase
|
||||
self.use_only_trigger_probability = trainer.config.lora.use_only_trigger_probability
|
||||
logger.info(f"Trigger phrase: {self.trigger_phrase}")
|
||||
|
||||
def process_caption(self, caption: str) -> str:
|
||||
caption = super().process_caption(caption=caption)
|
||||
if self.trigger_phrase:
|
||||
caption = (
|
||||
f"{self.trigger_phrase} {caption}"
|
||||
if random.random() < self.use_only_trigger_probability
|
||||
else self.trigger_phrase
|
||||
)
|
||||
return caption
|
||||
|
||||
|
||||
class LoraLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
|
||||
dataset: HuggingfaceDatasetConfig
|
||||
latent_diffusion: LatentDiffusionConfig
|
||||
lora: LoraConfig
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Pydantic v2 does post init differently, so we need to override this method too."""
|
||||
logger.info("Freezing models to train only the loras.")
|
||||
self.models["unet"].train = False
|
||||
self.models["text_encoder"].train = False
|
||||
self.models["lda"].train = False
|
||||
|
||||
|
||||
class LoraLatentDiffusionTrainer(LatentDiffusionTrainer[LoraLatentDiffusionConfig]):
|
||||
def __init__(
|
||||
self,
|
||||
config: LoraLatentDiffusionConfig,
|
||||
callbacks: "list[Callback[Any]] | None" = None,
|
||||
) -> None:
|
||||
super().__init__(config=config, callbacks=callbacks)
|
||||
self.callbacks.extend((LoadLoras(), SaveLoras()))
|
||||
|
||||
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
|
||||
return TriggerPhraseDataset(trainer=self)
|
||||
|
||||
|
||||
class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
|
||||
def on_train_begin(self, trainer: LoraLatentDiffusionTrainer) -> None:
|
||||
lora_config = trainer.config.lora
|
||||
|
||||
for model_name in MODELS:
|
||||
model = getattr(trainer, model_name)
|
||||
model_targets: list[LoraTarget] = getattr(lora_config, f"{model_name}_targets")
|
||||
adapter = LoraAdapter[type(model)](
|
||||
model,
|
||||
sub_targets=lora_targets(model, model_targets),
|
||||
rank=lora_config.rank,
|
||||
)
|
||||
for sub_adapter, _ in adapter.sub_adapters:
|
||||
for linear in sub_adapter.Lora.layers(fl.Linear):
|
||||
linear.requires_grad_(requires_grad=True)
|
||||
adapter.inject()
|
||||
|
||||
|
||||
class SaveLoras(Callback[LoraLatentDiffusionTrainer]):
|
||||
def on_checkpoint_save(self, trainer: LoraLatentDiffusionTrainer) -> None:
|
||||
tensors: dict[str, Tensor] = {}
|
||||
metadata: dict[str, str] = {}
|
||||
|
||||
for model_name in MODELS:
|
||||
model = getattr(trainer, model_name)
|
||||
adapter = model.parent
|
||||
tensors |= {f"{model_name}.{i:03d}": w for i, w in enumerate(adapter.weights)}
|
||||
metadata |= {f"{model_name}_targets": ",".join(adapter.sub_targets)}
|
||||
|
||||
save_to_safetensors(
|
||||
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors",
|
||||
tensors=tensors,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
config_path = sys.argv[1]
|
||||
config = LoraLatentDiffusionConfig.load_from_toml(
|
||||
toml_path=config_path,
|
||||
)
|
||||
trainer = LoraLatentDiffusionTrainer(config=config)
|
||||
trainer.train()
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Any, Generic, Iterable, TypeVar
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
from torch.nn import Parameter as TorchParameter
|
||||
|
@ -6,125 +7,259 @@ from torch.nn.init import normal_, zeros_
|
|||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
|
||||
T = TypeVar("T", bound=fl.Chain)
|
||||
TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PEP 673)
|
||||
from refiners.fluxion.layers.chain import Chain
|
||||
|
||||
|
||||
class Lora(fl.Chain):
|
||||
class Lora(fl.Chain, ABC):
|
||||
def __init__(
|
||||
self,
|
||||
rank: int = 16,
|
||||
scale: float = 1.0,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.rank = rank
|
||||
self._scale = scale
|
||||
|
||||
super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale))
|
||||
|
||||
normal_(tensor=self.down.weight, std=1 / self.rank)
|
||||
zeros_(tensor=self.up.weight)
|
||||
|
||||
@abstractmethod
|
||||
def lora_layers(
|
||||
self, device: Device | str | None = None, dtype: DType | None = None
|
||||
) -> tuple[fl.WeightedModule, fl.WeightedModule]:
|
||||
...
|
||||
|
||||
@property
|
||||
def down(self) -> fl.WeightedModule:
|
||||
down_layer = self[0]
|
||||
assert isinstance(down_layer, fl.WeightedModule)
|
||||
return down_layer
|
||||
|
||||
@property
|
||||
def up(self) -> fl.WeightedModule:
|
||||
up_layer = self[1]
|
||||
assert isinstance(up_layer, fl.WeightedModule)
|
||||
return up_layer
|
||||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
return self._scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, value: float) -> None:
|
||||
self._scale = value
|
||||
self.ensure_find(fl.Multiply).scale = value
|
||||
|
||||
@classmethod
|
||||
def from_weights(
|
||||
cls,
|
||||
down: Tensor,
|
||||
up: Tensor,
|
||||
) -> "Lora":
|
||||
match (up.ndim, down.ndim):
|
||||
case (2, 2):
|
||||
return LinearLora.from_weights(up=up, down=down)
|
||||
case (4, 4):
|
||||
return Conv2dLora.from_weights(up=up, down=down)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, state_dict: dict[str, Tensor], /) -> dict[str, "Lora"]:
|
||||
"""
|
||||
Create a dictionary of LoRA layers from a state dict.
|
||||
|
||||
Expects the state dict to be a succession of down and up weights.
|
||||
"""
|
||||
state_dict = {k: v for k, v in state_dict.items() if ".weight" in k}
|
||||
loras: dict[str, Lora] = {}
|
||||
for down_key, down_tensor, up_tensor in zip(
|
||||
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
|
||||
):
|
||||
key = ".".join(down_key.split(".")[:-2])
|
||||
loras[key] = cls.from_weights(down=down_tensor, up=up_tensor)
|
||||
return loras
|
||||
|
||||
@abstractmethod
|
||||
def auto_attach(self, target: fl.Chain, exclude: list[str] | None = None) -> Any:
|
||||
...
|
||||
|
||||
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
||||
assert down_weight.shape == self.down.weight.shape
|
||||
assert up_weight.shape == self.up.weight.shape
|
||||
self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
|
||||
self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
|
||||
|
||||
|
||||
class LinearLora(Lora):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
rank: int = 16,
|
||||
scale: float = 1.0,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.rank = rank
|
||||
self.scale: float = 1.0
|
||||
|
||||
super().__init__(
|
||||
fl.Linear(in_features=in_features, out_features=rank, bias=False, device=device, dtype=dtype),
|
||||
fl.Linear(in_features=rank, out_features=out_features, bias=False, device=device, dtype=dtype),
|
||||
fl.Lambda(func=self.scale_outputs),
|
||||
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
|
||||
|
||||
@classmethod
|
||||
def from_weights(
|
||||
cls,
|
||||
down: Tensor,
|
||||
up: Tensor,
|
||||
) -> "LinearLora":
|
||||
assert up.ndim == 2 and down.ndim == 2
|
||||
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
|
||||
lora = cls(
|
||||
in_features=down.shape[1], out_features=up.shape[0], rank=down.shape[0], device=up.device, dtype=up.dtype
|
||||
)
|
||||
lora.load_weights(down_weight=down, up_weight=up)
|
||||
return lora
|
||||
|
||||
normal_(tensor=self.Linear_1.weight, std=1 / self.rank)
|
||||
zeros_(tensor=self.Linear_2.weight)
|
||||
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
|
||||
for layer, parent in target.walk(fl.Linear):
|
||||
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
|
||||
continue
|
||||
|
||||
def scale_outputs(self, x: Tensor) -> Tensor:
|
||||
return x * self.scale
|
||||
if exclude is not None and any(
|
||||
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
|
||||
):
|
||||
continue
|
||||
|
||||
def set_scale(self, scale: float) -> None:
|
||||
self.scale = scale
|
||||
if layer.in_features == self.in_features and layer.out_features == self.out_features:
|
||||
return LoraAdapter(target=layer, lora=self), parent
|
||||
|
||||
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
||||
self.Linear_1.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
|
||||
self.Linear_2.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
|
||||
|
||||
@property
|
||||
def up_weight(self) -> Tensor:
|
||||
return self.Linear_2.weight.data
|
||||
|
||||
@property
|
||||
def down_weight(self) -> Tensor:
|
||||
return self.Linear_1.weight.data
|
||||
|
||||
|
||||
class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]):
|
||||
def __init__(
|
||||
self,
|
||||
target: fl.Linear,
|
||||
rank: int = 16,
|
||||
scale: float = 1.0,
|
||||
) -> None:
|
||||
self.in_features = target.in_features
|
||||
self.out_features = target.out_features
|
||||
self.rank = rank
|
||||
self.scale = scale
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(
|
||||
target,
|
||||
Lora(
|
||||
in_features=target.in_features,
|
||||
out_features=target.out_features,
|
||||
rank=rank,
|
||||
device=target.device,
|
||||
dtype=target.dtype,
|
||||
def lora_layers(
|
||||
self, device: Device | str | None = None, dtype: DType | None = None
|
||||
) -> tuple[fl.Linear, fl.Linear]:
|
||||
return (
|
||||
fl.Linear(
|
||||
in_features=self.in_features,
|
||||
out_features=self.rank,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.Linear(
|
||||
in_features=self.rank,
|
||||
out_features=self.out_features,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
self.Lora.set_scale(scale=scale)
|
||||
|
||||
|
||||
class LoraAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||
class Conv2dLora(Lora):
|
||||
def __init__(
|
||||
self,
|
||||
target: T,
|
||||
sub_targets: Iterable[tuple[fl.Linear, fl.Chain]],
|
||||
rank: int | None = None,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
rank: int = 16,
|
||||
scale: float = 1.0,
|
||||
weights: list[Tensor] | None = None,
|
||||
kernel_size: tuple[int, int] = (1, 3),
|
||||
stride: tuple[int, int] = (1, 1),
|
||||
padding: tuple[int, int] = (0, 1),
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
|
||||
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
|
||||
|
||||
@classmethod
|
||||
def from_weights(
|
||||
cls,
|
||||
down: Tensor,
|
||||
up: Tensor,
|
||||
) -> "Conv2dLora":
|
||||
assert up.ndim == 4 and down.ndim == 4
|
||||
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
|
||||
down_kernel_size, up_kernel_size = down.shape[2], up.shape[2]
|
||||
down_padding = 1 if down_kernel_size == 3 else 0
|
||||
up_padding = 1 if up_kernel_size == 3 else 0
|
||||
lora = cls(
|
||||
in_channels=down.shape[1],
|
||||
out_channels=up.shape[0],
|
||||
rank=down.shape[0],
|
||||
kernel_size=(down_kernel_size, up_kernel_size),
|
||||
padding=(down_padding, up_padding),
|
||||
device=up.device,
|
||||
dtype=up.dtype,
|
||||
)
|
||||
lora.load_weights(down_weight=down, up_weight=up)
|
||||
return lora
|
||||
|
||||
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
|
||||
for layer, parent in target.walk(fl.Conv2d):
|
||||
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
|
||||
continue
|
||||
|
||||
if exclude is not None and any(
|
||||
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
|
||||
):
|
||||
continue
|
||||
|
||||
if layer.in_channels == self.in_channels and layer.out_channels == self.out_channels:
|
||||
if layer.stride != (self.stride[0], self.stride[0]):
|
||||
self.down.stride = layer.stride
|
||||
|
||||
return LoraAdapter(
|
||||
target=layer,
|
||||
lora=self,
|
||||
), parent
|
||||
|
||||
def lora_layers(
|
||||
self, device: Device | str | None = None, dtype: DType | None = None
|
||||
) -> tuple[fl.Conv2d, fl.Conv2d]:
|
||||
return (
|
||||
fl.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.rank,
|
||||
kernel_size=self.kernel_size[0],
|
||||
stride=self.stride[0],
|
||||
padding=self.padding[0],
|
||||
use_bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.Conv2d(
|
||||
in_channels=self.rank,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=self.kernel_size[1],
|
||||
stride=self.stride[1],
|
||||
padding=self.padding[1],
|
||||
use_bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
||||
def __init__(self, target: fl.WeightedModule, lora: Lora) -> None:
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
|
||||
if weights is not None:
|
||||
assert len(weights) % 2 == 0
|
||||
weights_rank = weights[0].shape[1]
|
||||
if rank is None:
|
||||
rank = weights_rank
|
||||
else:
|
||||
assert rank == weights_rank
|
||||
|
||||
assert rank is not None, "either pass a rank or weights"
|
||||
|
||||
self.sub_targets = sub_targets
|
||||
self.sub_adapters: list[tuple[SingleLoraAdapter, fl.Chain]] = []
|
||||
|
||||
for linear, parent in self.sub_targets:
|
||||
self.sub_adapters.append((SingleLoraAdapter(target=linear, rank=rank, scale=scale), parent))
|
||||
|
||||
if weights is not None:
|
||||
assert len(self.sub_adapters) == (len(weights) // 2)
|
||||
for i, (adapter, _) in enumerate(self.sub_adapters):
|
||||
lora = adapter.Lora
|
||||
assert (
|
||||
lora.rank == weights[i * 2].shape[1]
|
||||
), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}"
|
||||
adapter.Lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1])
|
||||
|
||||
def inject(self: TLoraAdapter, parent: fl.Chain | None = None) -> TLoraAdapter:
|
||||
for adapter, adapter_parent in self.sub_adapters:
|
||||
adapter.inject(adapter_parent)
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
for adapter, _ in self.sub_adapters:
|
||||
adapter.eject()
|
||||
super().eject()
|
||||
super().__init__(target, lora)
|
||||
|
||||
@property
|
||||
def weights(self) -> list[Tensor]:
|
||||
return [w for adapter, _ in self.sub_adapters for w in [adapter.Lora.up_weight, adapter.Lora.down_weight]]
|
||||
def lora(self) -> Lora:
|
||||
return self.ensure_find(Lora)
|
||||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
return self.lora.scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, value: float) -> None:
|
||||
self.lora.scale = value
|
||||
|
|
|
@ -1,146 +1,140 @@
|
|||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator
|
||||
from warnings import warn
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
from refiners.fluxion.adapters.lora import Lora, LoraAdapter
|
||||
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
|
||||
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
|
||||
from refiners.foundationals.latent_diffusion import (
|
||||
CLIPTextEncoderL,
|
||||
LatentDiffusionAutoencoder,
|
||||
SD1UNet,
|
||||
StableDiffusion_1,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
|
||||
|
||||
MODELS = ["unet", "text_encoder", "lda"]
|
||||
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||
|
||||
|
||||
class LoraTarget(str, Enum):
|
||||
Self = "self"
|
||||
Attention = "Attention"
|
||||
SelfAttention = "SelfAttention"
|
||||
CrossAttention = "CrossAttentionBlock2d"
|
||||
FeedForward = "FeedForward"
|
||||
TransformerLayer = "TransformerLayer"
|
||||
|
||||
def get_class(self) -> type[fl.Chain]:
|
||||
match self:
|
||||
case LoraTarget.Self:
|
||||
return fl.Chain
|
||||
case LoraTarget.Attention:
|
||||
return fl.Attention
|
||||
case LoraTarget.SelfAttention:
|
||||
return fl.SelfAttention
|
||||
case LoraTarget.CrossAttention:
|
||||
return CrossAttentionBlock2d
|
||||
case LoraTarget.FeedForward:
|
||||
return FeedForward
|
||||
case LoraTarget.TransformerLayer:
|
||||
return TransformerLayer
|
||||
|
||||
|
||||
def _predicate(k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
|
||||
def f(m: fl.Module, _: fl.Chain) -> bool:
|
||||
if isinstance(m, Lora): # do not adapt other LoRAs
|
||||
raise StopIteration
|
||||
if isinstance(m, Controlnet): # do not adapt Controlnet linears
|
||||
raise StopIteration
|
||||
return isinstance(m, k)
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def _iter_linears(module: fl.Chain) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
||||
for m, p in module.walk(_predicate(fl.Linear)):
|
||||
assert isinstance(m, fl.Linear)
|
||||
yield (m, p)
|
||||
|
||||
|
||||
def lora_targets(
|
||||
module: fl.Chain,
|
||||
target: LoraTarget | list[LoraTarget],
|
||||
) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
||||
if isinstance(target, list):
|
||||
for t in target:
|
||||
yield from lora_targets(module, t)
|
||||
return
|
||||
|
||||
if target == LoraTarget.Self:
|
||||
yield from _iter_linears(module)
|
||||
return
|
||||
|
||||
for layer, _ in module.walk(_predicate(target.get_class())):
|
||||
assert isinstance(layer, fl.Chain)
|
||||
yield from _iter_linears(layer)
|
||||
|
||||
|
||||
class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
|
||||
metadata: dict[str, str] | None
|
||||
tensors: dict[str, Tensor]
|
||||
|
||||
class SDLoraManager:
|
||||
def __init__(
|
||||
self,
|
||||
target: StableDiffusion_1,
|
||||
sub_targets: dict[str, list[LoraTarget]],
|
||||
target: LatentDiffusionModel,
|
||||
) -> None:
|
||||
self.target = target
|
||||
|
||||
@property
|
||||
def unet(self) -> fl.Chain:
|
||||
unet = self.target.unet
|
||||
assert isinstance(unet, fl.Chain)
|
||||
return unet
|
||||
|
||||
@property
|
||||
def clip_text_encoder(self) -> fl.Chain:
|
||||
clip_text_encoder = self.target.clip_text_encoder
|
||||
assert isinstance(clip_text_encoder, fl.Chain)
|
||||
return clip_text_encoder
|
||||
|
||||
def load(
|
||||
self,
|
||||
tensors: dict[str, Tensor],
|
||||
/,
|
||||
scale: float = 1.0,
|
||||
weights: dict[str, Tensor] | None = None,
|
||||
):
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
) -> None:
|
||||
"""Load the LoRA weights from a dictionary of tensors.
|
||||
|
||||
self.sub_adapters: list[LoraAdapter[SD1UNet | CLIPTextEncoderL | LatentDiffusionAutoencoder]] = []
|
||||
|
||||
for model_name in MODELS:
|
||||
if not (model_targets := sub_targets.get(model_name, [])):
|
||||
continue
|
||||
model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name)
|
||||
|
||||
lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None
|
||||
self.sub_adapters.append(
|
||||
LoraAdapter[type(model)](
|
||||
model,
|
||||
sub_targets=lora_targets(model, model_targets),
|
||||
scale=scale,
|
||||
weights=lora_weights,
|
||||
)
|
||||
Expects the keys to be in the commonly found formats on CivitAI's hub.
|
||||
"""
|
||||
assert len(self.lora_adapters) == 0, "Loras already loaded"
|
||||
loras = Lora.from_dict(
|
||||
{key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()}
|
||||
)
|
||||
loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}
|
||||
|
||||
@classmethod
|
||||
def from_safetensors(
|
||||
cls,
|
||||
target: StableDiffusion_1,
|
||||
checkpoint_path: Path | str,
|
||||
scale: float = 1.0,
|
||||
):
|
||||
metadata = load_metadata_from_safetensors(checkpoint_path)
|
||||
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
|
||||
tensors = load_from_safetensors(checkpoint_path, device=target.device)
|
||||
# if no key contains "unet" or "text", assume all keys are for the unet
|
||||
if not "unet" in loras and not "text" in loras:
|
||||
loras = {f"unet_{key}": loras[key] for key in loras.keys()}
|
||||
|
||||
sub_targets: dict[str, list[LoraTarget]] = {}
|
||||
for model_name in MODELS:
|
||||
if not (v := metadata.get(f"{model_name}_targets", "")):
|
||||
continue
|
||||
sub_targets[model_name] = [LoraTarget(x) for x in v.split(",")]
|
||||
self.load_unet(loras)
|
||||
self.load_text_encoder(loras)
|
||||
|
||||
return cls(
|
||||
target,
|
||||
sub_targets,
|
||||
scale=scale,
|
||||
weights=tensors,
|
||||
)
|
||||
self.scale = scale
|
||||
|
||||
def inject(self: "SD1LoraAdapter", parent: fl.Chain | None = None) -> "SD1LoraAdapter":
|
||||
for adapter in self.sub_adapters:
|
||||
adapter.inject()
|
||||
return super().inject(parent)
|
||||
def load_text_encoder(self, loras: dict[str, Lora], /) -> None:
|
||||
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
||||
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
|
||||
|
||||
def eject(self) -> None:
|
||||
for adapter in self.sub_adapters:
|
||||
adapter.eject()
|
||||
super().eject()
|
||||
def load_unet(self, loras: dict[str, Lora], /) -> None:
|
||||
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
|
||||
exclude: list[str] = []
|
||||
exclude = [
|
||||
self.unet_exclusions[exclusion]
|
||||
for exclusion in self.unet_exclusions
|
||||
if all([exclusion not in key for key in unet_loras.keys()])
|
||||
]
|
||||
SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude)
|
||||
|
||||
def unload(self) -> None:
|
||||
for lora_adapter in self.lora_adapters:
|
||||
lora_adapter.eject()
|
||||
|
||||
@property
|
||||
def loras(self) -> list[Lora]:
|
||||
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora))
|
||||
|
||||
@property
|
||||
def lora_adapters(self) -> list[LoraAdapter]:
|
||||
return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter))
|
||||
|
||||
@property
|
||||
def unet_exclusions(self) -> dict[str, str]:
|
||||
return {
|
||||
"time": "TimestepEncoder",
|
||||
"res": "ResidualBlock",
|
||||
"downsample": "DownsampleBlock",
|
||||
"upsample": "UpsampleBlock",
|
||||
}
|
||||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
assert len(self.loras) > 0, "No loras found"
|
||||
assert all([lora.scale == self.loras[0].scale for lora in self.loras])
|
||||
return self.loras[0].scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, value: float) -> None:
|
||||
for lora in self.loras:
|
||||
lora.scale = value
|
||||
|
||||
@staticmethod
|
||||
def pad(input: str, /, padding_length: int = 2) -> str:
|
||||
new_split: list[str] = []
|
||||
for s in input.split("_"):
|
||||
if s.isdigit():
|
||||
new_split.append(s.zfill(padding_length))
|
||||
else:
|
||||
new_split.append(s)
|
||||
return "_".join(new_split)
|
||||
|
||||
@staticmethod
|
||||
def sort_keys(key: str, /) -> tuple[str, int]:
|
||||
# out0 happens sometimes as an alias for out ; this dict might not be exhaustive
|
||||
key_char_order = {"q": 1, "k": 2, "v": 3, "out": 4, "out0": 4}
|
||||
|
||||
for i, s in enumerate(key.split("_")):
|
||||
if s in key_char_order:
|
||||
prefix = SDLoraManager.pad("_".join(key.split("_")[:i]))
|
||||
return (prefix, key_char_order[s])
|
||||
|
||||
return (SDLoraManager.pad(key), 5)
|
||||
|
||||
@staticmethod
|
||||
def auto_attach(
|
||||
loras: dict[str, Lora],
|
||||
target: fl.Chain,
|
||||
/,
|
||||
exclude: list[str] | None = None,
|
||||
) -> None:
|
||||
failed_loras: dict[str, Lora] = {}
|
||||
for key, lora in loras.items():
|
||||
if attach := lora.auto_attach(target, exclude=exclude):
|
||||
adapter, parent = attach
|
||||
adapter.inject(parent)
|
||||
else:
|
||||
failed_loras[key] = lora
|
||||
|
||||
if failed_loras:
|
||||
warn(f"failed to attach {len(failed_loras)}/{len(loras)} loras to {target.__class__.__name__}")
|
||||
|
||||
# TODO: add a stronger sanity check to make sure loras are attached correctly
|
||||
|
|
|
@ -1,82 +0,0 @@
|
|||
from torch import allclose, randn
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.adapters.lora import Lora, LoraAdapter, SingleLoraAdapter
|
||||
|
||||
|
||||
def test_single_lora_adapter() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Chain(
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
),
|
||||
fl.Linear(in_features=1, out_features=2),
|
||||
)
|
||||
x = randn(1, 1)
|
||||
y = chain(x)
|
||||
|
||||
lora_adapter = SingleLoraAdapter(chain.Chain.Linear_1).inject(chain.Chain)
|
||||
|
||||
assert isinstance(lora_adapter[1], Lora)
|
||||
assert allclose(input=chain(x), other=y)
|
||||
assert lora_adapter.parent == chain.Chain
|
||||
|
||||
lora_adapter.eject()
|
||||
assert isinstance(chain.Chain[0], fl.Linear)
|
||||
assert len(chain) == 2
|
||||
|
||||
lora_adapter.inject(chain.Chain)
|
||||
assert isinstance(chain.Chain[0], SingleLoraAdapter)
|
||||
|
||||
|
||||
def test_lora_adapter() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Chain(
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
),
|
||||
fl.Linear(in_features=1, out_features=2),
|
||||
)
|
||||
|
||||
# create and inject twice
|
||||
|
||||
a1 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
|
||||
assert len(list(chain.layers(Lora))) == 3
|
||||
|
||||
a2 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
|
||||
assert len(list(chain.layers(Lora))) == 6
|
||||
|
||||
# If we init a LoRA when another LoRA is already injected, the Linear
|
||||
# layers of the first LoRA will be adapted too, which is typically not
|
||||
# what we want.
|
||||
# This issue can be avoided either by making the predicate for
|
||||
# `walk` raise StopIteration when it encounters a LoRA (see the SD LoRA)
|
||||
# or by creating all the LoRA Adapters first, before injecting them
|
||||
# (see below).
|
||||
assert len(list(chain.layers(Lora, recurse=True))) == 12
|
||||
|
||||
# ejection in forward order
|
||||
|
||||
a1.eject()
|
||||
assert len(list(chain.layers(Lora))) == 3
|
||||
a2.eject()
|
||||
assert len(list(chain.layers(Lora))) == 0
|
||||
|
||||
# create twice then inject twice
|
||||
|
||||
a1 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0)
|
||||
a2 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0)
|
||||
a1.inject()
|
||||
a2.inject()
|
||||
assert len(list(chain.layers(Lora))) == 6
|
||||
|
||||
# If we inject after init we do not have the target selection problem,
|
||||
# the LoRA layers are not adapted.
|
||||
assert len(list(chain.layers(Lora, recurse=True))) == 6
|
||||
|
||||
# ejection in reverse order
|
||||
|
||||
a2.eject()
|
||||
assert len(list(chain.layers(Lora))) == 3
|
||||
a1.eject()
|
||||
assert len(list(chain.layers(Lora))) == 0
|
|
@ -19,7 +19,7 @@ from refiners.foundationals.latent_diffusion import (
|
|||
StableDiffusion_1,
|
||||
StableDiffusion_1_Inpainting,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
||||
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
||||
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
|
||||
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
||||
from refiners.foundationals.latent_diffusion.restart import Restart
|
||||
|
@ -182,14 +182,38 @@ def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[
|
|||
condition_image = Image.open(ref_path / f"fairy_guide_{name}.png").convert("RGB")
|
||||
expected_image = Image.open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB")
|
||||
weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors"
|
||||
|
||||
if not weights_path.is_file():
|
||||
warn(f"could not find weights at {weights_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
return name, condition_image, expected_image, weights_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Path]:
|
||||
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
|
||||
expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB")
|
||||
weights_path = test_weights_path / "loras" / "pcuenq_pokemon_lora.safetensors"
|
||||
return expected_image, weights_path
|
||||
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
|
||||
|
||||
if not weights_path.is_file():
|
||||
warn(f"could not find weights at {weights_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
tensors = torch.load(weights_path) # type: ignore
|
||||
return expected_image, tensors
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
|
||||
expected_image = Image.open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB")
|
||||
weights_path = test_weights_path / "loras" / "dpo-lora" / "pytorch_lora_weights.safetensors"
|
||||
|
||||
if not weights_path.is_file():
|
||||
warn(f"could not find weights at {weights_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
tensors = load_from_safetensors(weights_path)
|
||||
return expected_image, tensors
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -1010,24 +1034,20 @@ def test_diffusion_controlnet_stack(
|
|||
@no_grad()
|
||||
def test_diffusion_lora(
|
||||
sd15_std: StableDiffusion_1,
|
||||
lora_data_pokemon: tuple[Image.Image, Path],
|
||||
lora_data_pokemon: tuple[Image.Image, dict[str, torch.Tensor]],
|
||||
test_device: torch.device,
|
||||
):
|
||||
) -> None:
|
||||
sd15 = sd15_std
|
||||
n_steps = 30
|
||||
|
||||
expected_image, lora_weights_path = lora_data_pokemon
|
||||
|
||||
if not lora_weights_path.is_file():
|
||||
warn(f"could not find weights at {lora_weights_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
expected_image, lora_weights = lora_data_pokemon
|
||||
|
||||
prompt = "a cute cat"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=1.0).inject()
|
||||
SDLoraManager(sd15).load(lora_weights, scale=1)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||
|
@ -1045,77 +1065,45 @@ def test_diffusion_lora(
|
|||
|
||||
|
||||
@no_grad()
|
||||
def test_diffusion_lora_float16(
|
||||
sd15_std_float16: StableDiffusion_1,
|
||||
lora_data_pokemon: tuple[Image.Image, Path],
|
||||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_std_float16
|
||||
n_steps = 30
|
||||
def test_diffusion_sdxl_lora(
|
||||
sdxl_ddim: StableDiffusion_XL,
|
||||
lora_data_dpo: tuple[Image.Image, dict[str, torch.Tensor]],
|
||||
) -> None:
|
||||
sdxl = sdxl_ddim
|
||||
expected_image, lora_weights = lora_data_dpo
|
||||
|
||||
expected_image, lora_weights_path = lora_data_pokemon
|
||||
# parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA
|
||||
# except that we are using DDIM instead of sde-dpmsolver++
|
||||
n_steps = 40
|
||||
seed = 12341234123
|
||||
guidance_scale = 7.5
|
||||
lora_scale = 1.4
|
||||
prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography"
|
||||
negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white"
|
||||
|
||||
if not lora_weights_path.is_file():
|
||||
warn(f"could not find weights at {lora_weights_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
SDLoraManager(sdxl).load(lora_weights, scale=lora_scale)
|
||||
|
||||
prompt = "a cute cat"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||
text=prompt, negative_text=negative_prompt
|
||||
)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
time_ids = sdxl.default_time_ids
|
||||
sdxl.set_num_inference_steps(n_steps)
|
||||
|
||||
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=1.0).inject()
|
||||
manual_seed(seed=seed)
|
||||
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||
|
||||
for step in sd15.steps:
|
||||
x = sd15(
|
||||
for step in sdxl.steps:
|
||||
x = sdxl(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
pooled_text_embedding=pooled_text_embedding,
|
||||
time_ids=time_ids,
|
||||
condition_scale=guidance_scale,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_diffusion_lora_twice(
|
||||
sd15_std: StableDiffusion_1,
|
||||
lora_data_pokemon: tuple[Image.Image, Path],
|
||||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_std
|
||||
n_steps = 30
|
||||
|
||||
expected_image, lora_weights_path = lora_data_pokemon
|
||||
|
||||
if not lora_weights_path.is_file():
|
||||
warn(f"could not find weights at {lora_weights_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
prompt = "a cute cat"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
# The same LoRA is used twice which is not a common use case: this is purely for testing purpose
|
||||
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=0.4).inject()
|
||||
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=0.6).inject()
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||
|
||||
for step in sd15.steps:
|
||||
x = sd15(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sdxl.lda.decode_latents(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
|
|
@ -48,6 +48,7 @@ Special cases:
|
|||
- `expected_restart.png`
|
||||
- `expected_freeu.png`
|
||||
- `expected_dropy_slime_9752.png`
|
||||
- `expected_sdxl_dpo_lora.png`
|
||||
|
||||
## Other images
|
||||
|
||||
|
|
BIN
tests/e2e/test_diffusion_ref/expected_sdxl_dpo_lora.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_sdxl_dpo_lora.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.6 MiB |
Loading…
Reference in a new issue