mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
refactor Lora LoraAdapter and the latent_diffusion/lora file
This commit is contained in:
parent
dd87b9706e
commit
a1f50f3f9d
10
README.md
10
README.md
|
@ -158,16 +158,6 @@ You should get:
|
||||||
|
|
||||||
![dropy slime output](https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/dropy_slime_9752.png)
|
![dropy slime output](https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/dropy_slime_9752.png)
|
||||||
|
|
||||||
### Training
|
|
||||||
|
|
||||||
Refiners has a built-in training utils library and provides scripts that can be used as a starting point.
|
|
||||||
|
|
||||||
E.g. to train a LoRA on top of Stable Diffusion, copy and edit `configs/finetune-lora.toml` to suit your needs and launch the training as follows:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/training/finetune-ldm-lora.py configs/finetune-lora.toml
|
|
||||||
```
|
|
||||||
|
|
||||||
## Adapter Zoo
|
## Adapter Zoo
|
||||||
|
|
||||||
For now, given [finegrain](https://finegrain.ai)'s mission, we are focusing on image edition tasks. We support:
|
For now, given [finegrain](https://finegrain.ai)'s mission, we are focusing on image edition tasks. We support:
|
||||||
|
|
|
@ -1,68 +0,0 @@
|
||||||
[wandb]
|
|
||||||
mode = "offline" # "online", "offline", "disabled"
|
|
||||||
entity = "acme"
|
|
||||||
project = "test-lora-training"
|
|
||||||
|
|
||||||
[models]
|
|
||||||
unet = {checkpoint = "/path/to/stable-diffusion-1-5/unet.safetensors"}
|
|
||||||
text_encoder = {checkpoint = "/path/to/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"}
|
|
||||||
lda = {checkpoint = "/path/to/stable-diffusion-1-5/lda.safetensors"}
|
|
||||||
|
|
||||||
[latent_diffusion]
|
|
||||||
unconditional_sampling_probability = 0.05
|
|
||||||
offset_noise = 0.1
|
|
||||||
|
|
||||||
[lora]
|
|
||||||
rank = 16
|
|
||||||
trigger_phrase = "a spsh photo,"
|
|
||||||
use_only_trigger_probability = 1.0
|
|
||||||
unet_targets = ["CrossAttentionBlock2d"]
|
|
||||||
text_encoder_targets = ["TransformerLayer"]
|
|
||||||
lda_targets = []
|
|
||||||
|
|
||||||
[training]
|
|
||||||
duration = "1000:epoch"
|
|
||||||
seed = 0
|
|
||||||
gpu_index = 0
|
|
||||||
batch_size = 4
|
|
||||||
gradient_accumulation = "4:step"
|
|
||||||
clip_grad_norm = 1.0
|
|
||||||
# clip_grad_value = 1.0
|
|
||||||
evaluation_interval = "5:epoch"
|
|
||||||
evaluation_seed = 1
|
|
||||||
|
|
||||||
|
|
||||||
[optimizer]
|
|
||||||
optimizer = "Prodigy" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit"
|
|
||||||
learning_rate = 1
|
|
||||||
betas = [0.9, 0.999]
|
|
||||||
eps = 1e-8
|
|
||||||
weight_decay = 1e-2
|
|
||||||
|
|
||||||
[scheduler]
|
|
||||||
scheduler_type = "ConstantLR"
|
|
||||||
update_interval = "1:step"
|
|
||||||
warmup = "500:step"
|
|
||||||
|
|
||||||
|
|
||||||
[dropout]
|
|
||||||
dropout_probability = 0.2
|
|
||||||
use_gyro_dropout = false
|
|
||||||
|
|
||||||
[dataset]
|
|
||||||
hf_repo = "acme/images"
|
|
||||||
revision = "main"
|
|
||||||
|
|
||||||
[checkpointing]
|
|
||||||
# save_folder = "/path/to/ckpts"
|
|
||||||
save_interval = "1:step"
|
|
||||||
|
|
||||||
[test_diffusion]
|
|
||||||
num_inference_steps = 30
|
|
||||||
use_short_prompts = false
|
|
||||||
prompts = [
|
|
||||||
"a cute cat",
|
|
||||||
"a cute dog",
|
|
||||||
"a cute bird",
|
|
||||||
"a cute horse",
|
|
||||||
]
|
|
|
@ -1,149 +0,0 @@
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from diffusers import DiffusionPipeline # type: ignore
|
|
||||||
from torch import Tensor
|
|
||||||
from torch.nn import Parameter as TorchParameter
|
|
||||||
from torch.nn.init import zeros_
|
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
|
||||||
from refiners.fluxion.adapters.lora import Lora, LoraAdapter
|
|
||||||
from refiners.fluxion.model_converter import ModelConverter
|
|
||||||
from refiners.fluxion.utils import no_grad, save_to_safetensors
|
|
||||||
from refiners.foundationals.latent_diffusion import SD1UNet
|
|
||||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
|
|
||||||
|
|
||||||
|
|
||||||
def get_weight(linear: fl.Linear) -> torch.Tensor:
|
|
||||||
assert linear.bias is None
|
|
||||||
return linear.state_dict()["weight"]
|
|
||||||
|
|
||||||
|
|
||||||
def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, torch.Tensor]:
|
|
||||||
weights: list[torch.Tensor] = []
|
|
||||||
for lora in module.layers(layer_type=Lora):
|
|
||||||
linears = list(lora.layers(layer_type=fl.Linear))
|
|
||||||
assert len(linears) == 2
|
|
||||||
weights.extend((get_weight(linear=linears[1]), get_weight(linear=linears[0]))) # aka (up_weight, down_weight)
|
|
||||||
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(iterable=weights)}
|
|
||||||
|
|
||||||
|
|
||||||
class Args(argparse.Namespace):
|
|
||||||
source_path: str
|
|
||||||
base_model: str
|
|
||||||
output_file: str
|
|
||||||
verbose: bool
|
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
|
||||||
def process(args: Args) -> None:
|
|
||||||
diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore
|
|
||||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
|
||||||
diffusers_sd = DiffusionPipeline.from_pretrained( # type: ignore
|
|
||||||
pretrained_model_name_or_path=args.base_model,
|
|
||||||
low_cpu_mem_usage=False,
|
|
||||||
)
|
|
||||||
diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore
|
|
||||||
|
|
||||||
refiners_model = SD1UNet(in_channels=4)
|
|
||||||
target = LoraTarget.CrossAttention
|
|
||||||
metadata = {"unet_targets": "CrossAttentionBlock2d"}
|
|
||||||
rank = diffusers_state_dict[
|
|
||||||
"mid_block.attentions.0.transformer_blocks.0.attn1.processor.to_q_lora.down.weight"
|
|
||||||
].shape[0]
|
|
||||||
|
|
||||||
x = torch.randn(1, 4, 32, 32)
|
|
||||||
timestep = torch.tensor(data=[0])
|
|
||||||
clip_text_embeddings = torch.randn(1, 77, 768)
|
|
||||||
|
|
||||||
refiners_model.set_timestep(timestep=timestep)
|
|
||||||
refiners_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
|
||||||
refiners_args = (x,)
|
|
||||||
|
|
||||||
diffusers_args = (x, timestep, clip_text_embeddings)
|
|
||||||
|
|
||||||
converter = ModelConverter(
|
|
||||||
source_model=refiners_model, target_model=diffusers_model, skip_output_check=True, verbose=args.verbose
|
|
||||||
)
|
|
||||||
if not converter.run(
|
|
||||||
source_args=refiners_args,
|
|
||||||
target_args=diffusers_args,
|
|
||||||
):
|
|
||||||
raise RuntimeError("Model conversion failed")
|
|
||||||
|
|
||||||
diffusers_to_refiners = converter.get_mapping()
|
|
||||||
|
|
||||||
LoraAdapter[SD1UNet](refiners_model, sub_targets=lora_targets(refiners_model, target), rank=rank).inject()
|
|
||||||
|
|
||||||
for layer in refiners_model.layers(layer_type=Lora):
|
|
||||||
zeros_(tensor=layer.Linear_1.weight)
|
|
||||||
|
|
||||||
targets = {k.split("_lora.")[0] for k in diffusers_state_dict.keys()}
|
|
||||||
for target_k in targets:
|
|
||||||
k_p, k_s = target_k.split(".processor.")
|
|
||||||
r = [v for k, v in diffusers_to_refiners.items() if k.startswith(f"{k_p}.{k_s}")]
|
|
||||||
assert len(r) == 1
|
|
||||||
orig_k = r[0]
|
|
||||||
orig_path = orig_k.split(sep=".")
|
|
||||||
p = refiners_model
|
|
||||||
for seg in orig_path[:-1]:
|
|
||||||
p = p[seg]
|
|
||||||
assert isinstance(p, fl.Chain)
|
|
||||||
last_seg = (
|
|
||||||
"SingleLoraAdapter"
|
|
||||||
if orig_path[-1] == "Linear"
|
|
||||||
else f"SingleLoraAdapter_{orig_path[-1].removeprefix('Linear_')}"
|
|
||||||
)
|
|
||||||
p_down = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.down.weight"])
|
|
||||||
p_up = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.up.weight"])
|
|
||||||
p[last_seg].Lora.load_weights(p_down, p_up)
|
|
||||||
|
|
||||||
state_dict = build_loras_safetensors(module=refiners_model, key_prefix="unet.")
|
|
||||||
assert len(state_dict) == 320
|
|
||||||
save_to_safetensors(path=args.output_path, tensors=state_dict, metadata=metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = argparse.ArgumentParser(description="Convert LoRAs saved using the diffusers library to refiners format.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--from",
|
|
||||||
type=str,
|
|
||||||
dest="source_path",
|
|
||||||
required=True,
|
|
||||||
help="Source file path (.bin|safetensors) containing the LoRAs.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--base-model",
|
|
||||||
type=str,
|
|
||||||
required=False,
|
|
||||||
default="runwayml/stable-diffusion-v1-5",
|
|
||||||
help="Base model, used for the UNet structure. Default: runwayml/stable-diffusion-v1-5",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--to",
|
|
||||||
type=str,
|
|
||||||
dest="output_path",
|
|
||||||
required=False,
|
|
||||||
default=None,
|
|
||||||
help=(
|
|
||||||
"Output file path (.safetensors) for converted LoRAs. If not provided, the output path will be the same as"
|
|
||||||
" the source path."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--verbose",
|
|
||||||
action="store_true",
|
|
||||||
dest="verbose",
|
|
||||||
default=False,
|
|
||||||
help="Use this flag to print verbose output during conversion.",
|
|
||||||
)
|
|
||||||
args = parser.parse_args(namespace=Args())
|
|
||||||
if args.output_path is None:
|
|
||||||
args.output_path = f"{Path(args.source_path).stem}-refiners.safetensors"
|
|
||||||
process(args=args)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,124 +0,0 @@
|
||||||
import argparse
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from convert_diffusers_unet import Args as UnetConversionArgs, setup_converter as convert_unet
|
|
||||||
from convert_transformers_clip_text_model import (
|
|
||||||
Args as TextEncoderConversionArgs,
|
|
||||||
setup_converter as convert_text_encoder,
|
|
||||||
)
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
|
||||||
from refiners.fluxion.utils import (
|
|
||||||
load_from_safetensors,
|
|
||||||
load_metadata_from_safetensors,
|
|
||||||
save_to_safetensors,
|
|
||||||
)
|
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
|
||||||
from refiners.foundationals.latent_diffusion import SD1UNet
|
|
||||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget
|
|
||||||
|
|
||||||
|
|
||||||
def get_unet_mapping(source_path: str) -> dict[str, str]:
|
|
||||||
args = UnetConversionArgs(source_path=source_path, verbose=False)
|
|
||||||
return convert_unet(args=args).get_mapping()
|
|
||||||
|
|
||||||
|
|
||||||
def get_text_encoder_mapping(source_path: str) -> dict[str, str]:
|
|
||||||
args = TextEncoderConversionArgs(source_path=source_path, subfolder="text_encoder", verbose=False)
|
|
||||||
return convert_text_encoder(
|
|
||||||
args=args,
|
|
||||||
).get_mapping()
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = argparse.ArgumentParser(description="Converts a refiner's LoRA weights to SD-WebUI's LoRA weights")
|
|
||||||
parser.add_argument(
|
|
||||||
"-i",
|
|
||||||
"--input-file",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the input file with refiner's LoRA weights (safetensors format)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-o",
|
|
||||||
"--output-file",
|
|
||||||
type=str,
|
|
||||||
default="sdwebui_loras.safetensors",
|
|
||||||
help="Path to the output file with sd-webui's LoRA weights (safetensors format)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--sd15",
|
|
||||||
type=str,
|
|
||||||
required=False,
|
|
||||||
default="runwayml/stable-diffusion-v1-5",
|
|
||||||
help="Path (preferred) or repository ID of Stable Diffusion 1.5 model (Hugging Face diffusers format)",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
metadata = load_metadata_from_safetensors(path=args.input_file)
|
|
||||||
assert metadata is not None, f"Could not load metadata from {args.input_file}"
|
|
||||||
tensors = load_from_safetensors(path=args.input_file)
|
|
||||||
|
|
||||||
state_dict: dict[str, Tensor] = {}
|
|
||||||
|
|
||||||
for meta_key, meta_value in metadata.items():
|
|
||||||
match meta_key:
|
|
||||||
case "unet_targets":
|
|
||||||
model = SD1UNet(in_channels=4)
|
|
||||||
create_mapping = partial(get_unet_mapping, source_path=args.sd15)
|
|
||||||
key_prefix = "unet."
|
|
||||||
lora_prefix = "lora_unet_"
|
|
||||||
case "text_encoder_targets":
|
|
||||||
model = CLIPTextEncoderL()
|
|
||||||
create_mapping = partial(get_text_encoder_mapping, source_path=args.sd15)
|
|
||||||
key_prefix = "text_encoder."
|
|
||||||
lora_prefix = "lora_te_"
|
|
||||||
case "lda_targets":
|
|
||||||
raise ValueError("SD-WebUI does not support LoRA for the auto-encoder")
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Unexpected key in checkpoint metadata: {meta_key}")
|
|
||||||
|
|
||||||
submodule_to_key: dict[fl.Module, str] = {}
|
|
||||||
for name, submodule in model.named_modules():
|
|
||||||
submodule_to_key[submodule] = name
|
|
||||||
|
|
||||||
# SD-WebUI expects LoRA state dicts with keys derived from the diffusers format, e.g.:
|
|
||||||
#
|
|
||||||
# lora_unet_down_blocks_0_attentions_0_proj_in.alpha
|
|
||||||
# lora_unet_down_blocks_0_attentions_0_proj_in.lora_down.weight
|
|
||||||
# lora_unet_down_blocks_0_attentions_0_proj_in.lora_up.weight
|
|
||||||
# ...
|
|
||||||
#
|
|
||||||
# Internally SD-WebUI has some logic[1] to convert such keys into the CompVis format. See
|
|
||||||
# `convert_diffusers_name_to_compvis` for more details.
|
|
||||||
#
|
|
||||||
# [1]: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/394ffa7/extensions-builtin/Lora/lora.py#L158-L225
|
|
||||||
|
|
||||||
refiners_to_diffusers = create_mapping()
|
|
||||||
assert refiners_to_diffusers is not None
|
|
||||||
|
|
||||||
# Compute the corresponding diffusers' keys where LoRA layers must be applied
|
|
||||||
lora_injection_points: list[str] = [
|
|
||||||
refiners_to_diffusers[submodule_to_key[linear]]
|
|
||||||
for target in [LoraTarget(t) for t in meta_value.split(sep=",")]
|
|
||||||
for layer in model.layers(layer_type=target.get_class())
|
|
||||||
for linear in layer.layers(layer_type=fl.Linear)
|
|
||||||
]
|
|
||||||
|
|
||||||
lora_weights = [w for w in [tensors[k] for k in sorted(tensors) if k.startswith(key_prefix)]]
|
|
||||||
assert len(lora_injection_points) == len(lora_weights) // 2
|
|
||||||
|
|
||||||
# Map LoRA weights to each key using SD-WebUI conventions (proper prefix and suffix, underscores)
|
|
||||||
for i, diffusers_key in enumerate(iterable=lora_injection_points):
|
|
||||||
lora_key = lora_prefix + diffusers_key.replace(".", "_")
|
|
||||||
# Note: no ".alpha" weights (those are used to scale the LoRA by alpha/rank). Refiners uses a scale = 1.0
|
|
||||||
# by default (see `lora_calc_updown` in SD-WebUI for more details)
|
|
||||||
state_dict[lora_key + ".lora_up.weight"] = lora_weights[2 * i]
|
|
||||||
state_dict[lora_key + ".lora_down.weight"] = lora_weights[2 * i + 1]
|
|
||||||
|
|
||||||
assert state_dict
|
|
||||||
save_to_safetensors(path=args.output_file, tensors=state_dict)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -229,10 +229,15 @@ def download_vae_ft_mse():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def download_lora():
|
def download_loras():
|
||||||
dest_folder = os.path.join(test_weights_dir, "pcuenq", "pokemon-lora")
|
dest_folder = os.path.join(test_weights_dir, "loras", "pokemon-lora")
|
||||||
download_file("https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin", dest_folder)
|
download_file("https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin", dest_folder)
|
||||||
|
|
||||||
|
dest_folder = os.path.join(test_weights_dir, "loras", "dpo-lora")
|
||||||
|
download_file(
|
||||||
|
"https://huggingface.co/radames/sdxl-DPO-LoRA/resolve/main/pytorch_lora_weights.safetensors", dest_folder
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def download_preprocessors():
|
def download_preprocessors():
|
||||||
dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings")
|
dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings")
|
||||||
|
@ -454,17 +459,6 @@ def convert_vae_fp16_fix():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_lora():
|
|
||||||
os.makedirs("tests/weights/loras", exist_ok=True)
|
|
||||||
run_conversion_script(
|
|
||||||
"convert_diffusers_lora.py",
|
|
||||||
"tests/weights/pcuenq/pokemon-lora/pytorch_lora_weights.bin",
|
|
||||||
"tests/weights/loras/pcuenq_pokemon_lora.safetensors",
|
|
||||||
additional_args=["--base-model", "tests/weights/runwayml/stable-diffusion-v1-5"],
|
|
||||||
expected_hash="a9d7e08e",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_preprocessors():
|
def convert_preprocessors():
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
[
|
[
|
||||||
|
@ -632,7 +626,7 @@ def download_all():
|
||||||
download_sdxl("stabilityai/stable-diffusion-xl-base-1.0")
|
download_sdxl("stabilityai/stable-diffusion-xl-base-1.0")
|
||||||
download_vae_ft_mse()
|
download_vae_ft_mse()
|
||||||
download_vae_fp16_fix()
|
download_vae_fp16_fix()
|
||||||
download_lora()
|
download_loras()
|
||||||
download_preprocessors()
|
download_preprocessors()
|
||||||
download_controlnet()
|
download_controlnet()
|
||||||
download_unclip()
|
download_unclip()
|
||||||
|
@ -647,7 +641,7 @@ def convert_all():
|
||||||
convert_sdxl()
|
convert_sdxl()
|
||||||
convert_vae_ft_mse()
|
convert_vae_ft_mse()
|
||||||
convert_vae_fp16_fix()
|
convert_vae_fp16_fix()
|
||||||
convert_lora()
|
# Note: no convert loras: this is done at runtime by `SDLoraManager`
|
||||||
convert_preprocessors()
|
convert_preprocessors()
|
||||||
convert_controlnet()
|
convert_controlnet()
|
||||||
convert_unclip()
|
convert_unclip()
|
||||||
|
|
|
@ -1,123 +0,0 @@
|
||||||
import random
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from torch import Tensor
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
|
||||||
from refiners.fluxion.utils import save_to_safetensors
|
|
||||||
from refiners.foundationals.latent_diffusion.lora import MODELS, LoraAdapter, LoraTarget, lora_targets
|
|
||||||
from refiners.training_utils.callback import Callback
|
|
||||||
from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig
|
|
||||||
from refiners.training_utils.latent_diffusion import (
|
|
||||||
FinetuneLatentDiffusionConfig,
|
|
||||||
LatentDiffusionConfig,
|
|
||||||
LatentDiffusionTrainer,
|
|
||||||
TextEmbeddingLatentsBatch,
|
|
||||||
TextEmbeddingLatentsDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LoraConfig(BaseModel):
|
|
||||||
rank: int = 32
|
|
||||||
trigger_phrase: str = ""
|
|
||||||
use_only_trigger_probability: float = 0.0
|
|
||||||
unet_targets: list[LoraTarget]
|
|
||||||
text_encoder_targets: list[LoraTarget]
|
|
||||||
lda_targets: list[LoraTarget]
|
|
||||||
|
|
||||||
|
|
||||||
class TriggerPhraseDataset(TextEmbeddingLatentsDataset):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
trainer: "LoraLatentDiffusionTrainer",
|
|
||||||
) -> None:
|
|
||||||
super().__init__(trainer=trainer)
|
|
||||||
self.trigger_phrase = trainer.config.lora.trigger_phrase
|
|
||||||
self.use_only_trigger_probability = trainer.config.lora.use_only_trigger_probability
|
|
||||||
logger.info(f"Trigger phrase: {self.trigger_phrase}")
|
|
||||||
|
|
||||||
def process_caption(self, caption: str) -> str:
|
|
||||||
caption = super().process_caption(caption=caption)
|
|
||||||
if self.trigger_phrase:
|
|
||||||
caption = (
|
|
||||||
f"{self.trigger_phrase} {caption}"
|
|
||||||
if random.random() < self.use_only_trigger_probability
|
|
||||||
else self.trigger_phrase
|
|
||||||
)
|
|
||||||
return caption
|
|
||||||
|
|
||||||
|
|
||||||
class LoraLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
|
|
||||||
dataset: HuggingfaceDatasetConfig
|
|
||||||
latent_diffusion: LatentDiffusionConfig
|
|
||||||
lora: LoraConfig
|
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
|
||||||
"""Pydantic v2 does post init differently, so we need to override this method too."""
|
|
||||||
logger.info("Freezing models to train only the loras.")
|
|
||||||
self.models["unet"].train = False
|
|
||||||
self.models["text_encoder"].train = False
|
|
||||||
self.models["lda"].train = False
|
|
||||||
|
|
||||||
|
|
||||||
class LoraLatentDiffusionTrainer(LatentDiffusionTrainer[LoraLatentDiffusionConfig]):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: LoraLatentDiffusionConfig,
|
|
||||||
callbacks: "list[Callback[Any]] | None" = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(config=config, callbacks=callbacks)
|
|
||||||
self.callbacks.extend((LoadLoras(), SaveLoras()))
|
|
||||||
|
|
||||||
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
|
|
||||||
return TriggerPhraseDataset(trainer=self)
|
|
||||||
|
|
||||||
|
|
||||||
class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
|
|
||||||
def on_train_begin(self, trainer: LoraLatentDiffusionTrainer) -> None:
|
|
||||||
lora_config = trainer.config.lora
|
|
||||||
|
|
||||||
for model_name in MODELS:
|
|
||||||
model = getattr(trainer, model_name)
|
|
||||||
model_targets: list[LoraTarget] = getattr(lora_config, f"{model_name}_targets")
|
|
||||||
adapter = LoraAdapter[type(model)](
|
|
||||||
model,
|
|
||||||
sub_targets=lora_targets(model, model_targets),
|
|
||||||
rank=lora_config.rank,
|
|
||||||
)
|
|
||||||
for sub_adapter, _ in adapter.sub_adapters:
|
|
||||||
for linear in sub_adapter.Lora.layers(fl.Linear):
|
|
||||||
linear.requires_grad_(requires_grad=True)
|
|
||||||
adapter.inject()
|
|
||||||
|
|
||||||
|
|
||||||
class SaveLoras(Callback[LoraLatentDiffusionTrainer]):
|
|
||||||
def on_checkpoint_save(self, trainer: LoraLatentDiffusionTrainer) -> None:
|
|
||||||
tensors: dict[str, Tensor] = {}
|
|
||||||
metadata: dict[str, str] = {}
|
|
||||||
|
|
||||||
for model_name in MODELS:
|
|
||||||
model = getattr(trainer, model_name)
|
|
||||||
adapter = model.parent
|
|
||||||
tensors |= {f"{model_name}.{i:03d}": w for i, w in enumerate(adapter.weights)}
|
|
||||||
metadata |= {f"{model_name}_targets": ",".join(adapter.sub_targets)}
|
|
||||||
|
|
||||||
save_to_safetensors(
|
|
||||||
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors",
|
|
||||||
tensors=tensors,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import sys
|
|
||||||
|
|
||||||
config_path = sys.argv[1]
|
|
||||||
config = LoraLatentDiffusionConfig.load_from_toml(
|
|
||||||
toml_path=config_path,
|
|
||||||
)
|
|
||||||
trainer = LoraLatentDiffusionTrainer(config=config)
|
|
||||||
trainer.train()
|
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Any, Generic, Iterable, TypeVar
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from torch import Tensor, device as Device, dtype as DType
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
from torch.nn import Parameter as TorchParameter
|
from torch.nn import Parameter as TorchParameter
|
||||||
|
@ -6,125 +7,259 @@ from torch.nn.init import normal_, zeros_
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
from refiners.fluxion.layers.chain import Chain
|
||||||
T = TypeVar("T", bound=fl.Chain)
|
|
||||||
TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PEP 673)
|
|
||||||
|
|
||||||
|
|
||||||
class Lora(fl.Chain):
|
class Lora(fl.Chain, ABC):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
rank: int = 16,
|
||||||
|
scale: float = 1.0,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.rank = rank
|
||||||
|
self._scale = scale
|
||||||
|
|
||||||
|
super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale))
|
||||||
|
|
||||||
|
normal_(tensor=self.down.weight, std=1 / self.rank)
|
||||||
|
zeros_(tensor=self.up.weight)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def lora_layers(
|
||||||
|
self, device: Device | str | None = None, dtype: DType | None = None
|
||||||
|
) -> tuple[fl.WeightedModule, fl.WeightedModule]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def down(self) -> fl.WeightedModule:
|
||||||
|
down_layer = self[0]
|
||||||
|
assert isinstance(down_layer, fl.WeightedModule)
|
||||||
|
return down_layer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def up(self) -> fl.WeightedModule:
|
||||||
|
up_layer = self[1]
|
||||||
|
assert isinstance(up_layer, fl.WeightedModule)
|
||||||
|
return up_layer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scale(self) -> float:
|
||||||
|
return self._scale
|
||||||
|
|
||||||
|
@scale.setter
|
||||||
|
def scale(self, value: float) -> None:
|
||||||
|
self._scale = value
|
||||||
|
self.ensure_find(fl.Multiply).scale = value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_weights(
|
||||||
|
cls,
|
||||||
|
down: Tensor,
|
||||||
|
up: Tensor,
|
||||||
|
) -> "Lora":
|
||||||
|
match (up.ndim, down.ndim):
|
||||||
|
case (2, 2):
|
||||||
|
return LinearLora.from_weights(up=up, down=down)
|
||||||
|
case (4, 4):
|
||||||
|
return Conv2dLora.from_weights(up=up, down=down)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, state_dict: dict[str, Tensor], /) -> dict[str, "Lora"]:
|
||||||
|
"""
|
||||||
|
Create a dictionary of LoRA layers from a state dict.
|
||||||
|
|
||||||
|
Expects the state dict to be a succession of down and up weights.
|
||||||
|
"""
|
||||||
|
state_dict = {k: v for k, v in state_dict.items() if ".weight" in k}
|
||||||
|
loras: dict[str, Lora] = {}
|
||||||
|
for down_key, down_tensor, up_tensor in zip(
|
||||||
|
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
|
||||||
|
):
|
||||||
|
key = ".".join(down_key.split(".")[:-2])
|
||||||
|
loras[key] = cls.from_weights(down=down_tensor, up=up_tensor)
|
||||||
|
return loras
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def auto_attach(self, target: fl.Chain, exclude: list[str] | None = None) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
|
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
||||||
|
assert down_weight.shape == self.down.weight.shape
|
||||||
|
assert up_weight.shape == self.up.weight.shape
|
||||||
|
self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
|
||||||
|
self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
class LinearLora(Lora):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
rank: int = 16,
|
rank: int = 16,
|
||||||
|
scale: float = 1.0,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self.rank = rank
|
|
||||||
self.scale: float = 1.0
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
|
||||||
fl.Linear(in_features=in_features, out_features=rank, bias=False, device=device, dtype=dtype),
|
|
||||||
fl.Linear(in_features=rank, out_features=out_features, bias=False, device=device, dtype=dtype),
|
@classmethod
|
||||||
fl.Lambda(func=self.scale_outputs),
|
def from_weights(
|
||||||
|
cls,
|
||||||
|
down: Tensor,
|
||||||
|
up: Tensor,
|
||||||
|
) -> "LinearLora":
|
||||||
|
assert up.ndim == 2 and down.ndim == 2
|
||||||
|
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
|
||||||
|
lora = cls(
|
||||||
|
in_features=down.shape[1], out_features=up.shape[0], rank=down.shape[0], device=up.device, dtype=up.dtype
|
||||||
)
|
)
|
||||||
|
lora.load_weights(down_weight=down, up_weight=up)
|
||||||
|
return lora
|
||||||
|
|
||||||
normal_(tensor=self.Linear_1.weight, std=1 / self.rank)
|
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
|
||||||
zeros_(tensor=self.Linear_2.weight)
|
for layer, parent in target.walk(fl.Linear):
|
||||||
|
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
|
||||||
|
continue
|
||||||
|
|
||||||
def scale_outputs(self, x: Tensor) -> Tensor:
|
if exclude is not None and any(
|
||||||
return x * self.scale
|
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
def set_scale(self, scale: float) -> None:
|
if layer.in_features == self.in_features and layer.out_features == self.out_features:
|
||||||
self.scale = scale
|
return LoraAdapter(target=layer, lora=self), parent
|
||||||
|
|
||||||
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
def lora_layers(
|
||||||
self.Linear_1.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
|
self, device: Device | str | None = None, dtype: DType | None = None
|
||||||
self.Linear_2.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
|
) -> tuple[fl.Linear, fl.Linear]:
|
||||||
|
return (
|
||||||
@property
|
fl.Linear(
|
||||||
def up_weight(self) -> Tensor:
|
in_features=self.in_features,
|
||||||
return self.Linear_2.weight.data
|
out_features=self.rank,
|
||||||
|
bias=False,
|
||||||
@property
|
device=device,
|
||||||
def down_weight(self) -> Tensor:
|
dtype=dtype,
|
||||||
return self.Linear_1.weight.data
|
),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.rank,
|
||||||
class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]):
|
out_features=self.out_features,
|
||||||
def __init__(
|
bias=False,
|
||||||
self,
|
device=device,
|
||||||
target: fl.Linear,
|
dtype=dtype,
|
||||||
rank: int = 16,
|
|
||||||
scale: float = 1.0,
|
|
||||||
) -> None:
|
|
||||||
self.in_features = target.in_features
|
|
||||||
self.out_features = target.out_features
|
|
||||||
self.rank = rank
|
|
||||||
self.scale = scale
|
|
||||||
with self.setup_adapter(target):
|
|
||||||
super().__init__(
|
|
||||||
target,
|
|
||||||
Lora(
|
|
||||||
in_features=target.in_features,
|
|
||||||
out_features=target.out_features,
|
|
||||||
rank=rank,
|
|
||||||
device=target.device,
|
|
||||||
dtype=target.dtype,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.Lora.set_scale(scale=scale)
|
|
||||||
|
|
||||||
|
|
||||||
class LoraAdapter(Generic[T], fl.Chain, Adapter[T]):
|
class Conv2dLora(Lora):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
target: T,
|
in_channels: int,
|
||||||
sub_targets: Iterable[tuple[fl.Linear, fl.Chain]],
|
out_channels: int,
|
||||||
rank: int | None = None,
|
rank: int = 16,
|
||||||
scale: float = 1.0,
|
scale: float = 1.0,
|
||||||
weights: list[Tensor] | None = None,
|
kernel_size: tuple[int, int] = (1, 3),
|
||||||
|
stride: tuple[int, int] = (1, 1),
|
||||||
|
padding: tuple[int, int] = (0, 1),
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_weights(
|
||||||
|
cls,
|
||||||
|
down: Tensor,
|
||||||
|
up: Tensor,
|
||||||
|
) -> "Conv2dLora":
|
||||||
|
assert up.ndim == 4 and down.ndim == 4
|
||||||
|
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
|
||||||
|
down_kernel_size, up_kernel_size = down.shape[2], up.shape[2]
|
||||||
|
down_padding = 1 if down_kernel_size == 3 else 0
|
||||||
|
up_padding = 1 if up_kernel_size == 3 else 0
|
||||||
|
lora = cls(
|
||||||
|
in_channels=down.shape[1],
|
||||||
|
out_channels=up.shape[0],
|
||||||
|
rank=down.shape[0],
|
||||||
|
kernel_size=(down_kernel_size, up_kernel_size),
|
||||||
|
padding=(down_padding, up_padding),
|
||||||
|
device=up.device,
|
||||||
|
dtype=up.dtype,
|
||||||
|
)
|
||||||
|
lora.load_weights(down_weight=down, up_weight=up)
|
||||||
|
return lora
|
||||||
|
|
||||||
|
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
|
||||||
|
for layer, parent in target.walk(fl.Conv2d):
|
||||||
|
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if exclude is not None and any(
|
||||||
|
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if layer.in_channels == self.in_channels and layer.out_channels == self.out_channels:
|
||||||
|
if layer.stride != (self.stride[0], self.stride[0]):
|
||||||
|
self.down.stride = layer.stride
|
||||||
|
|
||||||
|
return LoraAdapter(
|
||||||
|
target=layer,
|
||||||
|
lora=self,
|
||||||
|
), parent
|
||||||
|
|
||||||
|
def lora_layers(
|
||||||
|
self, device: Device | str | None = None, dtype: DType | None = None
|
||||||
|
) -> tuple[fl.Conv2d, fl.Conv2d]:
|
||||||
|
return (
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=self.in_channels,
|
||||||
|
out_channels=self.rank,
|
||||||
|
kernel_size=self.kernel_size[0],
|
||||||
|
stride=self.stride[0],
|
||||||
|
padding=self.padding[0],
|
||||||
|
use_bias=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=self.rank,
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
kernel_size=self.kernel_size[1],
|
||||||
|
stride=self.stride[1],
|
||||||
|
padding=self.padding[1],
|
||||||
|
use_bias=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
||||||
|
def __init__(self, target: fl.WeightedModule, lora: Lora) -> None:
|
||||||
with self.setup_adapter(target):
|
with self.setup_adapter(target):
|
||||||
super().__init__(target)
|
super().__init__(target, lora)
|
||||||
|
|
||||||
if weights is not None:
|
|
||||||
assert len(weights) % 2 == 0
|
|
||||||
weights_rank = weights[0].shape[1]
|
|
||||||
if rank is None:
|
|
||||||
rank = weights_rank
|
|
||||||
else:
|
|
||||||
assert rank == weights_rank
|
|
||||||
|
|
||||||
assert rank is not None, "either pass a rank or weights"
|
|
||||||
|
|
||||||
self.sub_targets = sub_targets
|
|
||||||
self.sub_adapters: list[tuple[SingleLoraAdapter, fl.Chain]] = []
|
|
||||||
|
|
||||||
for linear, parent in self.sub_targets:
|
|
||||||
self.sub_adapters.append((SingleLoraAdapter(target=linear, rank=rank, scale=scale), parent))
|
|
||||||
|
|
||||||
if weights is not None:
|
|
||||||
assert len(self.sub_adapters) == (len(weights) // 2)
|
|
||||||
for i, (adapter, _) in enumerate(self.sub_adapters):
|
|
||||||
lora = adapter.Lora
|
|
||||||
assert (
|
|
||||||
lora.rank == weights[i * 2].shape[1]
|
|
||||||
), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}"
|
|
||||||
adapter.Lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1])
|
|
||||||
|
|
||||||
def inject(self: TLoraAdapter, parent: fl.Chain | None = None) -> TLoraAdapter:
|
|
||||||
for adapter, adapter_parent in self.sub_adapters:
|
|
||||||
adapter.inject(adapter_parent)
|
|
||||||
return super().inject(parent)
|
|
||||||
|
|
||||||
def eject(self) -> None:
|
|
||||||
for adapter, _ in self.sub_adapters:
|
|
||||||
adapter.eject()
|
|
||||||
super().eject()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def weights(self) -> list[Tensor]:
|
def lora(self) -> Lora:
|
||||||
return [w for adapter, _ in self.sub_adapters for w in [adapter.Lora.up_weight, adapter.Lora.down_weight]]
|
return self.ensure_find(Lora)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scale(self) -> float:
|
||||||
|
return self.lora.scale
|
||||||
|
|
||||||
|
@scale.setter
|
||||||
|
def scale(self, value: float) -> None:
|
||||||
|
self.lora.scale = value
|
||||||
|
|
|
@ -1,146 +1,140 @@
|
||||||
from enum import Enum
|
from warnings import warn
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, Iterator
|
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
|
||||||
from refiners.fluxion.adapters.lora import Lora, LoraAdapter
|
from refiners.fluxion.adapters.lora import Lora, LoraAdapter
|
||||||
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
|
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
|
|
||||||
from refiners.foundationals.latent_diffusion import (
|
|
||||||
CLIPTextEncoderL,
|
|
||||||
LatentDiffusionAutoencoder,
|
|
||||||
SD1UNet,
|
|
||||||
StableDiffusion_1,
|
|
||||||
)
|
|
||||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
|
|
||||||
|
|
||||||
MODELS = ["unet", "text_encoder", "lda"]
|
|
||||||
|
|
||||||
|
|
||||||
class LoraTarget(str, Enum):
|
class SDLoraManager:
|
||||||
Self = "self"
|
|
||||||
Attention = "Attention"
|
|
||||||
SelfAttention = "SelfAttention"
|
|
||||||
CrossAttention = "CrossAttentionBlock2d"
|
|
||||||
FeedForward = "FeedForward"
|
|
||||||
TransformerLayer = "TransformerLayer"
|
|
||||||
|
|
||||||
def get_class(self) -> type[fl.Chain]:
|
|
||||||
match self:
|
|
||||||
case LoraTarget.Self:
|
|
||||||
return fl.Chain
|
|
||||||
case LoraTarget.Attention:
|
|
||||||
return fl.Attention
|
|
||||||
case LoraTarget.SelfAttention:
|
|
||||||
return fl.SelfAttention
|
|
||||||
case LoraTarget.CrossAttention:
|
|
||||||
return CrossAttentionBlock2d
|
|
||||||
case LoraTarget.FeedForward:
|
|
||||||
return FeedForward
|
|
||||||
case LoraTarget.TransformerLayer:
|
|
||||||
return TransformerLayer
|
|
||||||
|
|
||||||
|
|
||||||
def _predicate(k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
|
|
||||||
def f(m: fl.Module, _: fl.Chain) -> bool:
|
|
||||||
if isinstance(m, Lora): # do not adapt other LoRAs
|
|
||||||
raise StopIteration
|
|
||||||
if isinstance(m, Controlnet): # do not adapt Controlnet linears
|
|
||||||
raise StopIteration
|
|
||||||
return isinstance(m, k)
|
|
||||||
|
|
||||||
return f
|
|
||||||
|
|
||||||
|
|
||||||
def _iter_linears(module: fl.Chain) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
|
||||||
for m, p in module.walk(_predicate(fl.Linear)):
|
|
||||||
assert isinstance(m, fl.Linear)
|
|
||||||
yield (m, p)
|
|
||||||
|
|
||||||
|
|
||||||
def lora_targets(
|
|
||||||
module: fl.Chain,
|
|
||||||
target: LoraTarget | list[LoraTarget],
|
|
||||||
) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
|
||||||
if isinstance(target, list):
|
|
||||||
for t in target:
|
|
||||||
yield from lora_targets(module, t)
|
|
||||||
return
|
|
||||||
|
|
||||||
if target == LoraTarget.Self:
|
|
||||||
yield from _iter_linears(module)
|
|
||||||
return
|
|
||||||
|
|
||||||
for layer, _ in module.walk(_predicate(target.get_class())):
|
|
||||||
assert isinstance(layer, fl.Chain)
|
|
||||||
yield from _iter_linears(layer)
|
|
||||||
|
|
||||||
|
|
||||||
class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
|
|
||||||
metadata: dict[str, str] | None
|
|
||||||
tensors: dict[str, Tensor]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
target: StableDiffusion_1,
|
target: LatentDiffusionModel,
|
||||||
sub_targets: dict[str, list[LoraTarget]],
|
) -> None:
|
||||||
|
self.target = target
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unet(self) -> fl.Chain:
|
||||||
|
unet = self.target.unet
|
||||||
|
assert isinstance(unet, fl.Chain)
|
||||||
|
return unet
|
||||||
|
|
||||||
|
@property
|
||||||
|
def clip_text_encoder(self) -> fl.Chain:
|
||||||
|
clip_text_encoder = self.target.clip_text_encoder
|
||||||
|
assert isinstance(clip_text_encoder, fl.Chain)
|
||||||
|
return clip_text_encoder
|
||||||
|
|
||||||
|
def load(
|
||||||
|
self,
|
||||||
|
tensors: dict[str, Tensor],
|
||||||
|
/,
|
||||||
scale: float = 1.0,
|
scale: float = 1.0,
|
||||||
weights: dict[str, Tensor] | None = None,
|
) -> None:
|
||||||
):
|
"""Load the LoRA weights from a dictionary of tensors.
|
||||||
with self.setup_adapter(target):
|
|
||||||
super().__init__(target)
|
|
||||||
|
|
||||||
self.sub_adapters: list[LoraAdapter[SD1UNet | CLIPTextEncoderL | LatentDiffusionAutoencoder]] = []
|
Expects the keys to be in the commonly found formats on CivitAI's hub.
|
||||||
|
"""
|
||||||
for model_name in MODELS:
|
assert len(self.lora_adapters) == 0, "Loras already loaded"
|
||||||
if not (model_targets := sub_targets.get(model_name, [])):
|
loras = Lora.from_dict(
|
||||||
continue
|
{key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()}
|
||||||
model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name)
|
|
||||||
|
|
||||||
lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None
|
|
||||||
self.sub_adapters.append(
|
|
||||||
LoraAdapter[type(model)](
|
|
||||||
model,
|
|
||||||
sub_targets=lora_targets(model, model_targets),
|
|
||||||
scale=scale,
|
|
||||||
weights=lora_weights,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}
|
||||||
|
|
||||||
@classmethod
|
# if no key contains "unet" or "text", assume all keys are for the unet
|
||||||
def from_safetensors(
|
if not "unet" in loras and not "text" in loras:
|
||||||
cls,
|
loras = {f"unet_{key}": loras[key] for key in loras.keys()}
|
||||||
target: StableDiffusion_1,
|
|
||||||
checkpoint_path: Path | str,
|
|
||||||
scale: float = 1.0,
|
|
||||||
):
|
|
||||||
metadata = load_metadata_from_safetensors(checkpoint_path)
|
|
||||||
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
|
|
||||||
tensors = load_from_safetensors(checkpoint_path, device=target.device)
|
|
||||||
|
|
||||||
sub_targets: dict[str, list[LoraTarget]] = {}
|
self.load_unet(loras)
|
||||||
for model_name in MODELS:
|
self.load_text_encoder(loras)
|
||||||
if not (v := metadata.get(f"{model_name}_targets", "")):
|
|
||||||
continue
|
|
||||||
sub_targets[model_name] = [LoraTarget(x) for x in v.split(",")]
|
|
||||||
|
|
||||||
return cls(
|
self.scale = scale
|
||||||
target,
|
|
||||||
sub_targets,
|
|
||||||
scale=scale,
|
|
||||||
weights=tensors,
|
|
||||||
)
|
|
||||||
|
|
||||||
def inject(self: "SD1LoraAdapter", parent: fl.Chain | None = None) -> "SD1LoraAdapter":
|
def load_text_encoder(self, loras: dict[str, Lora], /) -> None:
|
||||||
for adapter in self.sub_adapters:
|
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
||||||
adapter.inject()
|
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
|
||||||
return super().inject(parent)
|
|
||||||
|
|
||||||
def eject(self) -> None:
|
def load_unet(self, loras: dict[str, Lora], /) -> None:
|
||||||
for adapter in self.sub_adapters:
|
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
|
||||||
adapter.eject()
|
exclude: list[str] = []
|
||||||
super().eject()
|
exclude = [
|
||||||
|
self.unet_exclusions[exclusion]
|
||||||
|
for exclusion in self.unet_exclusions
|
||||||
|
if all([exclusion not in key for key in unet_loras.keys()])
|
||||||
|
]
|
||||||
|
SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude)
|
||||||
|
|
||||||
|
def unload(self) -> None:
|
||||||
|
for lora_adapter in self.lora_adapters:
|
||||||
|
lora_adapter.eject()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loras(self) -> list[Lora]:
|
||||||
|
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lora_adapters(self) -> list[LoraAdapter]:
|
||||||
|
return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unet_exclusions(self) -> dict[str, str]:
|
||||||
|
return {
|
||||||
|
"time": "TimestepEncoder",
|
||||||
|
"res": "ResidualBlock",
|
||||||
|
"downsample": "DownsampleBlock",
|
||||||
|
"upsample": "UpsampleBlock",
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scale(self) -> float:
|
||||||
|
assert len(self.loras) > 0, "No loras found"
|
||||||
|
assert all([lora.scale == self.loras[0].scale for lora in self.loras])
|
||||||
|
return self.loras[0].scale
|
||||||
|
|
||||||
|
@scale.setter
|
||||||
|
def scale(self, value: float) -> None:
|
||||||
|
for lora in self.loras:
|
||||||
|
lora.scale = value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pad(input: str, /, padding_length: int = 2) -> str:
|
||||||
|
new_split: list[str] = []
|
||||||
|
for s in input.split("_"):
|
||||||
|
if s.isdigit():
|
||||||
|
new_split.append(s.zfill(padding_length))
|
||||||
|
else:
|
||||||
|
new_split.append(s)
|
||||||
|
return "_".join(new_split)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sort_keys(key: str, /) -> tuple[str, int]:
|
||||||
|
# out0 happens sometimes as an alias for out ; this dict might not be exhaustive
|
||||||
|
key_char_order = {"q": 1, "k": 2, "v": 3, "out": 4, "out0": 4}
|
||||||
|
|
||||||
|
for i, s in enumerate(key.split("_")):
|
||||||
|
if s in key_char_order:
|
||||||
|
prefix = SDLoraManager.pad("_".join(key.split("_")[:i]))
|
||||||
|
return (prefix, key_char_order[s])
|
||||||
|
|
||||||
|
return (SDLoraManager.pad(key), 5)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def auto_attach(
|
||||||
|
loras: dict[str, Lora],
|
||||||
|
target: fl.Chain,
|
||||||
|
/,
|
||||||
|
exclude: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
failed_loras: dict[str, Lora] = {}
|
||||||
|
for key, lora in loras.items():
|
||||||
|
if attach := lora.auto_attach(target, exclude=exclude):
|
||||||
|
adapter, parent = attach
|
||||||
|
adapter.inject(parent)
|
||||||
|
else:
|
||||||
|
failed_loras[key] = lora
|
||||||
|
|
||||||
|
if failed_loras:
|
||||||
|
warn(f"failed to attach {len(failed_loras)}/{len(loras)} loras to {target.__class__.__name__}")
|
||||||
|
|
||||||
|
# TODO: add a stronger sanity check to make sure loras are attached correctly
|
||||||
|
|
|
@ -1,82 +0,0 @@
|
||||||
from torch import allclose, randn
|
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
|
||||||
from refiners.fluxion.adapters.lora import Lora, LoraAdapter, SingleLoraAdapter
|
|
||||||
|
|
||||||
|
|
||||||
def test_single_lora_adapter() -> None:
|
|
||||||
chain = fl.Chain(
|
|
||||||
fl.Chain(
|
|
||||||
fl.Linear(in_features=1, out_features=1),
|
|
||||||
fl.Linear(in_features=1, out_features=1),
|
|
||||||
),
|
|
||||||
fl.Linear(in_features=1, out_features=2),
|
|
||||||
)
|
|
||||||
x = randn(1, 1)
|
|
||||||
y = chain(x)
|
|
||||||
|
|
||||||
lora_adapter = SingleLoraAdapter(chain.Chain.Linear_1).inject(chain.Chain)
|
|
||||||
|
|
||||||
assert isinstance(lora_adapter[1], Lora)
|
|
||||||
assert allclose(input=chain(x), other=y)
|
|
||||||
assert lora_adapter.parent == chain.Chain
|
|
||||||
|
|
||||||
lora_adapter.eject()
|
|
||||||
assert isinstance(chain.Chain[0], fl.Linear)
|
|
||||||
assert len(chain) == 2
|
|
||||||
|
|
||||||
lora_adapter.inject(chain.Chain)
|
|
||||||
assert isinstance(chain.Chain[0], SingleLoraAdapter)
|
|
||||||
|
|
||||||
|
|
||||||
def test_lora_adapter() -> None:
|
|
||||||
chain = fl.Chain(
|
|
||||||
fl.Chain(
|
|
||||||
fl.Linear(in_features=1, out_features=1),
|
|
||||||
fl.Linear(in_features=1, out_features=1),
|
|
||||||
),
|
|
||||||
fl.Linear(in_features=1, out_features=2),
|
|
||||||
)
|
|
||||||
|
|
||||||
# create and inject twice
|
|
||||||
|
|
||||||
a1 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
|
|
||||||
assert len(list(chain.layers(Lora))) == 3
|
|
||||||
|
|
||||||
a2 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
|
|
||||||
assert len(list(chain.layers(Lora))) == 6
|
|
||||||
|
|
||||||
# If we init a LoRA when another LoRA is already injected, the Linear
|
|
||||||
# layers of the first LoRA will be adapted too, which is typically not
|
|
||||||
# what we want.
|
|
||||||
# This issue can be avoided either by making the predicate for
|
|
||||||
# `walk` raise StopIteration when it encounters a LoRA (see the SD LoRA)
|
|
||||||
# or by creating all the LoRA Adapters first, before injecting them
|
|
||||||
# (see below).
|
|
||||||
assert len(list(chain.layers(Lora, recurse=True))) == 12
|
|
||||||
|
|
||||||
# ejection in forward order
|
|
||||||
|
|
||||||
a1.eject()
|
|
||||||
assert len(list(chain.layers(Lora))) == 3
|
|
||||||
a2.eject()
|
|
||||||
assert len(list(chain.layers(Lora))) == 0
|
|
||||||
|
|
||||||
# create twice then inject twice
|
|
||||||
|
|
||||||
a1 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0)
|
|
||||||
a2 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0)
|
|
||||||
a1.inject()
|
|
||||||
a2.inject()
|
|
||||||
assert len(list(chain.layers(Lora))) == 6
|
|
||||||
|
|
||||||
# If we inject after init we do not have the target selection problem,
|
|
||||||
# the LoRA layers are not adapted.
|
|
||||||
assert len(list(chain.layers(Lora, recurse=True))) == 6
|
|
||||||
|
|
||||||
# ejection in reverse order
|
|
||||||
|
|
||||||
a2.eject()
|
|
||||||
assert len(list(chain.layers(Lora))) == 3
|
|
||||||
a1.eject()
|
|
||||||
assert len(list(chain.layers(Lora))) == 0
|
|
|
@ -19,7 +19,7 @@ from refiners.foundationals.latent_diffusion import (
|
||||||
StableDiffusion_1,
|
StableDiffusion_1,
|
||||||
StableDiffusion_1_Inpainting,
|
StableDiffusion_1_Inpainting,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
||||||
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
|
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
|
||||||
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
||||||
from refiners.foundationals.latent_diffusion.restart import Restart
|
from refiners.foundationals.latent_diffusion.restart import Restart
|
||||||
|
@ -182,14 +182,38 @@ def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[
|
||||||
condition_image = Image.open(ref_path / f"fairy_guide_{name}.png").convert("RGB")
|
condition_image = Image.open(ref_path / f"fairy_guide_{name}.png").convert("RGB")
|
||||||
expected_image = Image.open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB")
|
expected_image = Image.open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB")
|
||||||
weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors"
|
weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors"
|
||||||
|
|
||||||
|
if not weights_path.is_file():
|
||||||
|
warn(f"could not find weights at {weights_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
return name, condition_image, expected_image, weights_path
|
return name, condition_image, expected_image, weights_path
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Path]:
|
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
|
||||||
expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB")
|
expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB")
|
||||||
weights_path = test_weights_path / "loras" / "pcuenq_pokemon_lora.safetensors"
|
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
|
||||||
return expected_image, weights_path
|
|
||||||
|
if not weights_path.is_file():
|
||||||
|
warn(f"could not find weights at {weights_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
|
tensors = torch.load(weights_path) # type: ignore
|
||||||
|
return expected_image, tensors
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]:
|
||||||
|
expected_image = Image.open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB")
|
||||||
|
weights_path = test_weights_path / "loras" / "dpo-lora" / "pytorch_lora_weights.safetensors"
|
||||||
|
|
||||||
|
if not weights_path.is_file():
|
||||||
|
warn(f"could not find weights at {weights_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
|
tensors = load_from_safetensors(weights_path)
|
||||||
|
return expected_image, tensors
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -1010,24 +1034,20 @@ def test_diffusion_controlnet_stack(
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_lora(
|
def test_diffusion_lora(
|
||||||
sd15_std: StableDiffusion_1,
|
sd15_std: StableDiffusion_1,
|
||||||
lora_data_pokemon: tuple[Image.Image, Path],
|
lora_data_pokemon: tuple[Image.Image, dict[str, torch.Tensor]],
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
):
|
) -> None:
|
||||||
sd15 = sd15_std
|
sd15 = sd15_std
|
||||||
n_steps = 30
|
n_steps = 30
|
||||||
|
|
||||||
expected_image, lora_weights_path = lora_data_pokemon
|
expected_image, lora_weights = lora_data_pokemon
|
||||||
|
|
||||||
if not lora_weights_path.is_file():
|
|
||||||
warn(f"could not find weights at {lora_weights_path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
|
|
||||||
prompt = "a cute cat"
|
prompt = "a cute cat"
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=1.0).inject()
|
SDLoraManager(sd15).load(lora_weights, scale=1)
|
||||||
|
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
@ -1045,77 +1065,45 @@ def test_diffusion_lora(
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_lora_float16(
|
def test_diffusion_sdxl_lora(
|
||||||
sd15_std_float16: StableDiffusion_1,
|
sdxl_ddim: StableDiffusion_XL,
|
||||||
lora_data_pokemon: tuple[Image.Image, Path],
|
lora_data_dpo: tuple[Image.Image, dict[str, torch.Tensor]],
|
||||||
test_device: torch.device,
|
) -> None:
|
||||||
):
|
sdxl = sdxl_ddim
|
||||||
sd15 = sd15_std_float16
|
expected_image, lora_weights = lora_data_dpo
|
||||||
n_steps = 30
|
|
||||||
|
|
||||||
expected_image, lora_weights_path = lora_data_pokemon
|
# parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA
|
||||||
|
# except that we are using DDIM instead of sde-dpmsolver++
|
||||||
|
n_steps = 40
|
||||||
|
seed = 12341234123
|
||||||
|
guidance_scale = 7.5
|
||||||
|
lora_scale = 1.4
|
||||||
|
prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography"
|
||||||
|
negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white"
|
||||||
|
|
||||||
if not lora_weights_path.is_file():
|
SDLoraManager(sdxl).load(lora_weights, scale=lora_scale)
|
||||||
warn(f"could not find weights at {lora_weights_path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
|
|
||||||
prompt = "a cute cat"
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
text=prompt, negative_text=negative_prompt
|
||||||
|
)
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
time_ids = sdxl.default_time_ids
|
||||||
|
sdxl.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=1.0).inject()
|
manual_seed(seed=seed)
|
||||||
|
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
|
||||||
|
|
||||||
manual_seed(2)
|
for step in sdxl.steps:
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
x = sdxl(
|
||||||
|
|
||||||
for step in sd15.steps:
|
|
||||||
x = sd15(
|
|
||||||
x,
|
x,
|
||||||
step=step,
|
step=step,
|
||||||
clip_text_embedding=clip_text_embedding,
|
clip_text_embedding=clip_text_embedding,
|
||||||
condition_scale=7.5,
|
pooled_text_embedding=pooled_text_embedding,
|
||||||
|
time_ids=time_ids,
|
||||||
|
condition_scale=guidance_scale,
|
||||||
)
|
)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98)
|
predicted_image = sdxl.lda.decode_latents(x)
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
|
||||||
def test_diffusion_lora_twice(
|
|
||||||
sd15_std: StableDiffusion_1,
|
|
||||||
lora_data_pokemon: tuple[Image.Image, Path],
|
|
||||||
test_device: torch.device,
|
|
||||||
):
|
|
||||||
sd15 = sd15_std
|
|
||||||
n_steps = 30
|
|
||||||
|
|
||||||
expected_image, lora_weights_path = lora_data_pokemon
|
|
||||||
|
|
||||||
if not lora_weights_path.is_file():
|
|
||||||
warn(f"could not find weights at {lora_weights_path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
|
|
||||||
prompt = "a cute cat"
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
|
||||||
|
|
||||||
# The same LoRA is used twice which is not a common use case: this is purely for testing purpose
|
|
||||||
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=0.4).inject()
|
|
||||||
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=0.6).inject()
|
|
||||||
|
|
||||||
manual_seed(2)
|
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
|
||||||
|
|
||||||
for step in sd15.steps:
|
|
||||||
x = sd15(
|
|
||||||
x,
|
|
||||||
step=step,
|
|
||||||
clip_text_embedding=clip_text_embedding,
|
|
||||||
condition_scale=7.5,
|
|
||||||
)
|
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
|
@ -48,6 +48,7 @@ Special cases:
|
||||||
- `expected_restart.png`
|
- `expected_restart.png`
|
||||||
- `expected_freeu.png`
|
- `expected_freeu.png`
|
||||||
- `expected_dropy_slime_9752.png`
|
- `expected_dropy_slime_9752.png`
|
||||||
|
- `expected_sdxl_dpo_lora.png`
|
||||||
|
|
||||||
## Other images
|
## Other images
|
||||||
|
|
||||||
|
|
BIN
tests/e2e/test_diffusion_ref/expected_sdxl_dpo_lora.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_sdxl_dpo_lora.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.6 MiB |
Loading…
Reference in a new issue