diff --git a/README.md b/README.md index 7dfb632..5d0c1e8 100644 --- a/README.md +++ b/README.md @@ -158,16 +158,6 @@ You should get: ![dropy slime output](https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/dropy_slime_9752.png) -### Training - -Refiners has a built-in training utils library and provides scripts that can be used as a starting point. - -E.g. to train a LoRA on top of Stable Diffusion, copy and edit `configs/finetune-lora.toml` to suit your needs and launch the training as follows: - -```bash -python scripts/training/finetune-ldm-lora.py configs/finetune-lora.toml -``` - ## Adapter Zoo For now, given [finegrain](https://finegrain.ai)'s mission, we are focusing on image edition tasks. We support: diff --git a/configs/finetune-lora.toml b/configs/finetune-lora.toml deleted file mode 100644 index 4d4409c..0000000 --- a/configs/finetune-lora.toml +++ /dev/null @@ -1,68 +0,0 @@ -[wandb] -mode = "offline" # "online", "offline", "disabled" -entity = "acme" -project = "test-lora-training" - -[models] -unet = {checkpoint = "/path/to/stable-diffusion-1-5/unet.safetensors"} -text_encoder = {checkpoint = "/path/to/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"} -lda = {checkpoint = "/path/to/stable-diffusion-1-5/lda.safetensors"} - -[latent_diffusion] -unconditional_sampling_probability = 0.05 -offset_noise = 0.1 - -[lora] -rank = 16 -trigger_phrase = "a spsh photo," -use_only_trigger_probability = 1.0 -unet_targets = ["CrossAttentionBlock2d"] -text_encoder_targets = ["TransformerLayer"] -lda_targets = [] - -[training] -duration = "1000:epoch" -seed = 0 -gpu_index = 0 -batch_size = 4 -gradient_accumulation = "4:step" -clip_grad_norm = 1.0 -# clip_grad_value = 1.0 -evaluation_interval = "5:epoch" -evaluation_seed = 1 - - -[optimizer] -optimizer = "Prodigy" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" -learning_rate = 1 -betas = [0.9, 0.999] -eps = 1e-8 -weight_decay = 1e-2 - -[scheduler] -scheduler_type = "ConstantLR" -update_interval = "1:step" -warmup = "500:step" - - -[dropout] -dropout_probability = 0.2 -use_gyro_dropout = false - -[dataset] -hf_repo = "acme/images" -revision = "main" - -[checkpointing] -# save_folder = "/path/to/ckpts" -save_interval = "1:step" - -[test_diffusion] -num_inference_steps = 30 -use_short_prompts = false -prompts = [ - "a cute cat", - "a cute dog", - "a cute bird", - "a cute horse", -] diff --git a/scripts/conversion/convert_diffusers_lora.py b/scripts/conversion/convert_diffusers_lora.py deleted file mode 100644 index 1c37d8d..0000000 --- a/scripts/conversion/convert_diffusers_lora.py +++ /dev/null @@ -1,149 +0,0 @@ -import argparse -from pathlib import Path -from typing import cast - -import torch -from diffusers import DiffusionPipeline # type: ignore -from torch import Tensor -from torch.nn import Parameter as TorchParameter -from torch.nn.init import zeros_ - -import refiners.fluxion.layers as fl -from refiners.fluxion.adapters.lora import Lora, LoraAdapter -from refiners.fluxion.model_converter import ModelConverter -from refiners.fluxion.utils import no_grad, save_to_safetensors -from refiners.foundationals.latent_diffusion import SD1UNet -from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets - - -def get_weight(linear: fl.Linear) -> torch.Tensor: - assert linear.bias is None - return linear.state_dict()["weight"] - - -def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, torch.Tensor]: - weights: list[torch.Tensor] = [] - for lora in module.layers(layer_type=Lora): - linears = list(lora.layers(layer_type=fl.Linear)) - assert len(linears) == 2 - weights.extend((get_weight(linear=linears[1]), get_weight(linear=linears[0]))) # aka (up_weight, down_weight) - return {f"{key_prefix}{i:03d}": w for i, w in enumerate(iterable=weights)} - - -class Args(argparse.Namespace): - source_path: str - base_model: str - output_file: str - verbose: bool - - -@no_grad() -def process(args: Args) -> None: - diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore - # low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate` - diffusers_sd = DiffusionPipeline.from_pretrained( # type: ignore - pretrained_model_name_or_path=args.base_model, - low_cpu_mem_usage=False, - ) - diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore - - refiners_model = SD1UNet(in_channels=4) - target = LoraTarget.CrossAttention - metadata = {"unet_targets": "CrossAttentionBlock2d"} - rank = diffusers_state_dict[ - "mid_block.attentions.0.transformer_blocks.0.attn1.processor.to_q_lora.down.weight" - ].shape[0] - - x = torch.randn(1, 4, 32, 32) - timestep = torch.tensor(data=[0]) - clip_text_embeddings = torch.randn(1, 77, 768) - - refiners_model.set_timestep(timestep=timestep) - refiners_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings) - refiners_args = (x,) - - diffusers_args = (x, timestep, clip_text_embeddings) - - converter = ModelConverter( - source_model=refiners_model, target_model=diffusers_model, skip_output_check=True, verbose=args.verbose - ) - if not converter.run( - source_args=refiners_args, - target_args=diffusers_args, - ): - raise RuntimeError("Model conversion failed") - - diffusers_to_refiners = converter.get_mapping() - - LoraAdapter[SD1UNet](refiners_model, sub_targets=lora_targets(refiners_model, target), rank=rank).inject() - - for layer in refiners_model.layers(layer_type=Lora): - zeros_(tensor=layer.Linear_1.weight) - - targets = {k.split("_lora.")[0] for k in diffusers_state_dict.keys()} - for target_k in targets: - k_p, k_s = target_k.split(".processor.") - r = [v for k, v in diffusers_to_refiners.items() if k.startswith(f"{k_p}.{k_s}")] - assert len(r) == 1 - orig_k = r[0] - orig_path = orig_k.split(sep=".") - p = refiners_model - for seg in orig_path[:-1]: - p = p[seg] - assert isinstance(p, fl.Chain) - last_seg = ( - "SingleLoraAdapter" - if orig_path[-1] == "Linear" - else f"SingleLoraAdapter_{orig_path[-1].removeprefix('Linear_')}" - ) - p_down = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.down.weight"]) - p_up = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.up.weight"]) - p[last_seg].Lora.load_weights(p_down, p_up) - - state_dict = build_loras_safetensors(module=refiners_model, key_prefix="unet.") - assert len(state_dict) == 320 - save_to_safetensors(path=args.output_path, tensors=state_dict, metadata=metadata) - - -def main() -> None: - parser = argparse.ArgumentParser(description="Convert LoRAs saved using the diffusers library to refiners format.") - parser.add_argument( - "--from", - type=str, - dest="source_path", - required=True, - help="Source file path (.bin|safetensors) containing the LoRAs.", - ) - parser.add_argument( - "--base-model", - type=str, - required=False, - default="runwayml/stable-diffusion-v1-5", - help="Base model, used for the UNet structure. Default: runwayml/stable-diffusion-v1-5", - ) - parser.add_argument( - "--to", - type=str, - dest="output_path", - required=False, - default=None, - help=( - "Output file path (.safetensors) for converted LoRAs. If not provided, the output path will be the same as" - " the source path." - ), - ) - parser.add_argument( - "--verbose", - action="store_true", - dest="verbose", - default=False, - help="Use this flag to print verbose output during conversion.", - ) - args = parser.parse_args(namespace=Args()) - if args.output_path is None: - args.output_path = f"{Path(args.source_path).stem}-refiners.safetensors" - process(args=args) - - -if __name__ == "__main__": - main() diff --git a/scripts/conversion/convert_refiners_lora_to_sdwebui.py b/scripts/conversion/convert_refiners_lora_to_sdwebui.py deleted file mode 100644 index e8105ea..0000000 --- a/scripts/conversion/convert_refiners_lora_to_sdwebui.py +++ /dev/null @@ -1,124 +0,0 @@ -import argparse -from functools import partial - -from convert_diffusers_unet import Args as UnetConversionArgs, setup_converter as convert_unet -from convert_transformers_clip_text_model import ( - Args as TextEncoderConversionArgs, - setup_converter as convert_text_encoder, -) -from torch import Tensor - -import refiners.fluxion.layers as fl -from refiners.fluxion.utils import ( - load_from_safetensors, - load_metadata_from_safetensors, - save_to_safetensors, -) -from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL -from refiners.foundationals.latent_diffusion import SD1UNet -from refiners.foundationals.latent_diffusion.lora import LoraTarget - - -def get_unet_mapping(source_path: str) -> dict[str, str]: - args = UnetConversionArgs(source_path=source_path, verbose=False) - return convert_unet(args=args).get_mapping() - - -def get_text_encoder_mapping(source_path: str) -> dict[str, str]: - args = TextEncoderConversionArgs(source_path=source_path, subfolder="text_encoder", verbose=False) - return convert_text_encoder( - args=args, - ).get_mapping() - - -def main() -> None: - parser = argparse.ArgumentParser(description="Converts a refiner's LoRA weights to SD-WebUI's LoRA weights") - parser.add_argument( - "-i", - "--input-file", - type=str, - required=True, - help="Path to the input file with refiner's LoRA weights (safetensors format)", - ) - parser.add_argument( - "-o", - "--output-file", - type=str, - default="sdwebui_loras.safetensors", - help="Path to the output file with sd-webui's LoRA weights (safetensors format)", - ) - parser.add_argument( - "--sd15", - type=str, - required=False, - default="runwayml/stable-diffusion-v1-5", - help="Path (preferred) or repository ID of Stable Diffusion 1.5 model (Hugging Face diffusers format)", - ) - args = parser.parse_args() - metadata = load_metadata_from_safetensors(path=args.input_file) - assert metadata is not None, f"Could not load metadata from {args.input_file}" - tensors = load_from_safetensors(path=args.input_file) - - state_dict: dict[str, Tensor] = {} - - for meta_key, meta_value in metadata.items(): - match meta_key: - case "unet_targets": - model = SD1UNet(in_channels=4) - create_mapping = partial(get_unet_mapping, source_path=args.sd15) - key_prefix = "unet." - lora_prefix = "lora_unet_" - case "text_encoder_targets": - model = CLIPTextEncoderL() - create_mapping = partial(get_text_encoder_mapping, source_path=args.sd15) - key_prefix = "text_encoder." - lora_prefix = "lora_te_" - case "lda_targets": - raise ValueError("SD-WebUI does not support LoRA for the auto-encoder") - case _: - raise ValueError(f"Unexpected key in checkpoint metadata: {meta_key}") - - submodule_to_key: dict[fl.Module, str] = {} - for name, submodule in model.named_modules(): - submodule_to_key[submodule] = name - - # SD-WebUI expects LoRA state dicts with keys derived from the diffusers format, e.g.: - # - # lora_unet_down_blocks_0_attentions_0_proj_in.alpha - # lora_unet_down_blocks_0_attentions_0_proj_in.lora_down.weight - # lora_unet_down_blocks_0_attentions_0_proj_in.lora_up.weight - # ... - # - # Internally SD-WebUI has some logic[1] to convert such keys into the CompVis format. See - # `convert_diffusers_name_to_compvis` for more details. - # - # [1]: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/394ffa7/extensions-builtin/Lora/lora.py#L158-L225 - - refiners_to_diffusers = create_mapping() - assert refiners_to_diffusers is not None - - # Compute the corresponding diffusers' keys where LoRA layers must be applied - lora_injection_points: list[str] = [ - refiners_to_diffusers[submodule_to_key[linear]] - for target in [LoraTarget(t) for t in meta_value.split(sep=",")] - for layer in model.layers(layer_type=target.get_class()) - for linear in layer.layers(layer_type=fl.Linear) - ] - - lora_weights = [w for w in [tensors[k] for k in sorted(tensors) if k.startswith(key_prefix)]] - assert len(lora_injection_points) == len(lora_weights) // 2 - - # Map LoRA weights to each key using SD-WebUI conventions (proper prefix and suffix, underscores) - for i, diffusers_key in enumerate(iterable=lora_injection_points): - lora_key = lora_prefix + diffusers_key.replace(".", "_") - # Note: no ".alpha" weights (those are used to scale the LoRA by alpha/rank). Refiners uses a scale = 1.0 - # by default (see `lora_calc_updown` in SD-WebUI for more details) - state_dict[lora_key + ".lora_up.weight"] = lora_weights[2 * i] - state_dict[lora_key + ".lora_down.weight"] = lora_weights[2 * i + 1] - - assert state_dict - save_to_safetensors(path=args.output_file, tensors=state_dict) - - -if __name__ == "__main__": - main() diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py index 39973b3..5a6d928 100644 --- a/scripts/prepare_test_weights.py +++ b/scripts/prepare_test_weights.py @@ -229,10 +229,15 @@ def download_vae_ft_mse(): ) -def download_lora(): - dest_folder = os.path.join(test_weights_dir, "pcuenq", "pokemon-lora") +def download_loras(): + dest_folder = os.path.join(test_weights_dir, "loras", "pokemon-lora") download_file("https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin", dest_folder) + dest_folder = os.path.join(test_weights_dir, "loras", "dpo-lora") + download_file( + "https://huggingface.co/radames/sdxl-DPO-LoRA/resolve/main/pytorch_lora_weights.safetensors", dest_folder + ) + def download_preprocessors(): dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings") @@ -454,17 +459,6 @@ def convert_vae_fp16_fix(): ) -def convert_lora(): - os.makedirs("tests/weights/loras", exist_ok=True) - run_conversion_script( - "convert_diffusers_lora.py", - "tests/weights/pcuenq/pokemon-lora/pytorch_lora_weights.bin", - "tests/weights/loras/pcuenq_pokemon_lora.safetensors", - additional_args=["--base-model", "tests/weights/runwayml/stable-diffusion-v1-5"], - expected_hash="a9d7e08e", - ) - - def convert_preprocessors(): subprocess.run( [ @@ -632,7 +626,7 @@ def download_all(): download_sdxl("stabilityai/stable-diffusion-xl-base-1.0") download_vae_ft_mse() download_vae_fp16_fix() - download_lora() + download_loras() download_preprocessors() download_controlnet() download_unclip() @@ -647,7 +641,7 @@ def convert_all(): convert_sdxl() convert_vae_ft_mse() convert_vae_fp16_fix() - convert_lora() + # Note: no convert loras: this is done at runtime by `SDLoraManager` convert_preprocessors() convert_controlnet() convert_unclip() diff --git a/scripts/training/finetune-ldm-lora.py b/scripts/training/finetune-ldm-lora.py deleted file mode 100644 index 32bcf16..0000000 --- a/scripts/training/finetune-ldm-lora.py +++ /dev/null @@ -1,123 +0,0 @@ -import random -from typing import Any - -from loguru import logger -from pydantic import BaseModel -from torch import Tensor -from torch.utils.data import Dataset - -import refiners.fluxion.layers as fl -from refiners.fluxion.utils import save_to_safetensors -from refiners.foundationals.latent_diffusion.lora import MODELS, LoraAdapter, LoraTarget, lora_targets -from refiners.training_utils.callback import Callback -from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig -from refiners.training_utils.latent_diffusion import ( - FinetuneLatentDiffusionConfig, - LatentDiffusionConfig, - LatentDiffusionTrainer, - TextEmbeddingLatentsBatch, - TextEmbeddingLatentsDataset, -) - - -class LoraConfig(BaseModel): - rank: int = 32 - trigger_phrase: str = "" - use_only_trigger_probability: float = 0.0 - unet_targets: list[LoraTarget] - text_encoder_targets: list[LoraTarget] - lda_targets: list[LoraTarget] - - -class TriggerPhraseDataset(TextEmbeddingLatentsDataset): - def __init__( - self, - trainer: "LoraLatentDiffusionTrainer", - ) -> None: - super().__init__(trainer=trainer) - self.trigger_phrase = trainer.config.lora.trigger_phrase - self.use_only_trigger_probability = trainer.config.lora.use_only_trigger_probability - logger.info(f"Trigger phrase: {self.trigger_phrase}") - - def process_caption(self, caption: str) -> str: - caption = super().process_caption(caption=caption) - if self.trigger_phrase: - caption = ( - f"{self.trigger_phrase} {caption}" - if random.random() < self.use_only_trigger_probability - else self.trigger_phrase - ) - return caption - - -class LoraLatentDiffusionConfig(FinetuneLatentDiffusionConfig): - dataset: HuggingfaceDatasetConfig - latent_diffusion: LatentDiffusionConfig - lora: LoraConfig - - def model_post_init(self, __context: Any) -> None: - """Pydantic v2 does post init differently, so we need to override this method too.""" - logger.info("Freezing models to train only the loras.") - self.models["unet"].train = False - self.models["text_encoder"].train = False - self.models["lda"].train = False - - -class LoraLatentDiffusionTrainer(LatentDiffusionTrainer[LoraLatentDiffusionConfig]): - def __init__( - self, - config: LoraLatentDiffusionConfig, - callbacks: "list[Callback[Any]] | None" = None, - ) -> None: - super().__init__(config=config, callbacks=callbacks) - self.callbacks.extend((LoadLoras(), SaveLoras())) - - def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]: - return TriggerPhraseDataset(trainer=self) - - -class LoadLoras(Callback[LoraLatentDiffusionTrainer]): - def on_train_begin(self, trainer: LoraLatentDiffusionTrainer) -> None: - lora_config = trainer.config.lora - - for model_name in MODELS: - model = getattr(trainer, model_name) - model_targets: list[LoraTarget] = getattr(lora_config, f"{model_name}_targets") - adapter = LoraAdapter[type(model)]( - model, - sub_targets=lora_targets(model, model_targets), - rank=lora_config.rank, - ) - for sub_adapter, _ in adapter.sub_adapters: - for linear in sub_adapter.Lora.layers(fl.Linear): - linear.requires_grad_(requires_grad=True) - adapter.inject() - - -class SaveLoras(Callback[LoraLatentDiffusionTrainer]): - def on_checkpoint_save(self, trainer: LoraLatentDiffusionTrainer) -> None: - tensors: dict[str, Tensor] = {} - metadata: dict[str, str] = {} - - for model_name in MODELS: - model = getattr(trainer, model_name) - adapter = model.parent - tensors |= {f"{model_name}.{i:03d}": w for i, w in enumerate(adapter.weights)} - metadata |= {f"{model_name}_targets": ",".join(adapter.sub_targets)} - - save_to_safetensors( - path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors", - tensors=tensors, - metadata=metadata, - ) - - -if __name__ == "__main__": - import sys - - config_path = sys.argv[1] - config = LoraLatentDiffusionConfig.load_from_toml( - toml_path=config_path, - ) - trainer = LoraLatentDiffusionTrainer(config=config) - trainer.train() diff --git a/src/refiners/fluxion/adapters/lora.py b/src/refiners/fluxion/adapters/lora.py index b0f2e89..6b8e992 100644 --- a/src/refiners/fluxion/adapters/lora.py +++ b/src/refiners/fluxion/adapters/lora.py @@ -1,4 +1,5 @@ -from typing import Any, Generic, Iterable, TypeVar +from abc import ABC, abstractmethod +from typing import Any from torch import Tensor, device as Device, dtype as DType from torch.nn import Parameter as TorchParameter @@ -6,125 +7,259 @@ from torch.nn.init import normal_, zeros_ import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import Adapter - -T = TypeVar("T", bound=fl.Chain) -TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PEP 673) +from refiners.fluxion.layers.chain import Chain -class Lora(fl.Chain): +class Lora(fl.Chain, ABC): + def __init__( + self, + rank: int = 16, + scale: float = 1.0, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.rank = rank + self._scale = scale + + super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale)) + + normal_(tensor=self.down.weight, std=1 / self.rank) + zeros_(tensor=self.up.weight) + + @abstractmethod + def lora_layers( + self, device: Device | str | None = None, dtype: DType | None = None + ) -> tuple[fl.WeightedModule, fl.WeightedModule]: + ... + + @property + def down(self) -> fl.WeightedModule: + down_layer = self[0] + assert isinstance(down_layer, fl.WeightedModule) + return down_layer + + @property + def up(self) -> fl.WeightedModule: + up_layer = self[1] + assert isinstance(up_layer, fl.WeightedModule) + return up_layer + + @property + def scale(self) -> float: + return self._scale + + @scale.setter + def scale(self, value: float) -> None: + self._scale = value + self.ensure_find(fl.Multiply).scale = value + + @classmethod + def from_weights( + cls, + down: Tensor, + up: Tensor, + ) -> "Lora": + match (up.ndim, down.ndim): + case (2, 2): + return LinearLora.from_weights(up=up, down=down) + case (4, 4): + return Conv2dLora.from_weights(up=up, down=down) + case _: + raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}") + + @classmethod + def from_dict(cls, state_dict: dict[str, Tensor], /) -> dict[str, "Lora"]: + """ + Create a dictionary of LoRA layers from a state dict. + + Expects the state dict to be a succession of down and up weights. + """ + state_dict = {k: v for k, v in state_dict.items() if ".weight" in k} + loras: dict[str, Lora] = {} + for down_key, down_tensor, up_tensor in zip( + list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2] + ): + key = ".".join(down_key.split(".")[:-2]) + loras[key] = cls.from_weights(down=down_tensor, up=up_tensor) + return loras + + @abstractmethod + def auto_attach(self, target: fl.Chain, exclude: list[str] | None = None) -> Any: + ... + + def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None: + assert down_weight.shape == self.down.weight.shape + assert up_weight.shape == self.up.weight.shape + self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype)) + self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype)) + + +class LinearLora(Lora): def __init__( self, in_features: int, out_features: int, rank: int = 16, + scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None, ) -> None: self.in_features = in_features self.out_features = out_features - self.rank = rank - self.scale: float = 1.0 - super().__init__( - fl.Linear(in_features=in_features, out_features=rank, bias=False, device=device, dtype=dtype), - fl.Linear(in_features=rank, out_features=out_features, bias=False, device=device, dtype=dtype), - fl.Lambda(func=self.scale_outputs), + super().__init__(rank=rank, scale=scale, device=device, dtype=dtype) + + @classmethod + def from_weights( + cls, + down: Tensor, + up: Tensor, + ) -> "LinearLora": + assert up.ndim == 2 and down.ndim == 2 + assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}" + lora = cls( + in_features=down.shape[1], out_features=up.shape[0], rank=down.shape[0], device=up.device, dtype=up.dtype + ) + lora.load_weights(down_weight=down, up_weight=up) + return lora + + def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None": + for layer, parent in target.walk(fl.Linear): + if isinstance(parent, Lora) or isinstance(parent, LoraAdapter): + continue + + if exclude is not None and any( + [any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude] + ): + continue + + if layer.in_features == self.in_features and layer.out_features == self.out_features: + return LoraAdapter(target=layer, lora=self), parent + + def lora_layers( + self, device: Device | str | None = None, dtype: DType | None = None + ) -> tuple[fl.Linear, fl.Linear]: + return ( + fl.Linear( + in_features=self.in_features, + out_features=self.rank, + bias=False, + device=device, + dtype=dtype, + ), + fl.Linear( + in_features=self.rank, + out_features=self.out_features, + bias=False, + device=device, + dtype=dtype, + ), ) - normal_(tensor=self.Linear_1.weight, std=1 / self.rank) - zeros_(tensor=self.Linear_2.weight) - def scale_outputs(self, x: Tensor) -> Tensor: - return x * self.scale - - def set_scale(self, scale: float) -> None: - self.scale = scale - - def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None: - self.Linear_1.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype)) - self.Linear_2.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype)) - - @property - def up_weight(self) -> Tensor: - return self.Linear_2.weight.data - - @property - def down_weight(self) -> Tensor: - return self.Linear_1.weight.data - - -class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]): +class Conv2dLora(Lora): def __init__( self, - target: fl.Linear, + in_channels: int, + out_channels: int, rank: int = 16, scale: float = 1.0, + kernel_size: tuple[int, int] = (1, 3), + stride: tuple[int, int] = (1, 1), + padding: tuple[int, int] = (0, 1), + device: Device | str | None = None, + dtype: DType | None = None, ) -> None: - self.in_features = target.in_features - self.out_features = target.out_features - self.rank = rank - self.scale = scale + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + + super().__init__(rank=rank, scale=scale, device=device, dtype=dtype) + + @classmethod + def from_weights( + cls, + down: Tensor, + up: Tensor, + ) -> "Conv2dLora": + assert up.ndim == 4 and down.ndim == 4 + assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}" + down_kernel_size, up_kernel_size = down.shape[2], up.shape[2] + down_padding = 1 if down_kernel_size == 3 else 0 + up_padding = 1 if up_kernel_size == 3 else 0 + lora = cls( + in_channels=down.shape[1], + out_channels=up.shape[0], + rank=down.shape[0], + kernel_size=(down_kernel_size, up_kernel_size), + padding=(down_padding, up_padding), + device=up.device, + dtype=up.dtype, + ) + lora.load_weights(down_weight=down, up_weight=up) + return lora + + def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None": + for layer, parent in target.walk(fl.Conv2d): + if isinstance(parent, Lora) or isinstance(parent, LoraAdapter): + continue + + if exclude is not None and any( + [any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude] + ): + continue + + if layer.in_channels == self.in_channels and layer.out_channels == self.out_channels: + if layer.stride != (self.stride[0], self.stride[0]): + self.down.stride = layer.stride + + return LoraAdapter( + target=layer, + lora=self, + ), parent + + def lora_layers( + self, device: Device | str | None = None, dtype: DType | None = None + ) -> tuple[fl.Conv2d, fl.Conv2d]: + return ( + fl.Conv2d( + in_channels=self.in_channels, + out_channels=self.rank, + kernel_size=self.kernel_size[0], + stride=self.stride[0], + padding=self.padding[0], + use_bias=False, + device=device, + dtype=dtype, + ), + fl.Conv2d( + in_channels=self.rank, + out_channels=self.out_channels, + kernel_size=self.kernel_size[1], + stride=self.stride[1], + padding=self.padding[1], + use_bias=False, + device=device, + dtype=dtype, + ), + ) + + +class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]): + def __init__(self, target: fl.WeightedModule, lora: Lora) -> None: with self.setup_adapter(target): - super().__init__( - target, - Lora( - in_features=target.in_features, - out_features=target.out_features, - rank=rank, - device=target.device, - dtype=target.dtype, - ), - ) - self.Lora.set_scale(scale=scale) - - -class LoraAdapter(Generic[T], fl.Chain, Adapter[T]): - def __init__( - self, - target: T, - sub_targets: Iterable[tuple[fl.Linear, fl.Chain]], - rank: int | None = None, - scale: float = 1.0, - weights: list[Tensor] | None = None, - ) -> None: - with self.setup_adapter(target): - super().__init__(target) - - if weights is not None: - assert len(weights) % 2 == 0 - weights_rank = weights[0].shape[1] - if rank is None: - rank = weights_rank - else: - assert rank == weights_rank - - assert rank is not None, "either pass a rank or weights" - - self.sub_targets = sub_targets - self.sub_adapters: list[tuple[SingleLoraAdapter, fl.Chain]] = [] - - for linear, parent in self.sub_targets: - self.sub_adapters.append((SingleLoraAdapter(target=linear, rank=rank, scale=scale), parent)) - - if weights is not None: - assert len(self.sub_adapters) == (len(weights) // 2) - for i, (adapter, _) in enumerate(self.sub_adapters): - lora = adapter.Lora - assert ( - lora.rank == weights[i * 2].shape[1] - ), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}" - adapter.Lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1]) - - def inject(self: TLoraAdapter, parent: fl.Chain | None = None) -> TLoraAdapter: - for adapter, adapter_parent in self.sub_adapters: - adapter.inject(adapter_parent) - return super().inject(parent) - - def eject(self) -> None: - for adapter, _ in self.sub_adapters: - adapter.eject() - super().eject() + super().__init__(target, lora) @property - def weights(self) -> list[Tensor]: - return [w for adapter, _ in self.sub_adapters for w in [adapter.Lora.up_weight, adapter.Lora.down_weight]] + def lora(self) -> Lora: + return self.ensure_find(Lora) + + @property + def scale(self) -> float: + return self.lora.scale + + @scale.setter + def scale(self, value: float) -> None: + self.lora.scale = value diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index d8820ab..d25115d 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -1,146 +1,140 @@ -from enum import Enum -from pathlib import Path -from typing import Callable, Iterator +from warnings import warn from torch import Tensor import refiners.fluxion.layers as fl -from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.lora import Lora, LoraAdapter -from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors -from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer -from refiners.foundationals.latent_diffusion import ( - CLIPTextEncoderL, - LatentDiffusionAutoencoder, - SD1UNet, - StableDiffusion_1, -) -from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d -from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet - -MODELS = ["unet", "text_encoder", "lda"] +from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel -class LoraTarget(str, Enum): - Self = "self" - Attention = "Attention" - SelfAttention = "SelfAttention" - CrossAttention = "CrossAttentionBlock2d" - FeedForward = "FeedForward" - TransformerLayer = "TransformerLayer" - - def get_class(self) -> type[fl.Chain]: - match self: - case LoraTarget.Self: - return fl.Chain - case LoraTarget.Attention: - return fl.Attention - case LoraTarget.SelfAttention: - return fl.SelfAttention - case LoraTarget.CrossAttention: - return CrossAttentionBlock2d - case LoraTarget.FeedForward: - return FeedForward - case LoraTarget.TransformerLayer: - return TransformerLayer - - -def _predicate(k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]: - def f(m: fl.Module, _: fl.Chain) -> bool: - if isinstance(m, Lora): # do not adapt other LoRAs - raise StopIteration - if isinstance(m, Controlnet): # do not adapt Controlnet linears - raise StopIteration - return isinstance(m, k) - - return f - - -def _iter_linears(module: fl.Chain) -> Iterator[tuple[fl.Linear, fl.Chain]]: - for m, p in module.walk(_predicate(fl.Linear)): - assert isinstance(m, fl.Linear) - yield (m, p) - - -def lora_targets( - module: fl.Chain, - target: LoraTarget | list[LoraTarget], -) -> Iterator[tuple[fl.Linear, fl.Chain]]: - if isinstance(target, list): - for t in target: - yield from lora_targets(module, t) - return - - if target == LoraTarget.Self: - yield from _iter_linears(module) - return - - for layer, _ in module.walk(_predicate(target.get_class())): - assert isinstance(layer, fl.Chain) - yield from _iter_linears(layer) - - -class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]): - metadata: dict[str, str] | None - tensors: dict[str, Tensor] - +class SDLoraManager: def __init__( self, - target: StableDiffusion_1, - sub_targets: dict[str, list[LoraTarget]], + target: LatentDiffusionModel, + ) -> None: + self.target = target + + @property + def unet(self) -> fl.Chain: + unet = self.target.unet + assert isinstance(unet, fl.Chain) + return unet + + @property + def clip_text_encoder(self) -> fl.Chain: + clip_text_encoder = self.target.clip_text_encoder + assert isinstance(clip_text_encoder, fl.Chain) + return clip_text_encoder + + def load( + self, + tensors: dict[str, Tensor], + /, scale: float = 1.0, - weights: dict[str, Tensor] | None = None, - ): - with self.setup_adapter(target): - super().__init__(target) + ) -> None: + """Load the LoRA weights from a dictionary of tensors. - self.sub_adapters: list[LoraAdapter[SD1UNet | CLIPTextEncoderL | LatentDiffusionAutoencoder]] = [] - - for model_name in MODELS: - if not (model_targets := sub_targets.get(model_name, [])): - continue - model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name) - - lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None - self.sub_adapters.append( - LoraAdapter[type(model)]( - model, - sub_targets=lora_targets(model, model_targets), - scale=scale, - weights=lora_weights, - ) - ) - - @classmethod - def from_safetensors( - cls, - target: StableDiffusion_1, - checkpoint_path: Path | str, - scale: float = 1.0, - ): - metadata = load_metadata_from_safetensors(checkpoint_path) - assert metadata is not None, "Invalid safetensors checkpoint: missing metadata" - tensors = load_from_safetensors(checkpoint_path, device=target.device) - - sub_targets: dict[str, list[LoraTarget]] = {} - for model_name in MODELS: - if not (v := metadata.get(f"{model_name}_targets", "")): - continue - sub_targets[model_name] = [LoraTarget(x) for x in v.split(",")] - - return cls( - target, - sub_targets, - scale=scale, - weights=tensors, + Expects the keys to be in the commonly found formats on CivitAI's hub. + """ + assert len(self.lora_adapters) == 0, "Loras already loaded" + loras = Lora.from_dict( + {key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()} ) + loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)} - def inject(self: "SD1LoraAdapter", parent: fl.Chain | None = None) -> "SD1LoraAdapter": - for adapter in self.sub_adapters: - adapter.inject() - return super().inject(parent) + # if no key contains "unet" or "text", assume all keys are for the unet + if not "unet" in loras and not "text" in loras: + loras = {f"unet_{key}": loras[key] for key in loras.keys()} - def eject(self) -> None: - for adapter in self.sub_adapters: - adapter.eject() - super().eject() + self.load_unet(loras) + self.load_text_encoder(loras) + + self.scale = scale + + def load_text_encoder(self, loras: dict[str, Lora], /) -> None: + text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key} + SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder) + + def load_unet(self, loras: dict[str, Lora], /) -> None: + unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key} + exclude: list[str] = [] + exclude = [ + self.unet_exclusions[exclusion] + for exclusion in self.unet_exclusions + if all([exclusion not in key for key in unet_loras.keys()]) + ] + SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude) + + def unload(self) -> None: + for lora_adapter in self.lora_adapters: + lora_adapter.eject() + + @property + def loras(self) -> list[Lora]: + return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora)) + + @property + def lora_adapters(self) -> list[LoraAdapter]: + return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter)) + + @property + def unet_exclusions(self) -> dict[str, str]: + return { + "time": "TimestepEncoder", + "res": "ResidualBlock", + "downsample": "DownsampleBlock", + "upsample": "UpsampleBlock", + } + + @property + def scale(self) -> float: + assert len(self.loras) > 0, "No loras found" + assert all([lora.scale == self.loras[0].scale for lora in self.loras]) + return self.loras[0].scale + + @scale.setter + def scale(self, value: float) -> None: + for lora in self.loras: + lora.scale = value + + @staticmethod + def pad(input: str, /, padding_length: int = 2) -> str: + new_split: list[str] = [] + for s in input.split("_"): + if s.isdigit(): + new_split.append(s.zfill(padding_length)) + else: + new_split.append(s) + return "_".join(new_split) + + @staticmethod + def sort_keys(key: str, /) -> tuple[str, int]: + # out0 happens sometimes as an alias for out ; this dict might not be exhaustive + key_char_order = {"q": 1, "k": 2, "v": 3, "out": 4, "out0": 4} + + for i, s in enumerate(key.split("_")): + if s in key_char_order: + prefix = SDLoraManager.pad("_".join(key.split("_")[:i])) + return (prefix, key_char_order[s]) + + return (SDLoraManager.pad(key), 5) + + @staticmethod + def auto_attach( + loras: dict[str, Lora], + target: fl.Chain, + /, + exclude: list[str] | None = None, + ) -> None: + failed_loras: dict[str, Lora] = {} + for key, lora in loras.items(): + if attach := lora.auto_attach(target, exclude=exclude): + adapter, parent = attach + adapter.inject(parent) + else: + failed_loras[key] = lora + + if failed_loras: + warn(f"failed to attach {len(failed_loras)}/{len(loras)} loras to {target.__class__.__name__}") + + # TODO: add a stronger sanity check to make sure loras are attached correctly diff --git a/tests/adapters/test_lora.py b/tests/adapters/test_lora.py deleted file mode 100644 index bdedfeb..0000000 --- a/tests/adapters/test_lora.py +++ /dev/null @@ -1,82 +0,0 @@ -from torch import allclose, randn - -import refiners.fluxion.layers as fl -from refiners.fluxion.adapters.lora import Lora, LoraAdapter, SingleLoraAdapter - - -def test_single_lora_adapter() -> None: - chain = fl.Chain( - fl.Chain( - fl.Linear(in_features=1, out_features=1), - fl.Linear(in_features=1, out_features=1), - ), - fl.Linear(in_features=1, out_features=2), - ) - x = randn(1, 1) - y = chain(x) - - lora_adapter = SingleLoraAdapter(chain.Chain.Linear_1).inject(chain.Chain) - - assert isinstance(lora_adapter[1], Lora) - assert allclose(input=chain(x), other=y) - assert lora_adapter.parent == chain.Chain - - lora_adapter.eject() - assert isinstance(chain.Chain[0], fl.Linear) - assert len(chain) == 2 - - lora_adapter.inject(chain.Chain) - assert isinstance(chain.Chain[0], SingleLoraAdapter) - - -def test_lora_adapter() -> None: - chain = fl.Chain( - fl.Chain( - fl.Linear(in_features=1, out_features=1), - fl.Linear(in_features=1, out_features=1), - ), - fl.Linear(in_features=1, out_features=2), - ) - - # create and inject twice - - a1 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject() - assert len(list(chain.layers(Lora))) == 3 - - a2 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject() - assert len(list(chain.layers(Lora))) == 6 - - # If we init a LoRA when another LoRA is already injected, the Linear - # layers of the first LoRA will be adapted too, which is typically not - # what we want. - # This issue can be avoided either by making the predicate for - # `walk` raise StopIteration when it encounters a LoRA (see the SD LoRA) - # or by creating all the LoRA Adapters first, before injecting them - # (see below). - assert len(list(chain.layers(Lora, recurse=True))) == 12 - - # ejection in forward order - - a1.eject() - assert len(list(chain.layers(Lora))) == 3 - a2.eject() - assert len(list(chain.layers(Lora))) == 0 - - # create twice then inject twice - - a1 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0) - a2 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0) - a1.inject() - a2.inject() - assert len(list(chain.layers(Lora))) == 6 - - # If we inject after init we do not have the target selection problem, - # the LoRA layers are not adapted. - assert len(list(chain.layers(Lora, recurse=True))) == 6 - - # ejection in reverse order - - a2.eject() - assert len(list(chain.layers(Lora))) == 3 - a1.eject() - assert len(list(chain.layers(Lora))) == 0 diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 04bb180..967fb0a 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -19,7 +19,7 @@ from refiners.foundationals.latent_diffusion import ( StableDiffusion_1, StableDiffusion_1_Inpainting, ) -from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter +from refiners.foundationals.latent_diffusion.lora import SDLoraManager from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter from refiners.foundationals.latent_diffusion.restart import Restart @@ -182,14 +182,38 @@ def t2i_adapter_xl_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[ condition_image = Image.open(ref_path / f"fairy_guide_{name}.png").convert("RGB") expected_image = Image.open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB") weights_path = test_weights_path / "T2I-Adapter" / "t2i-adapter-canny-sdxl-1.0.safetensors" + + if not weights_path.is_file(): + warn(f"could not find weights at {weights_path}, skipping") + pytest.skip(allow_module_level=True) + return name, condition_image, expected_image, weights_path @pytest.fixture(scope="module") -def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Path]: +def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]: expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB") - weights_path = test_weights_path / "loras" / "pcuenq_pokemon_lora.safetensors" - return expected_image, weights_path + weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin" + + if not weights_path.is_file(): + warn(f"could not find weights at {weights_path}, skipping") + pytest.skip(allow_module_level=True) + + tensors = torch.load(weights_path) # type: ignore + return expected_image, tensors + + +@pytest.fixture(scope="module") +def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, dict[str, torch.Tensor]]: + expected_image = Image.open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB") + weights_path = test_weights_path / "loras" / "dpo-lora" / "pytorch_lora_weights.safetensors" + + if not weights_path.is_file(): + warn(f"could not find weights at {weights_path}, skipping") + pytest.skip(allow_module_level=True) + + tensors = load_from_safetensors(weights_path) + return expected_image, tensors @pytest.fixture @@ -1010,24 +1034,20 @@ def test_diffusion_controlnet_stack( @no_grad() def test_diffusion_lora( sd15_std: StableDiffusion_1, - lora_data_pokemon: tuple[Image.Image, Path], + lora_data_pokemon: tuple[Image.Image, dict[str, torch.Tensor]], test_device: torch.device, -): +) -> None: sd15 = sd15_std n_steps = 30 - expected_image, lora_weights_path = lora_data_pokemon - - if not lora_weights_path.is_file(): - warn(f"could not find weights at {lora_weights_path}, skipping") - pytest.skip(allow_module_level=True) + expected_image, lora_weights = lora_data_pokemon prompt = "a cute cat" clip_text_embedding = sd15.compute_clip_text_embedding(prompt) sd15.set_num_inference_steps(n_steps) - SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=1.0).inject() + SDLoraManager(sd15).load(lora_weights, scale=1) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) @@ -1045,77 +1065,45 @@ def test_diffusion_lora( @no_grad() -def test_diffusion_lora_float16( - sd15_std_float16: StableDiffusion_1, - lora_data_pokemon: tuple[Image.Image, Path], - test_device: torch.device, -): - sd15 = sd15_std_float16 - n_steps = 30 +def test_diffusion_sdxl_lora( + sdxl_ddim: StableDiffusion_XL, + lora_data_dpo: tuple[Image.Image, dict[str, torch.Tensor]], +) -> None: + sdxl = sdxl_ddim + expected_image, lora_weights = lora_data_dpo - expected_image, lora_weights_path = lora_data_pokemon + # parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA + # except that we are using DDIM instead of sde-dpmsolver++ + n_steps = 40 + seed = 12341234123 + guidance_scale = 7.5 + lora_scale = 1.4 + prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography" + negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white" - if not lora_weights_path.is_file(): - warn(f"could not find weights at {lora_weights_path}, skipping") - pytest.skip(allow_module_level=True) + SDLoraManager(sdxl).load(lora_weights, scale=lora_scale) - prompt = "a cute cat" - clip_text_embedding = sd15.compute_clip_text_embedding(prompt) + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( + text=prompt, negative_text=negative_prompt + ) - sd15.set_num_inference_steps(n_steps) + time_ids = sdxl.default_time_ids + sdxl.set_num_inference_steps(n_steps) - SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=1.0).inject() + manual_seed(seed=seed) + x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype) - manual_seed(2) - x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) - - for step in sd15.steps: - x = sd15( + for step in sdxl.steps: + x = sdxl( x, step=step, clip_text_embedding=clip_text_embedding, - condition_scale=7.5, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + condition_scale=guidance_scale, ) - predicted_image = sd15.lda.decode_latents(x) - ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98) - - -@no_grad() -def test_diffusion_lora_twice( - sd15_std: StableDiffusion_1, - lora_data_pokemon: tuple[Image.Image, Path], - test_device: torch.device, -): - sd15 = sd15_std - n_steps = 30 - - expected_image, lora_weights_path = lora_data_pokemon - - if not lora_weights_path.is_file(): - warn(f"could not find weights at {lora_weights_path}, skipping") - pytest.skip(allow_module_level=True) - - prompt = "a cute cat" - clip_text_embedding = sd15.compute_clip_text_embedding(prompt) - - sd15.set_num_inference_steps(n_steps) - - # The same LoRA is used twice which is not a common use case: this is purely for testing purpose - SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=0.4).inject() - SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=0.6).inject() - - manual_seed(2) - x = torch.randn(1, 4, 64, 64, device=test_device) - - for step in sd15.steps: - x = sd15( - x, - step=step, - clip_text_embedding=clip_text_embedding, - condition_scale=7.5, - ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sdxl.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) diff --git a/tests/e2e/test_diffusion_ref/README.md b/tests/e2e/test_diffusion_ref/README.md index 1a35ecf..0f572b9 100644 --- a/tests/e2e/test_diffusion_ref/README.md +++ b/tests/e2e/test_diffusion_ref/README.md @@ -48,6 +48,7 @@ Special cases: - `expected_restart.png` - `expected_freeu.png` - `expected_dropy_slime_9752.png` + - `expected_sdxl_dpo_lora.png` ## Other images diff --git a/tests/e2e/test_diffusion_ref/expected_sdxl_dpo_lora.png b/tests/e2e/test_diffusion_ref/expected_sdxl_dpo_lora.png new file mode 100644 index 0000000..5af8d4d Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_sdxl_dpo_lora.png differ